-
I've been scratching my head trying to figure out how to return a pub async fn handle_rate_limiting<S, B>(
req: ServiceRequest,
limiter: Arc<Mutex<Limiter>>,
service: Arc<S>,
) -> Result<ServiceResponse<B>, Error>
where
S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error>,
S::Future: 'static,
B: 'static + MessageBody,
{
let ip = match req.peer_addr().map(|addr| addr.ip().to_string()) {
Some(ip) => ip,
None => return Ok(ServiceResponse::new(req.request().clone(), HttpResponse::new(StatusCode::BAD_REQUEST))) // expected `HttpResponse<B>`, found `HttpResponse`
};
println!("Request from IP: {}", ip);
let now = Utc::now();
{
let mut limiter = limiter.lock().unwrap();
let (last_request_time, request_count) = limiter.ip_addresses.entry(ip.clone())
.or_insert((now, 0));
println!("IP: {} - Last Request Time: {}, Request Count: {}", ip, last_request_time, request_count);
if now - *last_request_time <= Duration::seconds(20) {
if *request_count >= 2 {
println!("IP: {} - Too Many Requests", ip);
return Ok(ServiceResponse::new(req.request().clone(), HttpResponse::new(StatusCode::TOO_MANY_REQUESTS))); // expected `HttpResponse<B>`, found `HttpResponse`
} else {
*request_count += 1;
println!("IP: {} - Incremented Request Count: {}", ip, request_count);
}
} else {
// Reset time and count after 20 seconds
*last_request_time = now;
*request_count = 1;
println!("IP: {} - Reset Request Count and Time", ip);
}
}
let res = service.call(req).await?;
Ok(res)
} I have tried using Can anyone please explain what I am missing here? Note that the answer to this similar question does not have a solution that works here. Using My full middleware definition is as follows: use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use chrono::{DateTime, Duration, Utc};
use actix_service::{Service, Transform};
use actix_web::dev::{ServiceRequest, ServiceResponse};
use actix_web::{Error, HttpResponse};
use futures::future::{ok, Ready};
use std::task::{Context, Poll};
use actix_web::body::{BoxBody, EitherBody, MessageBody};
use actix_web::http::StatusCode;
pub struct Limiter {
pub ip_addresses: HashMap<String, (DateTime<Utc>, usize)>,
}
pub struct RateLimiter {
pub(crate) limiter: Arc<Mutex<Limiter>>,
}
impl RateLimiter {
pub fn new(limiter: Arc<Mutex<Limiter>>) -> Self {
Self { limiter }
}
}
pub struct RateLimiterMiddleware<S> {
pub(crate) service: Arc<S>,
pub(crate) limiter: Arc<Mutex<Limiter>>,
}
impl<S, B> Transform<S, ServiceRequest> for RateLimiter
where
S: Service<ServiceRequest, Response=ServiceResponse<EitherBody<B, BoxBody>>, Error=Error> + 'static,
S::Future: 'static,
B: 'static + MessageBody,
{
type Response=ServiceResponse<EitherBody<B, BoxBody>>;
type Error = Error;
type Transform = RateLimiterMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(RateLimiterMiddleware {
service: Arc::new(service),
limiter: self.limiter.clone(),
})
}
}
impl<S, B> Service<ServiceRequest> for RateLimiterMiddleware<S>
where
S: Service<ServiceRequest, Response=ServiceResponse<EitherBody<B, BoxBody>>, Error=Error> + 'static,
S::Future: 'static,
B: 'static + MessageBody,
{
type Response=ServiceResponse<EitherBody<B, BoxBody>>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output=Result<Self::Response, Self::Error>>>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&self, req: ServiceRequest) -> Self::Future {
let limiter = Arc::clone(&self.limiter);
let service = Arc::clone(&self.service);
Box::pin(handle_rate_limiting(req, limiter, service))
}
}
pub async fn handle_rate_limiting<S, B>(
req: ServiceRequest,
limiter: Arc<Mutex<Limiter>>,
service: Arc<S>,
) -> Result<ServiceResponse<EitherBody<B, BoxBody>>, Error>
where
S: Service<ServiceRequest, Response=ServiceResponse<EitherBody<B, BoxBody>>, Error=Error>,
S::Future: 'static,
B: 'static,
{
let ip = match req.peer_addr().map(|addr| addr.ip().to_string()) {
Some(ip) => ip,
None => return Ok(ServiceResponse::new(req.request().clone(), HttpResponse::new(StatusCode::BAD_REQUEST).map_into_boxed_body().map_into_right_body())) // expected `HttpResponse<EitherBody<B>>`, found `HttpResponse`
};
println!("Request from IP: {}", ip);
let now = Utc::now();
{
let mut limiter = limiter.lock().unwrap();
let (last_request_time, request_count) = limiter.ip_addresses.entry(ip.clone())
.or_insert((now, 0));
println!("IP: {} - Last Request Time: {}, Request Count: {}", ip, last_request_time, request_count);
if now - *last_request_time <= Duration::seconds(20) {
if *request_count >= 2 {
println!("IP: {} - Too Many Requests", ip);
return Ok(ServiceResponse::new(req.request().clone(), HttpResponse::new(StatusCode::TOO_MANY_REQUESTS).map_into_boxed_body().map_into_right_body()));
} else {
*request_count += 1;
println!("IP: {} - Incremented Request Count: {}", ip, request_count);
}
} else {
// Reset time and count after 20 seconds
*last_request_time = now;
*request_count = 1;
println!("IP: {} - Reset Request Count and Time", ip);
}
}
let res = service.call(req).await?;
Ok(res)
} I am trying to wrap my endpoints as shown below: let limiter = Arc::new(Mutex::new(Limiter {
ip_addresses: HashMap::new(),
}));
HttpServer::new(move || {
App::new()
.app_data(web::Data::new(app_state.clone()))
.wrap(Logger::default())
.wrap_fn(move |req, srv| {
let srv = Arc::new(srv);
let limiter = Arc::clone(&limiter);
async move {
handle_rate_limiting(req, limiter, srv).await
}
}) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Kinda duplicate of #3428 You can't coerce something to a You're on the right track using ) -> Result<ServiceResponse<EitherBody<B>>, Error> In essensce, you're wrapping a |
Beta Was this translation helpful? Give feedback.
Kinda duplicate of #3428
You can't coerce something to a
B
here any more than other generics in Rust.You're on the right track using
EitherBody
and the helper methods. The missing piece is in the return type:In essensce, you're wrapping a
Svc(ResponseBody = B)
into aSvc(ResponseBody = EitherBody<B, BoxBody>)
, whereB
is still only from theservice.call(req).await
result, andBoxBody
is from your other responses.