diff --git a/Cargo.lock b/Cargo.lock index d3aa0b2953a0..a99803e1bc4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8744,6 +8744,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "tokio-util", "tower 0.4.13", "tower-http", "tracing", diff --git a/crates/rpc/rpc-builder/Cargo.toml b/crates/rpc/rpc-builder/Cargo.toml index cc72c2ebf92e..b9b511a078bc 100644 --- a/crates/rpc/rpc-builder/Cargo.toml +++ b/crates/rpc/rpc-builder/Cargo.toml @@ -50,6 +50,8 @@ metrics.workspace = true serde = { workspace = true, features = ["derive"] } thiserror.workspace = true tracing.workspace = true +tokio-util = { workspace = true } +tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } [dev-dependencies] reth-chainspec.workspace = true diff --git a/crates/rpc/rpc-builder/src/lib.rs b/crates/rpc/rpc-builder/src/lib.rs index 72b53efe674c..ceafe206531a 100644 --- a/crates/rpc/rpc-builder/src/lib.rs +++ b/crates/rpc/rpc-builder/src/lib.rs @@ -226,6 +226,9 @@ pub use eth::EthHandlers; mod metrics; pub use metrics::{MeteredRequestFuture, RpcRequestMetricsService}; +// Rpc rate limiter +pub mod rate_limiter; + /// Convenience function for starting a server in one step. #[allow(clippy::too_many_arguments)] pub async fn launch( diff --git a/crates/rpc/rpc-builder/src/rate_limiter.rs b/crates/rpc/rpc-builder/src/rate_limiter.rs new file mode 100644 index 000000000000..85df0eee61c6 --- /dev/null +++ b/crates/rpc/rpc-builder/src/rate_limiter.rs @@ -0,0 +1,116 @@ +//! [`jsonrpsee`] helper layer for rate limiting certain methods. + +use jsonrpsee::{server::middleware::rpc::RpcServiceT, types::Request, MethodResponse}; +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{ready, Context, Poll}, +}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio_util::sync::PollSemaphore; +use tower::Layer; + +/// Rate limiter for the RPC server. +/// +/// Rate limits expensive calls such as debug_ and trace_. +#[derive(Debug, Clone)] +pub struct RpcRequestRateLimiter { + inner: Arc, +} + +impl RpcRequestRateLimiter { + /// Create a new rate limit layer with the given number of permits. + pub fn new(rate_limit: usize) -> Self { + Self { + inner: Arc::new(RpcRequestRateLimiterInner { + call_guard: PollSemaphore::new(Arc::new(Semaphore::new(rate_limit))), + }), + } + } +} + +impl Layer for RpcRequestRateLimiter { + type Service = RpcRequestRateLimitingService; + + fn layer(&self, inner: S) -> Self::Service { + RpcRequestRateLimitingService::new(inner, self.clone()) + } +} + +/// Rate Limiter for the RPC server +#[derive(Debug, Clone)] +struct RpcRequestRateLimiterInner { + /// Semaphore to rate limit calls + call_guard: PollSemaphore, +} + +/// A [`RpcServiceT`] middleware that rate limits RPC calls to the server. +#[derive(Debug, Clone)] +pub struct RpcRequestRateLimitingService { + /// The rate limiter for RPC requests + rate_limiter: RpcRequestRateLimiter, + /// The inner service being wrapped + inner: S, +} + +impl RpcRequestRateLimitingService { + /// Create a new rate limited service. + pub const fn new(service: S, rate_limiter: RpcRequestRateLimiter) -> Self { + Self { inner: service, rate_limiter } + } +} + +impl<'a, S> RpcServiceT<'a> for RpcRequestRateLimitingService +where + S: RpcServiceT<'a> + Send + Sync + Clone + 'static, +{ + type Future = RateLimitingRequestFuture; + + fn call(&self, req: Request<'a>) -> Self::Future { + let method_name = req.method_name(); + if method_name.starts_with("trace_") || method_name.starts_with("debug_") { + RateLimitingRequestFuture { + fut: self.inner.call(req), + guard: Some(self.rate_limiter.inner.call_guard.clone()), + permit: None, + } + } else { + // if we don't need to rate limit, then there + // is no need to get a semaphore permit + RateLimitingRequestFuture { fut: self.inner.call(req), guard: None, permit: None } + } + } +} + +/// Response future. +#[pin_project::pin_project] +pub struct RateLimitingRequestFuture { + #[pin] + fut: F, + guard: Option, + permit: Option, +} + +impl std::fmt::Debug for RateLimitingRequestFuture { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("RateLimitingRequestFuture") + } +} + +impl> Future for RateLimitingRequestFuture { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + if let Some(guard) = this.guard.as_mut() { + *this.permit = ready!(guard.poll_acquire(cx)); + *this.guard = None; + } + let res = this.fut.poll(cx); + if res.is_ready() { + *this.permit = None; + } + res + } +}