diff --git a/README.md b/README.md index fb13c4b5c..24f90e2ce 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ async fn main() -> tide::Result<()> { Ok(()) } -async fn order_shoes(mut req: Request<()>) -> tide::Result { +async fn order_shoes(mut req: Request) -> tide::Result { let Animal { name, legs } = req.body_json().await?; Ok(format!("Hello, {}! I've put in an order for {} shoes", name, legs).into()) } diff --git a/examples/concurrent_listeners.rs b/examples/concurrent_listeners.rs index b79e6fb53..af9884203 100644 --- a/examples/concurrent_listeners.rs +++ b/examples/concurrent_listeners.rs @@ -6,7 +6,7 @@ async fn main() -> Result<(), std::io::Error> { let mut app = tide::new(); app.with(tide::log::LogMiddleware::new()); - app.at("/").get(|request: Request<_>| async move { + app.at("/").get(|request: Request| async move { Ok(format!( "Hi! You reached this app through: {}", request.local_addr().unwrap_or("an unknown port") diff --git a/examples/cookies.rs b/examples/cookies.rs index 966bd4479..2f86b21df 100644 --- a/examples/cookies.rs +++ b/examples/cookies.rs @@ -3,17 +3,17 @@ use tide::{Request, Response, StatusCode}; /// Tide will use the the `Cookies`'s `Extract` implementation to build this parameter. /// -async fn retrieve_cookie(req: Request<()>) -> tide::Result { +async fn retrieve_cookie(req: Request) -> tide::Result { Ok(format!("hello cookies: {:?}", req.cookie("hello").unwrap())) } -async fn insert_cookie(_req: Request<()>) -> tide::Result { +async fn insert_cookie(_req: Request) -> tide::Result { let mut res = Response::new(StatusCode::Ok); res.insert_cookie(Cookie::new("hello", "world")); Ok(res) } -async fn remove_cookie(_req: Request<()>) -> tide::Result { +async fn remove_cookie(_req: Request) -> tide::Result { let mut res = Response::new(StatusCode::Ok); res.remove_cookie(Cookie::named("hello")); Ok(res) diff --git a/examples/error_handling.rs b/examples/error_handling.rs index ca35277d0..7073f5f15 100644 --- a/examples/error_handling.rs +++ b/examples/error_handling.rs @@ -23,7 +23,7 @@ async fn main() -> Result<()> { })); app.at("/") - .get(|_req: Request<_>| async { Ok(Body::from_file("./does-not-exist").await?) }); + .get(|_req: Request| async { Ok(Body::from_file("./does-not-exist").await?) }); app.listen("127.0.0.1:8080").await?; diff --git a/examples/fib.rs b/examples/fib.rs index 2b0451ecb..bde8bb7d8 100644 --- a/examples/fib.rs +++ b/examples/fib.rs @@ -8,7 +8,7 @@ fn fib(n: usize) -> usize { } } -async fn fibsum(req: Request<()>) -> tide::Result { +async fn fibsum(req: Request) -> tide::Result { use std::time::Instant; let n: usize = req.param("n")?.parse().unwrap_or(0); // Start a stopwatch diff --git a/examples/graphql.rs b/examples/graphql.rs index c771afb49..6e4a020b6 100644 --- a/examples/graphql.rs +++ b/examples/graphql.rs @@ -2,7 +2,7 @@ use std::sync::{Arc, RwLock}; use juniper::{http::graphiql, http::GraphQLRequest, RootNode}; use lazy_static::lazy_static; -use tide::{http::mime, Body, Redirect, Request, Response, Server, StatusCode}; +use tide::{http::mime, Body, Redirect, Request, Response, StatusCode}; #[derive(Clone)] struct User { @@ -74,9 +74,9 @@ lazy_static! { static ref SCHEMA: Schema = Schema::new(QueryRoot {}, MutationRoot {}); } -async fn handle_graphql(mut request: Request) -> tide::Result { +async fn handle_graphql(mut request: Request) -> tide::Result { let query: GraphQLRequest = request.body_json().await?; - let response = query.execute(&SCHEMA, request.state()); + let response = query.execute(&SCHEMA, request.state::()); let status = if response.is_ok() { StatusCode::Ok } else { @@ -88,7 +88,7 @@ async fn handle_graphql(mut request: Request) -> tide::Result { .build()) } -async fn handle_graphiql(_: Request) -> tide::Result> { +async fn handle_graphiql(_: Request) -> tide::Result> { Ok(Response::builder(200) .body(graphiql::graphiql_source("/graphql")) .content_type(mime::HTML)) @@ -96,7 +96,7 @@ async fn handle_graphiql(_: Request) -> tide::Result> #[async_std::main] async fn main() -> std::io::Result<()> { - let mut app = Server::with_state(State { + let mut app = tide::with_state(State { users: Arc::new(RwLock::new(Vec::new())), }); app.at("/").get(Redirect::permanent("/graphiql")); diff --git a/examples/json.rs b/examples/json.rs index 085a359d7..49d2d45a7 100644 --- a/examples/json.rs +++ b/examples/json.rs @@ -13,7 +13,7 @@ async fn main() -> tide::Result<()> { let mut app = tide::new(); app.with(tide::log::LogMiddleware::new()); - app.at("/submit").post(|mut req: Request<()>| async move { + app.at("/submit").post(|mut req: Request| async move { let cat: Cat = req.body_json().await?; println!("cat name: {}", cat.name); diff --git a/examples/middleware.rs b/examples/middleware.rs index 39709149d..5dacc2067 100644 --- a/examples/middleware.rs +++ b/examples/middleware.rs @@ -1,5 +1,3 @@ -use std::future::Future; -use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -26,22 +24,17 @@ impl UserDatabase { // This is an example of a function middleware that uses the // application state. Because it depends on a specific request state, // it would likely be closely tied to a specific application -fn user_loader<'a>( - mut request: Request, - next: Next<'a, UserDatabase>, -) -> Pin + Send + 'a>> { - Box::pin(async { - if let Some(user) = request.state().find_user().await { - trace!("user loaded", {user: user.name}); - request.set_ext(user); - Ok(next.run(request).await) - // this middleware only needs to run before the endpoint, so - // it just passes through the result of Next - } else { - // do not run endpoints, we could not find a user - Ok(Response::new(StatusCode::Unauthorized)) - } - }) +async fn user_loader(mut request: Request, next: Next) -> Result { + if let Some(user) = request.state::().find_user().await { + trace!("user loaded", {user: user.name}); + request.set_ext(user); + Ok(next.run(request).await) + // this middleware only needs to run before the endpoint, so + // it just passes through the result of Next + } else { + // do not run endpoints, we could not find a user + Ok(Response::new(StatusCode::Unauthorized)) + } } // This is an example of middleware that keeps its own state and could @@ -62,8 +55,8 @@ impl RequestCounterMiddleware { struct RequestCount(usize); #[tide::utils::async_trait] -impl Middleware for RequestCounterMiddleware { - async fn handle(&self, mut req: Request, next: Next<'_, State>) -> Result { +impl Middleware for RequestCounterMiddleware { + async fn handle(&self, mut req: Request, next: Next) -> Result { let count = self.requests_counted.fetch_add(1, Ordering::Relaxed); trace!("request counter", { count: count }); req.set_ext(RequestCount(count)); @@ -115,12 +108,12 @@ async fn main() -> Result<()> { app.with(user_loader); app.with(RequestCounterMiddleware::new(0)); - app.with(Before(|mut request: Request| async move { + app.with(Before(|mut request: Request| async move { request.set_ext(std::time::Instant::now()); request })); - app.at("/").get(|req: Request<_>| async move { + app.at("/").get(|req: Request| async move { let count: &RequestCount = req.ext().unwrap(); let user: &User = req.ext().unwrap(); diff --git a/examples/sessions.rs b/examples/sessions.rs index f6e9e966f..5e7cc6497 100644 --- a/examples/sessions.rs +++ b/examples/sessions.rs @@ -15,7 +15,7 @@ async fn main() -> Result<(), std::io::Error> { )); app.with(tide::utils::Before( - |mut request: tide::Request<()>| async move { + |mut request: tide::Request| async move { let session = request.session_mut(); let visits: usize = session.get("visits").unwrap_or_default(); session.insert("visits", visits + 1).unwrap(); @@ -23,16 +23,15 @@ async fn main() -> Result<(), std::io::Error> { }, )); - app.at("/").get(|req: tide::Request<()>| async move { + app.at("/").get(|req: tide::Request| async move { let visits: usize = req.session().get("visits").unwrap(); Ok(format!("you have visited this website {} times", visits)) }); - app.at("/reset") - .get(|mut req: tide::Request<()>| async move { - req.session_mut().destroy(); - Ok(tide::Redirect::new("/")) - }); + app.at("/reset").get(|mut req: tide::Request| async move { + req.session_mut().destroy(); + Ok(tide::Redirect::new("/")) + }); app.listen("127.0.0.1:8080").await?; diff --git a/examples/state.rs b/examples/state.rs index 5679c8a62..e2f98389a 100644 --- a/examples/state.rs +++ b/examples/state.rs @@ -19,13 +19,13 @@ async fn main() -> tide::Result<()> { femme::start(); let mut app = tide::with_state(State::new()); app.with(tide::log::LogMiddleware::new()); - app.at("/").get(|req: tide::Request| async move { - let state = req.state(); + app.at("/").get(|req: tide::Request| async move { + let state = req.state::(); let value = state.value.load(Ordering::Relaxed); Ok(format!("{}\n", value)) }); - app.at("/inc").get(|req: tide::Request| async move { - let state = req.state(); + app.at("/inc").get(|req: tide::Request| async move { + let state = req.state::(); let value = state.value.fetch_add(1, Ordering::Relaxed) + 1; Ok(format!("{}\n", value)) }); diff --git a/examples/upload.rs b/examples/upload.rs index 4bf2da511..fcf90cdee 100644 --- a/examples/upload.rs +++ b/examples/upload.rs @@ -37,9 +37,10 @@ async fn main() -> Result<(), IoError> { // $ curl localhost:8080/README.md # this reads the file from the same temp directory app.at(":file") - .put(|req: Request| async move { + .put(|req: Request| async move { let path = req.param("file")?; - let fs_path = req.state().path().join(path); + let state = req.state::(); + let fs_path = state.path().join(path); let file = OpenOptions::new() .create(true) @@ -56,9 +57,9 @@ async fn main() -> Result<(), IoError> { Ok(json!({ "bytes": bytes_written })) }) - .get(|req: Request| async move { + .get(|req: Request| async move { let path = req.param("file")?; - let fs_path = req.state().path().join(path); + let fs_path = req.state::().path().join(path); if let Ok(body) = Body::from_file(fs_path).await { Ok(body.into()) diff --git a/src/cookies/middleware.rs b/src/cookies/middleware.rs index 145d24173..12d0da85a 100644 --- a/src/cookies/middleware.rs +++ b/src/cookies/middleware.rs @@ -15,7 +15,7 @@ use std::sync::{Arc, RwLock}; /// # use tide::{Request, Response, StatusCode}; /// # use tide::http::cookies::Cookie; /// let mut app = tide::Server::new(); -/// app.at("/get").get(|req: Request<()>| async move { +/// app.at("/get").get(|req: Request| async move { /// Ok(req.cookie("testCookie").unwrap().value().to_string()) /// }); /// app.at("/set").get(|_| async { @@ -35,8 +35,8 @@ impl CookiesMiddleware { } #[async_trait] -impl Middleware for CookiesMiddleware { - async fn handle(&self, mut ctx: Request, next: Next<'_, State>) -> crate::Result { +impl Middleware for CookiesMiddleware { + async fn handle(&self, mut ctx: Request, next: Next) -> crate::Result { let cookie_jar = if let Some(cookie_data) = ctx.ext::() { cookie_data.content.clone() } else { @@ -112,7 +112,7 @@ impl LazyJar { } impl CookieData { - pub(crate) fn from_request(req: &Request) -> Self { + pub(crate) fn from_request(req: &Request) -> Self { let jar = if let Some(cookie_headers) = req.header(&headers::COOKIE) { let mut jar = CookieJar::new(); for cookie_header in cookie_headers { diff --git a/src/endpoint.rs b/src/endpoint.rs index 2857c291c..fffac3e98 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -1,17 +1,15 @@ use async_std::future::Future; -use async_std::sync::Arc; use async_trait::async_trait; use http_types::Result; -use crate::middleware::Next; -use crate::{Middleware, Request, Response}; +use crate::{Request, Response}; /// An HTTP request handler. /// /// This trait is automatically implemented for `Fn` types, and so is rarely implemented /// directly by Tide users. /// -/// In practice, endpoints are functions that take a `Request` as an argument and +/// In practice, endpoints are functions that take a `Request` as an argument and /// return a type `T` that implements `Into`. /// /// # Examples @@ -23,7 +21,7 @@ use crate::{Middleware, Request, Response}; /// A simple endpoint that is invoked on a `GET` request and returns a `String`: /// /// ```no_run -/// async fn hello(_req: tide::Request<()>) -> tide::Result { +/// async fn hello(_req: tide::Request) -> tide::Result { /// Ok(String::from("hello")) /// } /// @@ -35,7 +33,7 @@ use crate::{Middleware, Request, Response}; /// /// ```no_run /// # use core::future::Future; -/// fn hello(_req: tide::Request<()>) -> impl Future> { +/// fn hello(_req: tide::Request) -> impl Future> { /// async_std::future::ready(Ok(String::from("hello"))) /// } /// @@ -45,90 +43,28 @@ use crate::{Middleware, Request, Response}; /// /// Tide routes will also accept endpoints with `Fn` signatures of this form, but using the `async` keyword has better ergonomics. #[async_trait] -pub trait Endpoint: Send + Sync + 'static { +pub trait Endpoint: Send + Sync + 'static { /// Invoke the endpoint within the given context - async fn call(&self, req: Request) -> crate::Result; + async fn call(&self, req: Request) -> crate::Result; } -pub(crate) type DynEndpoint = dyn Endpoint; - #[async_trait] -impl Endpoint for F +impl Endpoint for F where - State: Clone + Send + Sync + 'static, - F: Send + Sync + 'static + Fn(Request) -> Fut, + F: Send + Sync + 'static + Fn(Request) -> Fut, Fut: Future> + Send + 'static, Res: Into + 'static, { - async fn call(&self, req: Request) -> crate::Result { + async fn call(&self, req: Request) -> crate::Result { let fut = (self)(req); let res = fut.await?; Ok(res.into()) } } -pub(crate) struct MiddlewareEndpoint { - endpoint: E, - middleware: Vec>>, -} - -impl Clone for MiddlewareEndpoint { - fn clone(&self) -> Self { - Self { - endpoint: self.endpoint.clone(), - middleware: self.middleware.clone(), - } - } -} - -impl std::fmt::Debug for MiddlewareEndpoint { - fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - fmt, - "MiddlewareEndpoint (length: {})", - self.middleware.len(), - ) - } -} - -impl MiddlewareEndpoint -where - State: Clone + Send + Sync + 'static, - E: Endpoint, -{ - pub(crate) fn wrap_with_middleware( - ep: E, - middleware: &[Arc>], - ) -> Box + Send + Sync + 'static> { - if middleware.is_empty() { - Box::new(ep) - } else { - Box::new(Self { - endpoint: ep, - middleware: middleware.to_vec(), - }) - } - } -} - -#[async_trait] -impl Endpoint for MiddlewareEndpoint -where - State: Clone + Send + Sync + 'static, - E: Endpoint, -{ - async fn call(&self, req: Request) -> crate::Result { - let next = Next { - endpoint: &self.endpoint, - next_middleware: &self.middleware, - }; - Ok(next.run(req).await) - } -} - #[async_trait] -impl Endpoint for Box> { - async fn call(&self, request: Request) -> crate::Result { +impl Endpoint for Box { + async fn call(&self, request: Request) -> crate::Result { self.as_ref().call(request).await } } diff --git a/src/fs/serve_dir.rs b/src/fs/serve_dir.rs index 431dfcd90..f2a9b9b12 100644 --- a/src/fs/serve_dir.rs +++ b/src/fs/serve_dir.rs @@ -19,11 +19,8 @@ impl ServeDir { } #[async_trait::async_trait] -impl Endpoint for ServeDir -where - State: Clone + Send + Sync + 'static, -{ - async fn call(&self, req: Request) -> Result { +impl Endpoint for ServeDir { + async fn call(&self, req: Request) -> Result { let path = req.url().path(); let path = path .strip_prefix(&self.prefix.trim_end_matches('*')) @@ -80,11 +77,11 @@ mod test { }) } - fn request(path: &str) -> crate::Request<()> { + fn request(path: &str) -> crate::Request { let request = crate::http::Request::get( crate::http::Url::parse(&format!("http://localhost/{}", path)).unwrap(), ); - crate::Request::new((), request, vec![]) + crate::Request::new(request, vec![]) } #[async_std::test] diff --git a/src/fs/serve_file.rs b/src/fs/serve_file.rs index 2ed80e41a..754da6267 100644 --- a/src/fs/serve_file.rs +++ b/src/fs/serve_file.rs @@ -21,8 +21,8 @@ impl ServeFile { } #[async_trait] -impl Endpoint for ServeFile { - async fn call(&self, _: Request) -> Result { +impl Endpoint for ServeFile { + async fn call(&self, _: Request) -> Result { match Body::from_file(&self.path).await { Ok(body) => Ok(Response::builder(StatusCode::Ok).body(body).build()), Err(e) if e.kind() == io::ErrorKind::NotFound => { @@ -53,10 +53,10 @@ mod test { Ok(ServeFile::init(file_path)?) } - fn request(path: &str) -> crate::Request<()> { + fn request(path: &str) -> crate::Request { let request = crate::http::Request::get(Url::parse(&format!("http://localhost/{}", path)).unwrap()); - crate::Request::new((), request, vec![]) + crate::Request::new(request, vec![]) } #[async_std::test] diff --git a/src/lib.rs b/src/lib.rs index b9e739d8f..cdd6a8a99 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,7 +38,7 @@ //! Ok(()) //! } //! -//! async fn order_shoes(mut req: Request<()>) -> tide::Result { +//! async fn order_shoes(mut req: Request) -> tide::Result { //! let Animal { name, legs } = req.body_json().await?; //! Ok(format!("Hello, {}! I've put in an order for {} shoes", name, legs).into()) //! } @@ -75,6 +75,7 @@ mod response_builder; mod route; mod router; mod server; +mod state; pub mod convert; pub mod listener; @@ -96,6 +97,7 @@ pub use response::Response; pub use response_builder::ResponseBuilder; pub use route::Route; pub use server::Server; +pub use state::StateMiddleware; pub use http_types::{self as http, Body, Error, Status, StatusCode}; @@ -114,7 +116,7 @@ pub use http_types::{self as http, Body, Error, Status, StatusCode}; /// # Ok(()) }) } /// ``` #[must_use] -pub fn new() -> server::Server<()> { +pub fn new() -> server::Server { Server::new() } @@ -128,7 +130,7 @@ pub fn new() -> server::Server<()> { /// # use async_std::task::block_on; /// # fn main() -> Result<(), std::io::Error> { block_on(async { /// # -/// use tide::Request; +/// use tide::{Request}; /// /// /// The shared application state. /// #[derive(Clone)] @@ -143,18 +145,20 @@ pub fn new() -> server::Server<()> { /// /// // Initialize the application with state. /// let mut app = tide::with_state(state); -/// app.at("/").get(|req: Request| async move { -/// Ok(format!("Hello, {}!", &req.state().name)) +/// app.at("/").get(|req: Request| async move { +/// Ok(format!("Hello, {}!", &req.state::().name)) /// }); /// app.listen("127.0.0.1:8080").await?; /// # /// # Ok(()) }) } /// ``` -pub fn with_state(state: State) -> server::Server +pub fn with_state(state: State) -> server::Server where State: Clone + Send + Sync + 'static, { - Server::with_state(state) + let mut server = Server::new(); + server.with_state(state); + server } /// A specialized Result type for Tide. diff --git a/src/listener/concurrent_listener.rs b/src/listener/concurrent_listener.rs index 906f1f4c5..be3c7a92f 100644 --- a/src/listener/concurrent_listener.rs +++ b/src/listener/concurrent_listener.rs @@ -32,11 +32,11 @@ use futures_util::stream::{futures_unordered::FuturesUnordered, StreamExt}; ///``` #[derive(Default)] -pub struct ConcurrentListener { - listeners: Vec>>, +pub struct ConcurrentListener { + listeners: Vec>, } -impl ConcurrentListener { +impl ConcurrentListener { /// creates a new ConcurrentListener pub fn new() -> Self { Self { listeners: vec![] } @@ -58,7 +58,7 @@ impl ConcurrentListener { /// ``` pub fn add(&mut self, listener: L) -> io::Result<()> where - L: ToListener, + L: ToListener, { self.listeners.push(Box::new(listener.to_listener()?)); Ok(()) @@ -77,7 +77,7 @@ impl ConcurrentListener { /// # Ok(()) }) } pub fn with_listener(mut self, listener: L) -> Self where - L: ToListener, + L: ToListener, { self.add(listener).expect("Unable to add listener"); self @@ -85,11 +85,8 @@ impl ConcurrentListener { } #[async_trait::async_trait] -impl Listener for ConcurrentListener -where - State: Clone + Send + Sync + 'static, -{ - async fn bind(&mut self, app: Server) -> io::Result<()> { +impl Listener for ConcurrentListener { + async fn bind(&mut self, app: Server) -> io::Result<()> { for listener in self.listeners.iter_mut() { listener.bind(app.clone()).await?; } @@ -117,13 +114,13 @@ where } } -impl Debug for ConcurrentListener { +impl Debug for ConcurrentListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{:?}", self.listeners) } } -impl Display for ConcurrentListener { +impl Display for ConcurrentListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let string = self .listeners diff --git a/src/listener/failover_listener.rs b/src/listener/failover_listener.rs index 26f0ee196..8b418c25c 100644 --- a/src/listener/failover_listener.rs +++ b/src/listener/failover_listener.rs @@ -34,15 +34,12 @@ use crate::listener::ListenInfo; ///} ///``` #[derive(Default)] -pub struct FailoverListener { - listeners: Vec>>>, +pub struct FailoverListener { + listeners: Vec>>, index: Option, } -impl FailoverListener -where - State: Clone + Send + Sync + 'static, -{ +impl FailoverListener { /// creates a new FailoverListener pub fn new() -> Self { Self { @@ -69,7 +66,7 @@ where /// ``` pub fn add(&mut self, listener: L) -> io::Result<()> where - L: ToListener, + L: ToListener, { self.listeners.push(Some(Box::new(listener.to_listener()?))); Ok(()) @@ -88,7 +85,7 @@ where /// # Ok(()) }) } pub fn with_listener(mut self, listener: L) -> Self where - L: ToListener, + L: ToListener, { self.add(listener).expect("Unable to add listener"); self @@ -96,11 +93,8 @@ where } #[async_trait::async_trait] -impl Listener for FailoverListener -where - State: Clone + Send + Sync + 'static, -{ - async fn bind(&mut self, app: Server) -> io::Result<()> { +impl Listener for FailoverListener { + async fn bind(&mut self, app: Server) -> io::Result<()> { for (index, listener) in self.listeners.iter_mut().enumerate() { let listener = listener.as_deref_mut().expect("bind called twice"); match listener.bind(app.clone()).await { @@ -148,13 +142,13 @@ where } } -impl Debug for FailoverListener { +impl Debug for FailoverListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{:?}", self.listeners) } } -impl Display for FailoverListener { +impl Display for FailoverListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let string = self .listeners diff --git a/src/listener/mod.rs b/src/listener/mod.rs index 2d469872d..2adb965f2 100644 --- a/src/listener/mod.rs +++ b/src/listener/mod.rs @@ -35,14 +35,11 @@ pub(crate) use unix_listener::UnixListener; /// implement at least one [`ToListener`](crate::listener::ToListener) that /// outputs your Listener type. #[async_trait] -pub trait Listener: Debug + Display + Send + Sync + 'static -where - State: Send + Sync + 'static, -{ +pub trait Listener: Debug + Display + Send + Sync + 'static { /// Bind the listener. This starts the listening process by opening the /// necessary network ports, but not yet accepting incoming connections. This /// method must be called before `accept`. - async fn bind(&mut self, app: Server) -> io::Result<()>; + async fn bind(&mut self, app: Server) -> io::Result<()>; /// Start accepting incoming connections. This method must be called only /// after `bind` has succeeded. @@ -54,12 +51,8 @@ where } #[async_trait] -impl Listener for Box -where - L: Listener, - State: Send + Sync + 'static, -{ - async fn bind(&mut self, app: Server) -> io::Result<()> { +impl Listener for Box { + async fn bind(&mut self, app: Server) -> io::Result<()> { self.as_mut().bind(app).await } diff --git a/src/listener/parsed_listener.rs b/src/listener/parsed_listener.rs index ad2926a10..82143eaca 100644 --- a/src/listener/parsed_listener.rs +++ b/src/listener/parsed_listener.rs @@ -13,13 +13,13 @@ use std::fmt::{self, Debug, Display, Formatter}; /// /// This is currently crate-visible only, and tide users are expected /// to create these through [ToListener](crate::ToListener) conversions. -pub enum ParsedListener { +pub enum ParsedListener { #[cfg(unix)] - Unix(UnixListener), - Tcp(TcpListener), + Unix(UnixListener), + Tcp(TcpListener), } -impl Debug for ParsedListener { +impl Debug for ParsedListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { #[cfg(unix)] @@ -29,7 +29,7 @@ impl Debug for ParsedListener { } } -impl Display for ParsedListener { +impl Display for ParsedListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { #[cfg(unix)] @@ -40,11 +40,8 @@ impl Display for ParsedListener { } #[async_trait::async_trait] -impl Listener for ParsedListener -where - State: Clone + Send + Sync + 'static, -{ - async fn bind(&mut self, server: Server) -> io::Result<()> { +impl Listener for ParsedListener { + async fn bind(&mut self, server: Server) -> io::Result<()> { match self { #[cfg(unix)] Self::Unix(u) => u.bind(server).await, diff --git a/src/listener/tcp_listener.rs b/src/listener/tcp_listener.rs index 03be652eb..51cfc104a 100644 --- a/src/listener/tcp_listener.rs +++ b/src/listener/tcp_listener.rs @@ -18,14 +18,14 @@ use kv_log_macro::error; /// /// This is currently crate-visible only, and tide users are expected /// to create these through [ToListener](crate::ToListener) conversions. -pub struct TcpListener { +pub struct TcpListener { addrs: Option>, listener: Option, - server: Option>, + server: Option, info: Option, } -impl TcpListener { +impl TcpListener { pub fn from_addrs(addrs: Vec) -> Self { Self { addrs: Some(addrs), @@ -45,7 +45,7 @@ impl TcpListener { } } -fn handle_tcp(app: Server, stream: TcpStream) { +fn handle_tcp(app: Server, stream: TcpStream) { task::spawn(async move { let local_addr = stream.local_addr().ok(); let peer_addr = stream.peer_addr().ok(); @@ -63,11 +63,8 @@ fn handle_tcp(app: Server, stream: } #[async_trait::async_trait] -impl Listener for TcpListener -where - State: Clone + Send + Sync + 'static, -{ - async fn bind(&mut self, server: Server) -> io::Result<()> { +impl Listener for TcpListener { + async fn bind(&mut self, server: Server) -> io::Result<()> { assert!(self.server.is_none(), "`bind` should only be called once"); self.server = Some(server); @@ -127,7 +124,7 @@ where } } -impl fmt::Debug for TcpListener { +impl fmt::Debug for TcpListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("TcpListener") .field("listener", &self.listener) @@ -144,7 +141,7 @@ impl fmt::Debug for TcpListener { } } -impl Display for TcpListener { +impl Display for TcpListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let http_fmt = |a| format!("http://{}", a); match &self.listener { diff --git a/src/listener/to_listener.rs b/src/listener/to_listener.rs index b25b8f36b..79b125c38 100644 --- a/src/listener/to_listener.rs +++ b/src/listener/to_listener.rs @@ -47,9 +47,9 @@ use async_std::io; /// ``` /// # Other implementations /// See below for additional provided implementations of ToListener. -pub trait ToListener { +pub trait ToListener { /// What listener are we converting into? - type Listener: Listener; + type Listener: Listener; /// Transform self into a /// [`Listener`](crate::listener::Listener). Unless self is diff --git a/src/listener/to_listener_impls.rs b/src/listener/to_listener_impls.rs index 1e92fef0d..bcf80a838 100644 --- a/src/listener/to_listener_impls.rs +++ b/src/listener/to_listener_impls.rs @@ -5,11 +5,8 @@ use crate::http::url::Url; use async_std::io; use std::net::ToSocketAddrs; -impl ToListener for Url -where - State: Clone + Send + Sync + 'static, -{ - type Listener = ParsedListener; +impl ToListener for Url { + type Listener = ParsedListener; fn to_listener(self) -> io::Result { match self.scheme() { @@ -51,31 +48,22 @@ where } } -impl ToListener for String -where - State: Clone + Send + Sync + 'static, -{ - type Listener = ParsedListener; +impl ToListener for String { + type Listener = ParsedListener; fn to_listener(self) -> io::Result { - ToListener::::to_listener(self.as_str()) + ToListener::to_listener(self.as_str()) } } -impl ToListener for &String -where - State: Clone + Send + Sync + 'static, -{ - type Listener = ParsedListener; +impl ToListener for &String { + type Listener = ParsedListener; fn to_listener(self) -> io::Result { - ToListener::::to_listener(self.as_str()) + ToListener::to_listener(self.as_str()) } } -impl ToListener for &str -where - State: Clone + Send + Sync + 'static, -{ - type Listener = ParsedListener; +impl ToListener for &str { + type Listener = ParsedListener; fn to_listener(self) -> io::Result { if let Ok(socket_addrs) = self.to_socket_addrs() { @@ -83,7 +71,7 @@ where socket_addrs.collect(), ))) } else if let Ok(url) = Url::parse(self) { - ToListener::::to_listener(url) + ToListener::to_listener(url) } else { Err(io::Error::new( io::ErrorKind::InvalidInput, @@ -94,72 +82,51 @@ where } #[cfg(unix)] -impl ToListener for async_std::path::PathBuf -where - State: Clone + Send + Sync + 'static, -{ - type Listener = UnixListener; +impl ToListener for async_std::path::PathBuf { + type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_path(self)) } } #[cfg(unix)] -impl ToListener for std::path::PathBuf -where - State: Clone + Send + Sync + 'static, -{ - type Listener = UnixListener; +impl ToListener for std::path::PathBuf { + type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_path(self)) } } -impl ToListener for async_std::net::TcpListener -where - State: Clone + Send + Sync + 'static, -{ - type Listener = TcpListener; +impl ToListener for async_std::net::TcpListener { + type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_listener(self)) } } -impl ToListener for std::net::TcpListener -where - State: Clone + Send + Sync + 'static, -{ - type Listener = TcpListener; +impl ToListener for std::net::TcpListener { + type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_listener(self)) } } -impl ToListener for (String, u16) -where - State: Clone + Send + Sync + 'static, -{ - type Listener = TcpListener; +impl ToListener for (String, u16) { + type Listener = TcpListener; fn to_listener(self) -> io::Result { - ToListener::::to_listener((self.0.as_str(), self.1)) + ToListener::to_listener((self.0.as_str(), self.1)) } } -impl ToListener for (&String, u16) -where - State: Clone + Send + Sync + 'static, -{ - type Listener = TcpListener; +impl ToListener for (&String, u16) { + type Listener = TcpListener; fn to_listener(self) -> io::Result { - ToListener::::to_listener((self.0.as_str(), self.1)) + ToListener::to_listener((self.0.as_str(), self.1)) } } -impl ToListener for (&str, u16) -where - State: Clone + Send + Sync + 'static, -{ - type Listener = TcpListener; +impl ToListener for (&str, u16) { + type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_addrs(self.to_socket_addrs()?.collect())) @@ -167,31 +134,22 @@ where } #[cfg(unix)] -impl ToListener for async_std::os::unix::net::UnixListener -where - State: Clone + Send + Sync + 'static, -{ - type Listener = UnixListener; +impl ToListener for async_std::os::unix::net::UnixListener { + type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_listener(self)) } } #[cfg(unix)] -impl ToListener for std::os::unix::net::UnixListener -where - State: Clone + Send + Sync + 'static, -{ - type Listener = UnixListener; +impl ToListener for std::os::unix::net::UnixListener { + type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_listener(self)) } } -impl ToListener for TcpListener -where - State: Clone + Send + Sync + 'static, -{ +impl ToListener for TcpListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) @@ -199,62 +157,46 @@ where } #[cfg(unix)] -impl ToListener for UnixListener -where - State: Clone + Send + Sync + 'static, -{ +impl ToListener for UnixListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for ConcurrentListener -where - State: Clone + Send + Sync + 'static, -{ +impl ToListener for ConcurrentListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for ParsedListener -where - State: Clone + Send + Sync + 'static, -{ +impl ToListener for ParsedListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for FailoverListener -where - State: Clone + Send + Sync + 'static, -{ +impl ToListener for FailoverListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for std::net::SocketAddr -where - State: Clone + Send + Sync + 'static, -{ - type Listener = TcpListener; +impl ToListener for std::net::SocketAddr { + type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_addrs(vec![self])) } } -impl ToListener for Vec +impl ToListener for Vec where - L: ToListener, - State: Clone + Send + Sync + 'static, + L: ToListener, { - type Listener = ConcurrentListener; + type Listener = ConcurrentListener; fn to_listener(self) -> io::Result { let mut concurrent_listener = ConcurrentListener::new(); for listener in self { @@ -268,7 +210,7 @@ where mod parse_tests { use super::*; - fn listen>(listener: L) -> io::Result { + fn listen(listener: L) -> io::Result { listener.to_listener() } diff --git a/src/listener/unix_listener.rs b/src/listener/unix_listener.rs index 85c717f7e..6d4a716f7 100644 --- a/src/listener/unix_listener.rs +++ b/src/listener/unix_listener.rs @@ -19,14 +19,14 @@ use kv_log_macro::error; /// /// This is currently crate-visible only, and tide users are expected /// to create these through [ToListener](crate::ToListener) conversions. -pub struct UnixListener { +pub struct UnixListener { path: Option, listener: Option, - server: Option>, + server: Option, info: Option, } -impl UnixListener { +impl UnixListener { pub fn from_path(path: impl Into) -> Self { Self { path: Some(path.into()), @@ -46,7 +46,7 @@ impl UnixListener { } } -fn handle_unix(app: Server, stream: UnixStream) { +fn handle_unix(app: Server, stream: UnixStream) { task::spawn(async move { let local_addr = unix_socket_addr_to_string(stream.local_addr()); let peer_addr = unix_socket_addr_to_string(stream.peer_addr()); @@ -64,11 +64,8 @@ fn handle_unix(app: Server, stream: } #[async_trait::async_trait] -impl Listener for UnixListener -where - State: Clone + Send + Sync + 'static, -{ - async fn bind(&mut self, server: Server) -> io::Result<()> { +impl Listener for UnixListener { + async fn bind(&mut self, server: Server) -> io::Result<()> { assert!(self.server.is_none(), "`bind` should only be called once"); self.server = Some(server); @@ -125,7 +122,7 @@ where } } -impl fmt::Debug for UnixListener { +impl fmt::Debug for UnixListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("UnixListener") .field("listener", &self.listener) @@ -142,7 +139,7 @@ impl fmt::Debug for UnixListener { } } -impl Display for UnixListener { +impl Display for UnixListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match &self.listener { Some(listener) => { diff --git a/src/log/middleware.rs b/src/log/middleware.rs index ff4b37840..323ab6f1e 100644 --- a/src/log/middleware.rs +++ b/src/log/middleware.rs @@ -27,11 +27,7 @@ impl LogMiddleware { } /// Log a request and a response. - async fn log<'a, State: Clone + Send + Sync + 'static>( - &'a self, - mut req: Request, - next: Next<'a, State>, - ) -> crate::Result { + async fn log(&self, mut req: Request, next: Next) -> crate::Result { if req.ext::().is_some() { return Ok(next.run(req).await); } @@ -95,8 +91,8 @@ impl LogMiddleware { } #[async_trait::async_trait] -impl Middleware for LogMiddleware { - async fn handle(&self, req: Request, next: Next<'_, State>) -> crate::Result { +impl Middleware for LogMiddleware { + async fn handle(&self, req: Request, next: Next) -> crate::Result { self.log(req, next).await } } diff --git a/src/middleware.rs b/src/middleware.rs index 7e1ca9a90..f26207448 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -2,17 +2,15 @@ use std::sync::Arc; -use crate::endpoint::DynEndpoint; -use crate::{Request, Response}; +use crate::{Endpoint, Request, Response}; use async_trait::async_trait; use std::future::Future; -use std::pin::Pin; /// Middleware that wraps around the remaining middleware chain. #[async_trait] -pub trait Middleware: Send + Sync + 'static { +pub trait Middleware: Send + Sync + 'static { /// Asynchronously handle the request, and return a response. - async fn handle(&self, request: Request, next: Next<'_, State>) -> crate::Result; + async fn handle(&self, request: Request, next: Next) -> crate::Result; /// Set the middleware's name. By default it uses the type signature. fn name(&self) -> &str { @@ -21,43 +19,87 @@ pub trait Middleware: Send + Sync + 'static { } #[async_trait] -impl Middleware for F +impl Middleware for F where - State: Clone + Send + Sync + 'static, - F: Send - + Sync - + 'static - + for<'a> Fn( - Request, - Next<'a, State>, - ) -> Pin + 'a + Send>>, + F: Send + Sync + 'static + Fn(Request, Next) -> Fut, + Fut: Future> + Send + 'static, + Res: Into + 'static, { - async fn handle(&self, req: Request, next: Next<'_, State>) -> crate::Result { - (self)(req, next).await + async fn handle(&self, req: Request, next: Next) -> crate::Result { + let fut = (self)(req, next); + let res = fut.await?; + Ok(res.into()) } } /// The remainder of a middleware chain, including the endpoint. #[allow(missing_debug_implementations)] -pub struct Next<'a, State> { - pub(crate) endpoint: &'a DynEndpoint, - pub(crate) next_middleware: &'a [Arc>], +pub struct Next { + cursor: usize, + endpoint: Arc, + middleware: Arc>>, } -impl Next<'_, State> { +impl Next { + /// Creates a new Next middleware with an arc to the endpoint and middleware + pub fn new(endpoint: impl Endpoint, middleware: Vec>) -> Self { + Self { + cursor: 0, + endpoint: Arc::new(endpoint), + middleware: Arc::new(middleware), + } + } + + /// Creates a new Next middleware from the given endpoint with empty middleware + pub fn from_endpoint(endpoint: impl Endpoint) -> Self { + Self { + cursor: 0, + endpoint: Arc::new(endpoint), + middleware: Arc::default(), + } + } + + /// Creates a new Next middleware from an existing next middleware (as an endpoint) + pub fn from_next(endpoint: Next, middleware: Arc>>) -> Self { + Self { + cursor: 0, + endpoint: Arc::new(endpoint), + middleware, + } + } + /// Asynchronously execute the remaining middleware chain. - pub async fn run(mut self, req: Request) -> Response { - if let Some((current, next)) = self.next_middleware.split_first() { - self.next_middleware = next; - match current.handle(req, self).await { - Ok(request) => request, + pub async fn run(mut self, req: Request) -> Response { + if let Some(mid) = self.middleware.get(self.cursor) { + self.cursor += 1; + match mid.to_owned().handle(req, self).await { + Ok(response) => response, Err(err) => err.into(), } } else { match self.endpoint.call(req).await { - Ok(request) => request, + Ok(response) => response, Err(err) => err.into(), } } } } + +#[async_trait] +impl Endpoint for Next { + async fn call(&self, req: Request) -> crate::Result { + let next = self.clone(); + let response = next.run(req).await; + Ok(response) + } +} + +impl Clone for Next { + fn clone(&self) -> Self { + Next { + cursor: 0, + endpoint: self.endpoint.clone(), + middleware: self.middleware.clone(), + } + } +} diff --git a/src/redirect.rs b/src/redirect.rs index acb50f92e..ad9113a05 100644 --- a/src/redirect.rs +++ b/src/redirect.rs @@ -28,7 +28,7 @@ use crate::{Endpoint, Request, Response}; /// # use tide::{Response, Redirect, Request, StatusCode}; /// # fn next_product() -> Option { None } /// # #[allow(dead_code)] -/// async fn route_handler(request: Request<()>) -> tide::Result { +/// async fn route_handler(request: Request) -> tide::Result { /// if let Some(product_url) = next_product() { /// Ok(Redirect::new(product_url).into()) /// } else { @@ -86,12 +86,11 @@ impl> Redirect { } #[async_trait::async_trait] -impl Endpoint for Redirect +impl Endpoint for Redirect where - State: Clone + Send + Sync + 'static, T: AsRef + Send + Sync + 'static, { - async fn call(&self, _req: Request) -> crate::Result { + async fn call(&self, _req: Request) -> crate::Result { Ok(self.into()) } } diff --git a/src/request.rs b/src/request.rs index 854154e66..1f8f81231 100644 --- a/src/request.rs +++ b/src/request.rs @@ -23,26 +23,26 @@ pin_project_lite::pin_project! { /// Requests also provide *extensions*, a type map primarily used for low-level /// communication between middleware and endpoints. #[derive(Debug)] - pub struct Request { - pub(crate) state: State, + pub struct Request { #[pin] pub(crate) req: http::Request, pub(crate) route_params: Vec>, } } -impl Request { +impl Request { /// Create a new `Request`. pub(crate) fn new( - state: State, req: http_types::Request, route_params: Vec>, ) -> Self { - Self { - state, - req, - route_params, - } + Self { req, route_params } + } + + /// Returns the current app state + pub fn state(&self) -> &T { + self.ext::() + .expect("request state not set for type, did you call app.with_state?") } /// Access the request's HTTP method. @@ -56,7 +56,7 @@ impl Request { /// use tide::Request; /// /// let mut app = tide::new(); - /// app.at("/").get(|req: Request<()>| async move { + /// app.at("/").get(|req: Request| async move { /// assert_eq!(req.method(), http_types::Method::Get); /// Ok("") /// }); @@ -80,7 +80,7 @@ impl Request { /// use tide::Request; /// /// let mut app = tide::new(); - /// app.at("/").get(|req: Request<()>| async move { + /// app.at("/").get(|req: Request| async move { /// assert_eq!(req.url(), &"/".parse::().unwrap()); /// Ok("") /// }); @@ -104,7 +104,7 @@ impl Request { /// use tide::Request; /// /// let mut app = tide::new(); - /// app.at("/").get(|req: Request<()>| async move { + /// app.at("/").get(|req: Request| async move { /// assert_eq!(req.version(), Some(http_types::Version::Http1_1)); /// Ok("") /// }); @@ -175,7 +175,7 @@ impl Request { /// use tide::Request; /// /// let mut app = tide::new(); - /// app.at("/").get(|req: Request<()>| async move { + /// app.at("/").get(|req: Request| async move { /// assert_eq!(req.header("X-Forwarded-For").unwrap(), "127.0.0.1"); /// Ok("") /// }); @@ -260,12 +260,6 @@ impl Request { self.req.ext_mut().insert(val) } - #[must_use] - /// Access application scoped state. - pub fn state(&self) -> &State { - &self.state - } - /// Extract and parse a route parameter by name. /// /// Returns the parameter as a `&str`, borrowed from this `Request`. @@ -284,7 +278,7 @@ impl Request { /// # /// use tide::{Request, Result}; /// - /// async fn greet(req: Request<()>) -> Result { + /// async fn greet(req: Request) -> Result { /// let name = req.param("name").unwrap_or("world"); /// Ok(format!("Hello, {}!", name)) /// } @@ -316,7 +310,7 @@ impl Request { /// # /// use tide::{Request, Result}; /// - /// async fn greet(req: Request<()>) -> Result { + /// async fn greet(req: Request) -> Result { /// let name = req.wildcard().unwrap_or("world"); /// Ok(format!("Hello, {}!", name)) /// } @@ -352,7 +346,7 @@ impl Request { /// selections: HashMap, /// } /// - /// let req: Request<()> = http::Request::get("https://httpbin.org/get?page=2&selections[width]=narrow&selections[height]=tall").into(); + /// let req: Request = http::Request::get("https://httpbin.org/get?page=2&selections[width]=narrow&selections[height]=tall").into(); /// let Index { page, selections } = req.query().unwrap(); /// assert_eq!(page, 2); /// assert_eq!(selections["width"], "narrow"); @@ -365,7 +359,7 @@ impl Request { /// format: &'q str, /// } /// - /// let req: Request<()> = http::Request::get("https://httpbin.org/get?format=bananna").into(); + /// let req: Request = http::Request::get("https://httpbin.org/get?format=bananna").into(); /// let Query { format } = req.query().unwrap(); /// assert_eq!(format, "bananna"); /// ``` @@ -407,7 +401,7 @@ impl Request { /// use tide::Request; /// /// let mut app = tide::new(); - /// app.at("/").get(|mut req: Request<()>| async move { + /// app.at("/").get(|mut req: Request| async move { /// let _body: Vec = req.body_bytes().await.unwrap(); /// Ok("") /// }); @@ -441,7 +435,7 @@ impl Request { /// use tide::Request; /// /// let mut app = tide::new(); - /// app.at("/").get(|mut req: Request<()>| async move { + /// app.at("/").get(|mut req: Request| async move { /// let _body: String = req.body_string().await.unwrap(); /// Ok("") /// }); @@ -481,7 +475,7 @@ impl Request { /// legs: u8 /// } /// - /// app.at("/").post(|mut req: tide::Request<()>| async move { + /// app.at("/").post(|mut req: tide::Request| async move { /// let animal: Animal = req.body_form().await?; /// Ok(format!( /// "hello, {}! i've put in an order for {} shoes", @@ -556,31 +550,31 @@ impl Request { } } -impl AsRef for Request { +impl AsRef for Request { fn as_ref(&self) -> &http::Request { &self.req } } -impl AsMut for Request { +impl AsMut for Request { fn as_mut(&mut self) -> &mut http::Request { &mut self.req } } -impl AsRef for Request { +impl AsRef for Request { fn as_ref(&self) -> &http::Headers { self.req.as_ref() } } -impl AsMut for Request { +impl AsMut for Request { fn as_mut(&mut self) -> &mut http::Headers { self.req.as_mut() } } -impl Read for Request { +impl Read for Request { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -590,27 +584,27 @@ impl Read for Request { } } -impl From> for http::Request { - fn from(request: Request) -> http::Request { +impl From for http::Request { + fn from(request: Request) -> http::Request { request.req } } -impl From for Request { - fn from(request: http_types::Request) -> Request { - Request::new(State::default(), request, vec![]) +impl From for Request { + fn from(request: http_types::Request) -> Request { + Request::new(request, vec![]) } } -impl From> for Response { - fn from(mut request: Request) -> Response { +impl From for Response { + fn from(mut request: Request) -> Response { let mut res = Response::new(StatusCode::Ok); res.set_body(request.take_body()); res } } -impl IntoIterator for Request { +impl IntoIterator for Request { type Item = (HeaderName, HeaderValues); type IntoIter = http_types::headers::IntoIter; @@ -621,7 +615,7 @@ impl IntoIterator for Request { } } -impl<'a, State> IntoIterator for &'a Request { +impl<'a> IntoIterator for &'a Request { type Item = (&'a HeaderName, &'a HeaderValues); type IntoIter = http_types::headers::Iter<'a>; @@ -631,7 +625,7 @@ impl<'a, State> IntoIterator for &'a Request { } } -impl<'a, State> IntoIterator for &'a mut Request { +impl<'a> IntoIterator for &'a mut Request { type Item = (&'a HeaderName, &'a mut HeaderValues); type IntoIter = http_types::headers::IterMut<'a>; @@ -641,7 +635,7 @@ impl<'a, State> IntoIterator for &'a mut Request { } } -impl Index for Request { +impl Index for Request { type Output = HeaderValues; /// Returns a reference to the value corresponding to the supplied name. @@ -655,7 +649,7 @@ impl Index for Request { } } -impl Index<&str> for Request { +impl Index<&str> for Request { type Output = HeaderValues; /// Returns a reference to the value corresponding to the supplied name. diff --git a/src/response.rs b/src/response.rs index 044b9f222..3ba1466b3 100644 --- a/src/response.rs +++ b/src/response.rs @@ -422,9 +422,21 @@ impl Response { self.res.ext().get() } - /// Set a response scoped extension value. - pub fn insert_ext(&mut self, val: T) { - self.res.ext_mut().insert(val); + /// Get a mutable reference to value stored in response extensions. + #[must_use] + pub fn ext_mut(&mut self) -> Option<&mut T> { + self.res.ext_mut().get_mut() + } + + /// Set a response extension value. + pub fn set_ext(&mut self, val: T) -> Option { + self.res.ext_mut().insert(val) + } + + /// Returns the current app state set from StateMiddleware + pub fn state(&self) -> &T { + self.ext::() + .expect("response state not set for type, did you call app.with_state?") } /// Create a `tide::Response` from a type that can be converted into an diff --git a/src/route.rs b/src/route.rs index e52889fda..d277c1669 100644 --- a/src/route.rs +++ b/src/route.rs @@ -3,9 +3,8 @@ use std::io; use std::path::Path; use std::sync::Arc; -use crate::endpoint::MiddlewareEndpoint; use crate::fs::{ServeDir, ServeFile}; -use crate::{router::Router, Endpoint, Middleware}; +use crate::{router::Router, Endpoint, Middleware, Next}; use kv_log_macro::trace; @@ -18,10 +17,10 @@ use kv_log_macro::trace; /// /// [`Server::at`]: ./struct.Server.html#method.at #[allow(missing_debug_implementations)] -pub struct Route<'a, State> { - router: &'a mut Router, +pub struct Route<'a> { + router: &'a mut Router, path: String, - middleware: Vec>>, + middleware: Vec>, /// Indicates whether the path of current route is treated as a prefix. Set by /// [`strip_prefix`]. /// @@ -29,8 +28,8 @@ pub struct Route<'a, State> { prefix: bool, } -impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { - pub(crate) fn new(router: &'a mut Router, path: String) -> Route<'a, State> { +impl<'a> Route<'a> { + pub(crate) fn new(router: &'a mut Router, path: String) -> Route<'a> { Route { router, path, @@ -40,7 +39,7 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { } /// Extend the route with the given `path`. - pub fn at<'b>(&'b mut self, path: &str) -> Route<'b, State> { + pub fn at<'b>(&'b mut self, path: &str) -> Route<'b> { let mut p = self.path.clone(); if !p.ends_with('/') && !path.starts_with('/') { @@ -78,10 +77,7 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { } /// Apply the given middleware to the current route. - pub fn with(&mut self, middleware: M) -> &mut Self - where - M: Middleware, - { + pub fn with(&mut self, middleware: impl Middleware) -> &mut Self { trace!( "Adding middleware {} to route {:?}", middleware.name(), @@ -113,8 +109,8 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { /// let mut example = tide::with_state("world"); /// example /// .at("/") - /// .get(|req: tide::Request<&'static str>| async move { - /// Ok(format!("Hello {state}!", state = req.state())) + /// .get(|req: tide::Request| async move { + /// Ok(format!("Hello {state}!", state = req.ext::<&str>().unwrap())) /// }); /// example /// }); @@ -125,11 +121,7 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { /// ``` /// /// [`Server`]: struct.Server.html - pub fn nest(&mut self, service: crate::Server) -> &mut Self - where - State: Clone + Send + Sync + 'static, - InnerState: Clone + Send + Sync + 'static, - { + pub fn nest(&mut self, service: crate::Server) -> &mut Self { let prefix = self.prefix; self.prefix = true; @@ -182,21 +174,15 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { } /// Add an endpoint for the given HTTP method - pub fn method(&mut self, method: http_types::Method, ep: impl Endpoint) -> &mut Self { + pub fn method(&mut self, method: http_types::Method, ep: impl Endpoint) -> &mut Self { if self.prefix { let ep = StripPrefixEndpoint::new(ep); let wildcard = self.at("*"); - wildcard.router.add( - &wildcard.path, - method, - MiddlewareEndpoint::wrap_with_middleware(ep, &wildcard.middleware), - ); + let next = Next::new(ep, wildcard.middleware.clone()); + wildcard.router.add(&wildcard.path, method, next); } else { - self.router.add( - &self.path, - method, - MiddlewareEndpoint::wrap_with_middleware(ep, &self.middleware), - ); + let next = Next::new(ep, self.middleware.clone()); + self.router.add(&self.path, method, next); } self } @@ -204,73 +190,69 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { /// Add an endpoint for all HTTP methods, as a fallback. /// /// Routes with specific HTTP methods will be tried first. - pub fn all(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn all(&mut self, ep: impl Endpoint) -> &mut Self { if self.prefix { let ep = StripPrefixEndpoint::new(ep); let wildcard = self.at("*"); - wildcard.router.add_all( - &wildcard.path, - MiddlewareEndpoint::wrap_with_middleware(ep, &wildcard.middleware), - ); + let next = Next::new(ep, wildcard.middleware.clone()); + wildcard.router.add_all(&wildcard.path, next); } else { - self.router.add_all( - &self.path, - MiddlewareEndpoint::wrap_with_middleware(ep, &self.middleware), - ); + let next = Next::new(ep, self.middleware.clone()); + self.router.add_all(&self.path, next); } self } /// Add an endpoint for `GET` requests - pub fn get(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn get(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Get, ep); self } /// Add an endpoint for `HEAD` requests - pub fn head(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn head(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Head, ep); self } /// Add an endpoint for `PUT` requests - pub fn put(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn put(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Put, ep); self } /// Add an endpoint for `POST` requests - pub fn post(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn post(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Post, ep); self } /// Add an endpoint for `DELETE` requests - pub fn delete(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn delete(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Delete, ep); self } /// Add an endpoint for `OPTIONS` requests - pub fn options(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn options(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Options, ep); self } /// Add an endpoint for `CONNECT` requests - pub fn connect(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn connect(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Connect, ep); self } /// Add an endpoint for `PATCH` requests - pub fn patch(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn patch(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Patch, ep); self } /// Add an endpoint for `TRACE` requests - pub fn trace(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn trace(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Trace, ep); self } @@ -292,14 +274,12 @@ impl Clone for StripPrefixEndpoint { } #[async_trait::async_trait] -impl Endpoint for StripPrefixEndpoint +impl Endpoint for StripPrefixEndpoint where - State: Clone + Send + Sync + 'static, - E: Endpoint, + E: Endpoint, { - async fn call(&self, req: crate::Request) -> crate::Result { + async fn call(&self, req: crate::Request) -> crate::Result { let crate::Request { - state, mut req, route_params, } = req; @@ -312,12 +292,6 @@ where req.url_mut().set_path(rest); - self.0 - .call(crate::Request { - state, - req, - route_params, - }) - .await + self.0.call(crate::Request { req, route_params }).await } } diff --git a/src/router.rs b/src/router.rs index 70b1ab350..fe40b061b 100644 --- a/src/router.rs +++ b/src/router.rs @@ -1,20 +1,19 @@ use routefinder::{Captures, Router as MethodRouter}; use std::collections::HashMap; -use crate::endpoint::DynEndpoint; -use crate::{Request, Response, StatusCode}; +use crate::{Next, Request, Response, StatusCode}; /// The routing table used by `Server` /// /// Internally, we have a separate state machine per http method; indexing /// by the method first allows the table itself to be more efficient. #[allow(missing_debug_implementations)] -pub(crate) struct Router { - method_map: HashMap>>>, - all_method_router: MethodRouter>>, +pub(crate) struct Router { + method_map: HashMap>, + all_method_router: MethodRouter, } -impl std::fmt::Debug for Router { +impl std::fmt::Debug for Router { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Router") .field("method_map", &self.method_map) @@ -24,12 +23,12 @@ impl std::fmt::Debug for Router { } /// The result of routing a URL -pub(crate) struct Selection<'a, State> { - pub(crate) endpoint: &'a DynEndpoint, +pub(crate) struct Selection { + pub(crate) next: Next, pub(crate) params: Captures<'static, 'static>, } -impl Router { +impl Router { pub(crate) fn new() -> Self { Router { method_map: HashMap::default(), @@ -37,12 +36,7 @@ impl Router { } } - pub(crate) fn add( - &mut self, - path: &str, - method: http_types::Method, - ep: Box>, - ) { + pub(crate) fn add(&mut self, path: &str, method: http_types::Method, ep: Next) { self.method_map .entry(method) .or_insert_with(MethodRouter::new) @@ -50,23 +44,23 @@ impl Router { .unwrap() } - pub(crate) fn add_all(&mut self, path: &str, ep: Box>) { + pub(crate) fn add_all(&mut self, path: &str, ep: Next) { self.all_method_router.add(path, ep).unwrap() } - pub(crate) fn route(&self, path: &str, method: http_types::Method) -> Selection<'_, State> { + pub(crate) fn route(&self, path: &str, method: http_types::Method) -> Selection { if let Some(m) = self .method_map .get(&method) .and_then(|r| r.best_match(path)) { Selection { - endpoint: m.handler(), + next: m.handler().clone(), params: m.captures().into_owned(), } } else if let Some(m) = self.all_method_router.best_match(path) { Selection { - endpoint: m.handler(), + next: m.handler().clone(), params: m.captures().into_owned(), } } else if method == http_types::Method::Head { @@ -83,26 +77,22 @@ impl Router { // If this `path` can be handled by a callback registered with a different HTTP method // should return 405 Method Not Allowed Selection { - endpoint: &method_not_allowed, + next: Next::from_endpoint(method_not_allowed), params: Captures::default(), } } else { Selection { - endpoint: ¬_found_endpoint, + next: Next::from_endpoint(not_found_endpoint), params: Captures::default(), } } } } -async fn not_found_endpoint( - _req: Request, -) -> crate::Result { +async fn not_found_endpoint(_req: Request) -> crate::Result { Ok(Response::new(StatusCode::NotFound)) } -async fn method_not_allowed( - _req: Request, -) -> crate::Result { +async fn method_not_allowed(_req: Request) -> crate::Result { Ok(Response::new(StatusCode::MethodNotAllowed)) } diff --git a/src/security/cors.rs b/src/security/cors.rs index f85d66529..2bdf2f961 100644 --- a/src/security/cors.rs +++ b/src/security/cors.rs @@ -136,8 +136,8 @@ impl CorsMiddleware { } #[async_trait::async_trait] -impl Middleware for CorsMiddleware { - async fn handle(&self, req: Request, next: Next<'_, State>) -> Result { +impl Middleware for CorsMiddleware { + async fn handle(&self, req: Request, next: Next) -> Result { // TODO: how should multiple origin values be handled? let origins = req.header(&headers::ORIGIN).cloned(); @@ -282,7 +282,7 @@ mod test { .unwrap() } - fn app() -> crate::Server<()> { + fn app() -> crate::Server { let mut app = crate::Server::new(); app.at(ENDPOINT).get(|_| async { Ok("Hello World") }); diff --git a/src/server.rs b/src/server.rs index 3b09fc400..ad196702c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -9,6 +9,7 @@ use crate::cookies; use crate::listener::{Listener, ToListener}; use crate::middleware::{Middleware, Next}; use crate::router::{Router, Selection}; +use crate::state::StateMiddleware; use crate::{Endpoint, Request, Route}; /// An HTTP server. @@ -26,9 +27,8 @@ use crate::{Endpoint, Request, Route}; /// - Middleware extends the base Tide framework with additional request or /// response processing, such as compression, default headers, or logging. To /// add middleware to an app, use the [`Server::with`] method. -pub struct Server { - router: Arc>, - state: State, +pub struct Server { + router: Arc, /// Holds the middleware stack. /// /// Note(Fishrock123): We do actually want this structure. @@ -37,10 +37,16 @@ pub struct Server { /// The inner Arc-s allow MiddlewareEndpoint-s to be cloned internally. /// We don't use a Mutex around the Vec here because adding a middleware during execution should be an error. #[allow(clippy::rc_buffer)] - middleware: Arc>>>, + middleware: Arc>>, } -impl Server<()> { +impl Default for Server { + fn default() -> Self { + Self::new() + } +} + +impl Server { /// Create a new Tide server. /// /// # Examples @@ -57,20 +63,15 @@ impl Server<()> { /// ``` #[must_use] pub fn new() -> Self { - Self::with_state(()) - } -} - -impl Default for Server<()> { - fn default() -> Self { - Self::new() + Self { + router: Arc::new(Router::new()), + middleware: Arc::new(vec![ + #[cfg(feature = "cookies")] + Arc::new(cookies::CookiesMiddleware::new()), + ]), + } } -} -impl Server -where - State: Clone + Send + Sync + 'static, -{ /// Create a new Tide server with shared application scoped state. /// /// Application scoped state is useful for storing items @@ -81,7 +82,7 @@ where /// # use async_std::task::block_on; /// # fn main() -> Result<(), std::io::Error> { block_on(async { /// # - /// use tide::Request; + /// use tide::{Request}; /// /// /// The shared application state. /// #[derive(Clone)] @@ -95,23 +96,17 @@ where /// }; /// /// // Initialize the application with state. - /// let mut app = tide::with_state(state); - /// app.at("/").get(|req: Request| async move { - /// Ok(format!("Hello, {}!", &req.state().name)) + /// let mut app = tide::new(); + /// app.with_state(state); + /// app.at("/").get(|req: Request| async move { + /// Ok(format!("Hello, {}!", &req.state::().name)) /// }); /// app.listen("127.0.0.1:8080").await?; /// # /// # Ok(()) }) } /// ``` - pub fn with_state(state: State) -> Self { - Self { - router: Arc::new(Router::new()), - middleware: Arc::new(vec![ - #[cfg(feature = "cookies")] - Arc::new(cookies::CookiesMiddleware::new()), - ]), - state, - } + pub fn with_state(&mut self, state: S) { + self.with(StateMiddleware::new(state)); } /// Add a new route at the given `path`, relative to root. @@ -160,7 +155,7 @@ where /// There is no fallback route matching, i.e. either a resource is a full /// match or not, which means that the order of adding resources has no /// effect. - pub fn at<'a>(&'a mut self, path: &str) -> Route<'a, State> { + pub fn at<'a>(&'a mut self, path: &str) -> Route<'a> { let router = Arc::get_mut(&mut self.router) .expect("Registering routes is not possible after the Server has started"); Route::new(router, path.to_owned()) @@ -175,10 +170,7 @@ where /// /// Middleware can only be added at the "top level" of an application, and is processed in the /// order in which it is applied. - pub fn with(&mut self, middleware: M) -> &mut Self - where - M: Middleware, - { + pub fn with(&mut self, middleware: impl Middleware) -> &mut Self { trace!("Adding middleware {}", middleware.name()); let m = Arc::get_mut(&mut self.middleware) .expect("Registering middleware is not possible after the Server has started"); @@ -203,7 +195,7 @@ where /// # /// # Ok(()) }) } /// ``` - pub async fn listen>(self, listener: L) -> io::Result<()> { + pub async fn listen(self, listener: L) -> io::Result<()> { let mut listener = listener.to_listener()?; listener.bind(self).await?; for info in listener.info().iter() { @@ -242,10 +234,7 @@ where /// # /// # Ok(()) }) } /// ``` - pub async fn bind>( - self, - listener: L, - ) -> io::Result<>::Listener> { + pub async fn bind(self, listener: L) -> io::Result<::Listener> { let mut listener = listener.to_listener()?; listener.bind(self).await?; Ok(listener) @@ -280,64 +269,36 @@ where Res: From, { let req = req.into(); - let Self { - router, - state, - middleware, - } = self.clone(); let method = req.method().to_owned(); - let Selection { endpoint, params } = router.route(req.url().path(), method); + let Selection { next, params } = self.router.route(req.url().path(), method); let route_params = vec![params]; - let req = Request::new(state, req, route_params); - - let next = Next { - endpoint, - next_middleware: &middleware, - }; - + let req = Request::new(req, route_params); + let next = Next::from_next(next, self.middleware.clone()); let res = next.run(req).await; let res: http_types::Response = res.into(); Ok(res.into()) } - - /// Gets a reference to the server's state. This is useful for testing and nesting: - /// - /// # Example - /// - /// ```rust - /// # #[derive(Clone)] struct SomeAppState; - /// let mut app = tide::with_state(SomeAppState); - /// let mut admin = tide::with_state(app.state().clone()); - /// admin.at("/").get(|_| async { Ok("nested app with cloned state") }); - /// app.at("/").nest(admin); - /// ``` - pub fn state(&self) -> &State { - &self.state - } } -impl std::fmt::Debug for Server { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Server").finish() - } -} - -impl Clone for Server { +impl Clone for Server { fn clone(&self) -> Self { Self { router: self.router.clone(), - state: self.state.clone(), middleware: self.middleware.clone(), } } } +impl std::fmt::Debug for Server { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Server").finish() + } +} + #[async_trait::async_trait] -impl - Endpoint for Server -{ - async fn call(&self, req: Request) -> crate::Result { +impl Endpoint for Server { + async fn call(&self, req: Request) -> crate::Result { let Request { req, mut route_params, @@ -345,25 +306,17 @@ impl http_client::HttpClient for Server { +impl http_client::HttpClient for Server { async fn send(&self, req: crate::http::Request) -> crate::http::Result { self.respond(req).await } diff --git a/src/sessions/middleware.rs b/src/sessions/middleware.rs index e81d4540b..fcd041567 100644 --- a/src/sessions/middleware.rs +++ b/src/sessions/middleware.rs @@ -28,20 +28,20 @@ const BASE64_DIGEST_LEN: usize = 44; /// b"we recommend you use std::env::var(\"TIDE_SECRET\").unwrap().as_bytes() instead of a fixed value" /// )); /// -/// app.with(tide::utils::Before(|mut request: tide::Request<()>| async move { +/// app.with(tide::utils::Before(|mut request: tide::Request| async move { /// let session = request.session_mut(); /// let visits: usize = session.get("visits").unwrap_or_default(); /// session.insert("visits", visits + 1).unwrap(); /// request /// })); /// -/// app.at("/").get(|req: tide::Request<()>| async move { +/// app.at("/").get(|req: tide::Request| async move { /// let visits: usize = req.session().get("visits").unwrap(); /// Ok(format!("you have visited this website {} times", visits)) /// }); /// /// app.at("/reset") -/// .get(|mut req: tide::Request<()>| async move { +/// .get(|mut req: tide::Request| async move { /// req.session_mut().destroy(); /// Ok(tide::Redirect::new("/")) /// }); @@ -77,12 +77,11 @@ impl std::fmt::Debug for SessionMiddleware { } #[async_trait] -impl Middleware for SessionMiddleware +impl Middleware for SessionMiddleware where Store: SessionStore, - State: Clone + Send + Sync + 'static, { - async fn handle(&self, mut request: Request, next: Next<'_, State>) -> crate::Result { + async fn handle(&self, mut request: Request, next: Next) -> crate::Result { let cookie = request.cookie(&self.cookie_name); let cookie_value = cookie .clone() diff --git a/src/sse/endpoint.rs b/src/sse/endpoint.rs index fb81c26e8..c8c3dc822 100644 --- a/src/sse/endpoint.rs +++ b/src/sse/endpoint.rs @@ -11,38 +11,33 @@ use std::marker::PhantomData; use std::sync::Arc; /// Create an endpoint that can handle SSE connections. -pub fn endpoint(handler: F) -> SseEndpoint +pub fn endpoint(handler: F) -> SseEndpoint where - State: Clone + Send + Sync + 'static, - F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, + F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { SseEndpoint { handler: Arc::new(handler), - __state: PhantomData, } } /// An endpoint that can handle SSE connections. #[derive(Debug)] -pub struct SseEndpoint +pub struct SseEndpoint where - State: Clone + Send + Sync + 'static, - F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, + F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { handler: Arc, - __state: PhantomData, } #[async_trait::async_trait] -impl Endpoint for SseEndpoint +impl Endpoint for SseEndpoint where - State: Clone + Send + Sync + 'static, - F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, + F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { - async fn call(&self, req: Request) -> Result { + async fn call(&self, req: Request) -> Result { let handler = self.handler.clone(); let (sender, encoder) = async_sse::encode(); task::spawn(async move { diff --git a/src/sse/upgrade.rs b/src/sse/upgrade.rs index 6cccda96d..5aa1287a9 100644 --- a/src/sse/upgrade.rs +++ b/src/sse/upgrade.rs @@ -9,10 +9,9 @@ use async_std::io::BufReader; use async_std::task; /// Upgrade an existing HTTP connection to an SSE connection. -pub fn upgrade(req: Request, handler: F) -> Response +pub fn upgrade(req: Request, handler: F) -> Response where - State: Clone + Send + Sync + 'static, - F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, + F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { let (sender, encoder) = async_sse::encode(); diff --git a/src/state/middleware.rs b/src/state/middleware.rs new file mode 100644 index 000000000..4cc4f3d08 --- /dev/null +++ b/src/state/middleware.rs @@ -0,0 +1,24 @@ +use crate::{utils::async_trait, Middleware, Next, Request}; + +/// Sets data onto the request extensions +#[derive(Debug)] +pub struct StateMiddleware { + data: T, +} + +impl StateMiddleware { + /// Creates a new state middleware with the provided state + pub fn new(data: T) -> Self { + Self { data } + } +} + +#[async_trait] +impl Middleware for StateMiddleware { + async fn handle(&self, mut request: Request, next: Next) -> crate::Result { + request.set_ext(self.data.clone()); + let mut response = next.run(request).await; + response.set_ext(self.data.clone()); + Ok(response) + } +} diff --git a/src/state/mod.rs b/src/state/mod.rs new file mode 100644 index 000000000..b0f78900b --- /dev/null +++ b/src/state/mod.rs @@ -0,0 +1,3 @@ +mod middleware; + +pub use middleware::StateMiddleware; diff --git a/src/utils.rs b/src/utils.rs index fde11b3ea..0832f3a0e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -16,7 +16,7 @@ use std::future::Future; /// use std::time::Instant; /// /// let mut app = tide::new(); -/// app.with(utils::Before(|mut request: Request<()>| async move { +/// app.with(utils::Before(|mut request: Request| async move { /// request.set_ext(Instant::now()); /// request /// })); @@ -25,13 +25,12 @@ use std::future::Future; pub struct Before(pub F); #[async_trait] -impl Middleware for Before +impl Middleware for Before where - State: Clone + Send + Sync + 'static, - F: Fn(Request) -> Fut + Send + Sync + 'static, - Fut: Future> + Send + Sync + 'static, + F: Fn(Request) -> Fut + Send + Sync + 'static, + Fut: Future + Send + Sync + 'static, { - async fn handle(&self, request: Request, next: Next<'_, State>) -> crate::Result { + async fn handle(&self, request: Request, next: Next) -> crate::Result { let request = (self.0)(request).await; Ok(next.run(request).await) } @@ -59,13 +58,12 @@ where #[derive(Debug)] pub struct After(pub F); #[async_trait] -impl Middleware for After +impl Middleware for After where - State: Clone + Send + Sync + 'static, F: Fn(Response) -> Fut + Send + Sync + 'static, Fut: Future + Send + Sync + 'static, { - async fn handle(&self, request: Request, next: Next<'_, State>) -> crate::Result { + async fn handle(&self, request: Request, next: Next) -> crate::Result { let response = next.run(request).await; (self.0)(response).await } diff --git a/tests/cookies.rs b/tests/cookies.rs index 3eed212e2..2c4bfbba7 100644 --- a/tests/cookies.rs +++ b/tests/cookies.rs @@ -6,7 +6,7 @@ use tide::{Request, Response, Server, StatusCode}; static COOKIE_NAME: &str = "testCookie"; -async fn retrieve_cookie(req: Request<()>) -> tide::Result { +async fn retrieve_cookie(req: Request) -> tide::Result { Ok(format!( "{} and also {}", req.cookie(COOKIE_NAME).unwrap().value(), @@ -14,19 +14,19 @@ async fn retrieve_cookie(req: Request<()>) -> tide::Result { )) } -async fn insert_cookie(_req: Request<()>) -> tide::Result { +async fn insert_cookie(_req: Request) -> tide::Result { let mut res = Response::new(StatusCode::Ok); res.insert_cookie(Cookie::new(COOKIE_NAME, "NewCookieValue")); Ok(res) } -async fn remove_cookie(_req: Request<()>) -> tide::Result { +async fn remove_cookie(_req: Request) -> tide::Result { let mut res = Response::new(StatusCode::Ok); res.remove_cookie(Cookie::named(COOKIE_NAME)); Ok(res) } -async fn set_multiple_cookie(_req: Request<()>) -> tide::Result { +async fn set_multiple_cookie(_req: Request) -> tide::Result { let mut res = Response::new(StatusCode::Ok); res.insert_cookie(Cookie::new("C1", "V1")); res.insert_cookie(Cookie::new("C2", "V2")); diff --git a/tests/endpoint.rs b/tests/endpoint.rs index 7d376997a..ba8dcf162 100644 --- a/tests/endpoint.rs +++ b/tests/endpoint.rs @@ -3,7 +3,7 @@ use tide::Response; #[async_std::test] async fn should_accept_boxed_endpoints() { - fn endpoint() -> Box> { + fn endpoint() -> Box { Box::new(|_| async { Ok("hello world") }) } diff --git a/tests/function_middleware.rs b/tests/function_middleware.rs index fd3680a61..b41d374bd 100644 --- a/tests/function_middleware.rs +++ b/tests/function_middleware.rs @@ -1,28 +1,21 @@ -use std::future::Future; -use std::pin::Pin; use tide::http::{self, url::Url, Method}; mod test_utils; -fn auth_middleware<'a>( - request: tide::Request<()>, - next: tide::Next<'a, ()>, -) -> Pin + 'a + Send>> { +async fn auth_middleware(request: tide::Request, next: tide::Next) -> tide::Result { let authenticated = match request.header("X-Auth") { Some(header) => header == "secret_key", None => false, }; - Box::pin(async move { - if authenticated { - Ok(next.run(request).await) - } else { - Ok(tide::Response::new(tide::StatusCode::Unauthorized)) - } - }) + if authenticated { + Ok(next.run(request).await) + } else { + Ok(tide::Response::new(tide::StatusCode::Unauthorized)) + } } -async fn echo_path(req: tide::Request) -> tide::Result { +async fn echo_path(req: tide::Request) -> tide::Result { Ok(req.url().path().to_string()) } diff --git a/tests/nested.rs b/tests/nested.rs index 7aaeb945e..c5169ad25 100644 --- a/tests/nested.rs +++ b/tests/nested.rs @@ -1,5 +1,6 @@ mod test_utils; use test_utils::ServerTestingExt; +use tide::Request; #[async_std::test] async fn nested() -> tide::Result<()> { @@ -18,7 +19,7 @@ async fn nested() -> tide::Result<()> { #[async_std::test] async fn nested_middleware() -> tide::Result<()> { - let echo_path = |req: tide::Request<()>| async move { Ok(req.url().path().to_string()) }; + let echo_path = |req: Request| async move { Ok(req.url().path().to_string()) }; let mut app = tide::new(); let mut inner_app = tide::new(); inner_app.with(tide::utils::After(|mut res: tide::Response| async move { @@ -51,8 +52,9 @@ async fn nested_middleware() -> tide::Result<()> { async fn nested_with_different_state() -> tide::Result<()> { let mut outer = tide::new(); let mut inner = tide::with_state(42); - inner.at("/").get(|req: tide::Request| async move { - let num = req.state(); + inner.at("/").get(|req: Request| async move { + let num = req.state::(); + println!("{:?}", req.state::()); Ok(format!("the number is {}", num)) }); outer.at("/").get(|_| async { Ok("Hello, world!") }); diff --git a/tests/params.rs b/tests/params.rs index 0d551b515..0da56f0ad 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -3,7 +3,7 @@ use tide::{self, Request, Response, Result}; #[async_std::test] async fn test_missing_param() -> tide::Result<()> { - async fn greet(req: Request<()>) -> Result { + async fn greet(req: Request) -> Result { assert_eq!(req.param("name")?, "Param \"name\" not found"); Ok(Response::new(200)) } @@ -19,7 +19,7 @@ async fn test_missing_param() -> tide::Result<()> { #[async_std::test] async fn hello_world_parametrized() -> Result<()> { - async fn greet(req: tide::Request<()>) -> Result> { + async fn greet(req: tide::Request) -> Result> { let body = format!("{} says hello", req.param("name").unwrap_or("nori")); Ok(Response::builder(200).body(body)) } diff --git a/tests/route_middleware.rs b/tests/route_middleware.rs index ba89a15d8..c815288b3 100644 --- a/tests/route_middleware.rs +++ b/tests/route_middleware.rs @@ -14,19 +14,15 @@ impl TestMiddleware { } #[async_trait::async_trait] -impl Middleware for TestMiddleware { - async fn handle( - &self, - req: tide::Request, - next: tide::Next<'_, State>, - ) -> tide::Result { +impl Middleware for TestMiddleware { + async fn handle(&self, req: tide::Request, next: tide::Next) -> tide::Result { let mut res = next.run(req).await; res.insert_header(self.0.clone(), self.1); Ok(res) } } -async fn echo_path(req: tide::Request) -> tide::Result { +async fn echo_path(req: tide::Request) -> tide::Result { Ok(req.url().path().to_string()) } diff --git a/tests/serve_dir.rs b/tests/serve_dir.rs index 932eb8a60..62f09854d 100644 --- a/tests/serve_dir.rs +++ b/tests/serve_dir.rs @@ -3,11 +3,11 @@ use tide::{http, Result, Server}; use std::fs::{self, File}; use std::io::Write; -fn api() -> Box> { +fn api() -> Box { Box::new(|_| async { Ok("api") }) } -fn app(tempdir: &tempfile::TempDir) -> Result> { +fn app(tempdir: &tempfile::TempDir) -> Result { let static_dir = tempdir.path().join("static"); fs::create_dir(&static_dir)?; diff --git a/tests/server.rs b/tests/server.rs index db9c3c54d..cc22d569c 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -12,7 +12,7 @@ fn hello_world() -> tide::Result<()> { let port = test_utils::find_port().await; let server = task::spawn(async move { let mut app = tide::new(); - app.at("/").get(move |mut req: Request<()>| async move { + app.at("/").get(move |mut req: Request| async move { assert_eq!(req.body_string().await.unwrap(), "nori".to_string()); assert!(req.local_addr().unwrap().contains(&port.to_string())); assert!(req.peer_addr().is_some()); @@ -75,7 +75,7 @@ fn json() -> tide::Result<()> { let port = test_utils::find_port().await; let server = task::spawn(async move { let mut app = tide::new(); - app.at("/").get(|mut req: Request<()>| async move { + app.at("/").get(|mut req: Request| async move { let mut counter: Counter = req.body_json().await.unwrap(); assert_eq!(counter.count, 0); counter.count = 1; diff --git a/tests/sessions.rs b/tests/sessions.rs index 6fe544394..8066f0cc9 100644 --- a/tests/sessions.rs +++ b/tests/sessions.rs @@ -22,13 +22,13 @@ async fn test_basic_sessions() -> tide::Result<()> { b"12345678901234567890123456789012345", )); - app.with(Before(|mut request: tide::Request<()>| async move { + app.with(Before(|mut request: tide::Request| async move { let visits: usize = request.session().get("visits").unwrap_or_default(); request.session_mut().insert("visits", visits + 1).unwrap(); request })); - app.at("/").get(|req: tide::Request<()>| async move { + app.at("/").get(|req: tide::Request| async move { let visits: usize = req.session().get("visits").unwrap(); Ok(format!("you have visited this website {} times", visits)) }); @@ -71,14 +71,14 @@ async fn test_customized_sessions() -> tide::Result<()> { ); app.at("/").get(|_| async { Ok("/") }); - app.at("/nested").get(|req: tide::Request<()>| async move { + app.at("/nested").get(|req: tide::Request| async move { Ok(format!( "/nested {}", req.session().get::("visits").unwrap_or_default() )) }); app.at("/nested/incr") - .get(|mut req: tide::Request<()>| async move { + .get(|mut req: tide::Request| async move { let mut visits: usize = req.session().get("visits").unwrap_or_default(); visits += 1; req.session_mut().insert("visits", visits)?; @@ -137,22 +137,21 @@ async fn test_session_destruction() -> tide::Result<()> { b"12345678901234567890123456789012345", )); - app.with(Before(|mut request: tide::Request<()>| async move { + app.with(Before(|mut request: tide::Request| async move { let visits: usize = request.session().get("visits").unwrap_or_default(); request.session_mut().insert("visits", visits + 1).unwrap(); request })); - app.at("/").get(|req: tide::Request<()>| async move { + app.at("/").get(|req: tide::Request| async move { let visits: usize = req.session().get("visits").unwrap(); Ok(format!("you have visited this website {} times", visits)) }); - app.at("/logout") - .post(|mut req: tide::Request<()>| async move { - req.session_mut().destroy(); - Ok(Response::new(200)) - }); + app.at("/logout").post(|mut req: tide::Request| async move { + req.session_mut().destroy(); + Ok(Response::new(200)) + }); let response = app.get("/").await?; let cookies = Cookies::from_response(&response); diff --git a/tests/test_utils.rs b/tests/test_utils.rs index 9ea19d0f5..0422a287d 100644 --- a/tests/test_utils.rs +++ b/tests/test_utils.rs @@ -60,7 +60,7 @@ pub trait ServerTestingExt { } } -impl ServerTestingExt for tide::Server { +impl ServerTestingExt for tide::Server { fn client(&self) -> Client { let config = Config::new() .set_http_client(self.clone()) diff --git a/tests/unix.rs b/tests/unix.rs index bc28a5ed0..0151ac8dd 100644 --- a/tests/unix.rs +++ b/tests/unix.rs @@ -16,7 +16,7 @@ mod unix_tests { let server = task::spawn(async move { let mut app = tide::new(); - app.at("/").get(|req: tide::Request<()>| async move { + app.at("/").get(|req: tide::Request| async move { Ok(req.local_addr().unwrap().to_string()) }); app.listen(sock_path).await?; diff --git a/tests/wildcard.rs b/tests/wildcard.rs index bb45752c2..7045fdf32 100644 --- a/tests/wildcard.rs +++ b/tests/wildcard.rs @@ -2,7 +2,7 @@ mod test_utils; use test_utils::ServerTestingExt; use tide::{Error, Request, StatusCode}; -async fn add_one(req: Request<()>) -> Result { +async fn add_one(req: Request) -> Result { let num: i64 = req .param("num")? .parse() @@ -10,7 +10,7 @@ async fn add_one(req: Request<()>) -> Result { Ok((num + 1).to_string()) } -async fn add_two(req: Request<()>) -> Result { +async fn add_two(req: Request) -> Result { let one: i64 = req .param("one")? .parse() @@ -22,14 +22,14 @@ async fn add_two(req: Request<()>) -> Result { Ok((one + two).to_string()) } -async fn echo_param(req: Request<()>) -> tide::Result { +async fn echo_param(req: Request) -> tide::Result { match req.param("param") { Ok(path) => Ok(path.into()), Err(_) => Ok(StatusCode::NotFound.into()), } } -async fn echo_wildcard(req: Request<()>) -> tide::Result { +async fn echo_wildcard(req: Request) -> tide::Result { match req.wildcard() { Some(path) => Ok(path.into()), None => Ok(StatusCode::NotFound.into()),