From ac0abd37bd5e3bb6088fadf06336921b6104674d Mon Sep 17 00:00:00 2001 From: Thomas Holloway Date: Mon, 18 Jul 2022 14:51:55 -0500 Subject: [PATCH] Erase type on server state and place in Arc any hashmap --- Cargo.toml | 1 + examples/graphql.rs | 10 +- examples/middleware.rs | 10 +- examples/state.rs | 12 +-- examples/upload.rs | 12 +-- src/endpoint.rs | 2 +- src/lib.rs | 17 ++-- src/listener/concurrent_listener.rs | 21 ++-- src/listener/failover_listener.rs | 24 ++--- src/listener/mod.rs | 15 +-- src/listener/parsed_listener.rs | 17 ++-- src/listener/tcp_listener.rs | 19 ++-- src/listener/to_listener.rs | 4 +- src/listener/to_listener_impls.rs | 142 ++++++++-------------------- src/listener/unix_listener.rs | 19 ++-- src/request.rs | 21 +++- src/route.rs | 14 ++- src/security/cors.rs | 2 +- src/server.rs | 81 ++++++---------- src/sessions/middleware.rs | 3 +- src/sse/endpoint.rs | 11 +-- src/sse/upgrade.rs | 3 +- src/state.rs | 76 +++++++++++++++ tests/nested.rs | 13 +-- tests/serve_dir.rs | 2 +- tests/test_utils.rs | 2 +- 26 files changed, 247 insertions(+), 306 deletions(-) create mode 100644 src/state.rs diff --git a/Cargo.toml b/Cargo.toml index 3a0f7ee91..7da6f38ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ serde = "1.0.117" serde_json = "1.0.59" routefinder = "0.5.0" regex = "1.5.5" +hashbrown = "0.12.3" [dev-dependencies] async-std = { version = "1.6.5", features = ["unstable", "attributes"] } diff --git a/examples/graphql.rs b/examples/graphql.rs index eee046ce5..13db866c4 100644 --- a/examples/graphql.rs +++ b/examples/graphql.rs @@ -2,7 +2,7 @@ use std::sync::{Arc, RwLock}; use juniper::{http::graphiql, http::GraphQLRequest, RootNode}; use lazy_static::lazy_static; -use tide::{http::mime, Body, Redirect, Request, RequestState, Response, Server, StatusCode}; +use tide::{http::mime, Body, Redirect, Request, Response, Server, StatusCode}; #[derive(Clone)] struct User { @@ -76,7 +76,7 @@ lazy_static! { async fn handle_graphql(mut request: Request) -> tide::Result { let query: GraphQLRequest = request.body_json().await?; - let response = query.execute(&SCHEMA, request.state()); + let response = query.execute(&SCHEMA, request.state::()); let status = if response.is_ok() { StatusCode::Ok } else { @@ -105,9 +105,3 @@ async fn main() -> std::io::Result<()> { app.listen("0.0.0.0:8080").await?; Ok(()) } - -impl RequestState for Request { - fn state(&self) -> &State { - self.ext::().unwrap() - } -} diff --git a/examples/middleware.rs b/examples/middleware.rs index e706e2ce9..3e0589b83 100644 --- a/examples/middleware.rs +++ b/examples/middleware.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use tide::http::mime; use tide::utils::{After, Before}; -use tide::{Middleware, Next, Request, RequestState, Response, Result, StatusCode}; +use tide::{Middleware, Next, Request, Response, Result, StatusCode}; #[derive(Debug)] struct User { @@ -24,7 +24,7 @@ impl UserDatabase { // application state. Because it depends on a specific request state, // it would likely be closely tied to a specific application async fn user_loader(mut request: Request, next: Next) -> Result { - if let Some(user) = request.state().find_user().await { + if let Some(user) = request.state::().find_user().await { tide::log::trace!("user loaded", {user: user.name}); request.set_ext(user); Ok(next.run(request).await) @@ -125,9 +125,3 @@ async fn main() -> Result<()> { app.listen("127.0.0.1:8080").await?; Ok(()) } - -impl RequestState for Request { - fn state(&self) -> &UserDatabase { - self.ext::().unwrap() - } -} diff --git a/examples/state.rs b/examples/state.rs index 95fc0b530..c25aac40e 100644 --- a/examples/state.rs +++ b/examples/state.rs @@ -1,8 +1,6 @@ use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; -use tide::RequestState; - #[derive(Clone)] struct State { value: Arc, @@ -22,21 +20,15 @@ async fn main() -> tide::Result<()> { let mut app = tide::with_state(State::new()); app.with(tide::log::LogMiddleware::new()); app.at("/").get(|req: tide::Request| async move { - let state = req.state(); + let state = req.state::(); let value = state.value.load(Ordering::Relaxed); Ok(format!("{}\n", value)) }); app.at("/inc").get(|req: tide::Request| async move { - let state = req.state(); + let state = req.state::(); let value = state.value.fetch_add(1, Ordering::Relaxed) + 1; Ok(format!("{}\n", value)) }); app.listen("127.0.0.1:8080").await?; Ok(()) } - -impl RequestState for tide::Request { - fn state(&self) -> &State { - self.ext::().unwrap() - } -} diff --git a/examples/upload.rs b/examples/upload.rs index 84344ef7d..42d44c88e 100644 --- a/examples/upload.rs +++ b/examples/upload.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use async_std::{fs::OpenOptions, io}; use tempfile::TempDir; use tide::prelude::*; -use tide::{Body, Request, RequestState, Response, StatusCode}; +use tide::{Body, Request, Response, StatusCode}; #[derive(Clone)] struct TempDirState { @@ -24,12 +24,6 @@ impl TempDirState { } } -impl RequestState for Request { - fn state(&self) -> &TempDirState { - self.ext::().unwrap() - } -} - #[async_std::main] async fn main() -> Result<(), IoError> { // tide::log::start(); @@ -44,7 +38,7 @@ async fn main() -> Result<(), IoError> { app.at(":file") .put(|req: Request| async move { let path = req.param("file")?; - let state = req.state(); + let state = req.state::(); let fs_path = state.path().join(path); let file = OpenOptions::new() @@ -64,7 +58,7 @@ async fn main() -> Result<(), IoError> { }) .get(|req: Request| async move { let path = req.param("file")?; - let fs_path = req.state().path().join(path); + let fs_path = req.state::().path().join(path); if let Ok(body) = Body::from_file(fs_path).await { Ok(body.into()) diff --git a/src/endpoint.rs b/src/endpoint.rs index 427430e08..6b3b2ca24 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -11,7 +11,7 @@ use crate::{Middleware, Request, Response}; /// This trait is automatically implemented for `Fn` types, and so is rarely implemented /// directly by Tide users. /// -/// In practice, endpoints are functions that take a `Request` as an argument and +/// In practice, endpoints are functions that take a `Request` as an argument and /// return a type `T` that implements `Into`. /// /// # Examples diff --git a/src/lib.rs b/src/lib.rs index 26771a43d..c651bd372 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,6 +77,7 @@ mod response_builder; mod route; mod router; mod server; +mod state; pub mod convert; pub mod listener; @@ -97,8 +98,8 @@ pub use request::Request; pub use response::Response; pub use response_builder::ResponseBuilder; pub use route::Route; -pub use server::RequestState; pub use server::Server; +pub use state::State; pub use http_types::{self as http, Body, Error, Status, StatusCode}; @@ -117,7 +118,7 @@ pub use http_types::{self as http, Body, Error, Status, StatusCode}; /// # Ok(()) }) } /// ``` #[must_use] -pub fn new() -> server::Server<()> { +pub fn new() -> server::Server { Server::new() } @@ -131,7 +132,7 @@ pub fn new() -> server::Server<()> { /// # use async_std::task::block_on; /// # fn main() -> Result<(), std::io::Error> { block_on(async { /// # -/// use tide::{Request, RequestState}; +/// use tide::{Request}; /// /// /// The shared application state. /// #[derive(Clone)] @@ -144,22 +145,16 @@ pub fn new() -> server::Server<()> { /// name: "Nori".to_string() /// }; /// -/// impl RequestState for Request { -/// fn state(&self) -> &State { -/// self.ext::().unwrap() -/// } -/// } -/// /// // Initialize the application with state. /// let mut app = tide::with_state(state); /// app.at("/").get(|req: Request| async move { -/// Ok(format!("Hello, {}!", &req.state().name)) +/// Ok(format!("Hello, {}!", &req.state::().name)) /// }); /// app.listen("127.0.0.1:8080").await?; /// # /// # Ok(()) }) } /// ``` -pub fn with_state(state: State) -> server::Server +pub fn with_state(state: State) -> server::Server where State: Clone + Send + Sync + 'static, { diff --git a/src/listener/concurrent_listener.rs b/src/listener/concurrent_listener.rs index 0cc52bd8e..5d59cfb2a 100644 --- a/src/listener/concurrent_listener.rs +++ b/src/listener/concurrent_listener.rs @@ -33,11 +33,11 @@ use futures_util::stream::{futures_unordered::FuturesUnordered, StreamExt}; ///``` #[derive(Default)] -pub struct ConcurrentListener { - listeners: Vec>>, +pub struct ConcurrentListener { + listeners: Vec>, } -impl ConcurrentListener { +impl ConcurrentListener { /// creates a new ConcurrentListener pub fn new() -> Self { Self { listeners: vec![] } @@ -59,7 +59,7 @@ impl ConcurrentListener { /// ``` pub fn add(&mut self, listener: L) -> io::Result<()> where - L: ToListener, + L: ToListener, { self.listeners.push(Box::new(listener.to_listener()?)); Ok(()) @@ -78,7 +78,7 @@ impl ConcurrentListener { /// # Ok(()) }) } pub fn with_listener(mut self, listener: L) -> Self where - L: ToListener, + L: ToListener, { self.add(listener).expect("Unable to add listener"); self @@ -86,11 +86,8 @@ impl ConcurrentListener { } #[async_trait::async_trait] -impl Listener for ConcurrentListener -where - State: Clone + Send + Sync + 'static, -{ - async fn bind(&mut self, app: Server) -> io::Result<()> { +impl Listener for ConcurrentListener { + async fn bind(&mut self, app: Server) -> io::Result<()> { for listener in self.listeners.iter_mut() { listener.bind(app.clone()).await?; } @@ -118,13 +115,13 @@ where } } -impl Debug for ConcurrentListener { +impl Debug for ConcurrentListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{:?}", self.listeners) } } -impl Display for ConcurrentListener { +impl Display for ConcurrentListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let string = self .listeners diff --git a/src/listener/failover_listener.rs b/src/listener/failover_listener.rs index d4eea4d33..4fe37c2a2 100644 --- a/src/listener/failover_listener.rs +++ b/src/listener/failover_listener.rs @@ -34,15 +34,12 @@ use crate::listener::ListenInfo; ///} ///``` #[derive(Default)] -pub struct FailoverListener { - listeners: Vec>>>, +pub struct FailoverListener { + listeners: Vec>>, index: Option, } -impl FailoverListener -where - State: Clone + Send + Sync + 'static, -{ +impl FailoverListener { /// creates a new FailoverListener pub fn new() -> Self { Self { @@ -69,7 +66,7 @@ where /// ``` pub fn add(&mut self, listener: L) -> io::Result<()> where - L: ToListener, + L: ToListener, { self.listeners.push(Some(Box::new(listener.to_listener()?))); Ok(()) @@ -88,7 +85,7 @@ where /// # Ok(()) }) } pub fn with_listener(mut self, listener: L) -> Self where - L: ToListener, + L: ToListener, { self.add(listener).expect("Unable to add listener"); self @@ -96,11 +93,8 @@ where } #[async_trait::async_trait] -impl Listener for FailoverListener -where - State: Clone + Send + Sync + 'static, -{ - async fn bind(&mut self, app: Server) -> io::Result<()> { +impl Listener for FailoverListener { + async fn bind(&mut self, app: Server) -> io::Result<()> { for (index, listener) in self.listeners.iter_mut().enumerate() { let listener = listener.as_deref_mut().expect("bind called twice"); match listener.bind(app.clone()).await { @@ -148,13 +142,13 @@ where } } -impl Debug for FailoverListener { +impl Debug for FailoverListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{:?}", self.listeners) } } -impl Display for FailoverListener { +impl Display for FailoverListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let string = self .listeners diff --git a/src/listener/mod.rs b/src/listener/mod.rs index 2d469872d..2adb965f2 100644 --- a/src/listener/mod.rs +++ b/src/listener/mod.rs @@ -35,14 +35,11 @@ pub(crate) use unix_listener::UnixListener; /// implement at least one [`ToListener`](crate::listener::ToListener) that /// outputs your Listener type. #[async_trait] -pub trait Listener: Debug + Display + Send + Sync + 'static -where - State: Send + Sync + 'static, -{ +pub trait Listener: Debug + Display + Send + Sync + 'static { /// Bind the listener. This starts the listening process by opening the /// necessary network ports, but not yet accepting incoming connections. This /// method must be called before `accept`. - async fn bind(&mut self, app: Server) -> io::Result<()>; + async fn bind(&mut self, app: Server) -> io::Result<()>; /// Start accepting incoming connections. This method must be called only /// after `bind` has succeeded. @@ -54,12 +51,8 @@ where } #[async_trait] -impl Listener for Box -where - L: Listener, - State: Send + Sync + 'static, -{ - async fn bind(&mut self, app: Server) -> io::Result<()> { +impl Listener for Box { + async fn bind(&mut self, app: Server) -> io::Result<()> { self.as_mut().bind(app).await } diff --git a/src/listener/parsed_listener.rs b/src/listener/parsed_listener.rs index ad2926a10..82143eaca 100644 --- a/src/listener/parsed_listener.rs +++ b/src/listener/parsed_listener.rs @@ -13,13 +13,13 @@ use std::fmt::{self, Debug, Display, Formatter}; /// /// This is currently crate-visible only, and tide users are expected /// to create these through [ToListener](crate::ToListener) conversions. -pub enum ParsedListener { +pub enum ParsedListener { #[cfg(unix)] - Unix(UnixListener), - Tcp(TcpListener), + Unix(UnixListener), + Tcp(TcpListener), } -impl Debug for ParsedListener { +impl Debug for ParsedListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { #[cfg(unix)] @@ -29,7 +29,7 @@ impl Debug for ParsedListener { } } -impl Display for ParsedListener { +impl Display for ParsedListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { #[cfg(unix)] @@ -40,11 +40,8 @@ impl Display for ParsedListener { } #[async_trait::async_trait] -impl Listener for ParsedListener -where - State: Clone + Send + Sync + 'static, -{ - async fn bind(&mut self, server: Server) -> io::Result<()> { +impl Listener for ParsedListener { + async fn bind(&mut self, server: Server) -> io::Result<()> { match self { #[cfg(unix)] Self::Unix(u) => u.bind(server).await, diff --git a/src/listener/tcp_listener.rs b/src/listener/tcp_listener.rs index 7b86a013a..90f2dde45 100644 --- a/src/listener/tcp_listener.rs +++ b/src/listener/tcp_listener.rs @@ -17,14 +17,14 @@ use async_std::{io, task}; /// /// This is currently crate-visible only, and tide users are expected /// to create these through [ToListener](crate::ToListener) conversions. -pub struct TcpListener { +pub struct TcpListener { addrs: Option>, listener: Option, - server: Option>, + server: Option, info: Option, } -impl TcpListener { +impl TcpListener { pub fn from_addrs(addrs: Vec) -> Self { Self { addrs: Some(addrs), @@ -44,7 +44,7 @@ impl TcpListener { } } -fn handle_tcp(app: Server, stream: TcpStream) { +fn handle_tcp(app: Server, stream: TcpStream) { task::spawn(async move { let local_addr = stream.local_addr().ok(); let peer_addr = stream.peer_addr().ok(); @@ -62,11 +62,8 @@ fn handle_tcp(app: Server, stream: } #[async_trait::async_trait] -impl Listener for TcpListener -where - State: Clone + Send + Sync + 'static, -{ - async fn bind(&mut self, server: Server) -> io::Result<()> { +impl Listener for TcpListener { + async fn bind(&mut self, server: Server) -> io::Result<()> { assert!(self.server.is_none(), "`bind` should only be called once"); self.server = Some(server); @@ -126,7 +123,7 @@ where } } -impl fmt::Debug for TcpListener { +impl fmt::Debug for TcpListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("TcpListener") .field("listener", &self.listener) @@ -143,7 +140,7 @@ impl fmt::Debug for TcpListener { } } -impl Display for TcpListener { +impl Display for TcpListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let http_fmt = |a| format!("http://{}", a); match &self.listener { diff --git a/src/listener/to_listener.rs b/src/listener/to_listener.rs index b25b8f36b..79b125c38 100644 --- a/src/listener/to_listener.rs +++ b/src/listener/to_listener.rs @@ -47,9 +47,9 @@ use async_std::io; /// ``` /// # Other implementations /// See below for additional provided implementations of ToListener. -pub trait ToListener { +pub trait ToListener { /// What listener are we converting into? - type Listener: Listener; + type Listener: Listener; /// Transform self into a /// [`Listener`](crate::listener::Listener). Unless self is diff --git a/src/listener/to_listener_impls.rs b/src/listener/to_listener_impls.rs index 1e92fef0d..bcf80a838 100644 --- a/src/listener/to_listener_impls.rs +++ b/src/listener/to_listener_impls.rs @@ -5,11 +5,8 @@ use crate::http::url::Url; use async_std::io; use std::net::ToSocketAddrs; -impl ToListener for Url -where - State: Clone + Send + Sync + 'static, -{ - type Listener = ParsedListener; +impl ToListener for Url { + type Listener = ParsedListener; fn to_listener(self) -> io::Result { match self.scheme() { @@ -51,31 +48,22 @@ where } } -impl ToListener for String -where - State: Clone + Send + Sync + 'static, -{ - type Listener = ParsedListener; +impl ToListener for String { + type Listener = ParsedListener; fn to_listener(self) -> io::Result { - ToListener::::to_listener(self.as_str()) + ToListener::to_listener(self.as_str()) } } -impl ToListener for &String -where - State: Clone + Send + Sync + 'static, -{ - type Listener = ParsedListener; +impl ToListener for &String { + type Listener = ParsedListener; fn to_listener(self) -> io::Result { - ToListener::::to_listener(self.as_str()) + ToListener::to_listener(self.as_str()) } } -impl ToListener for &str -where - State: Clone + Send + Sync + 'static, -{ - type Listener = ParsedListener; +impl ToListener for &str { + type Listener = ParsedListener; fn to_listener(self) -> io::Result { if let Ok(socket_addrs) = self.to_socket_addrs() { @@ -83,7 +71,7 @@ where socket_addrs.collect(), ))) } else if let Ok(url) = Url::parse(self) { - ToListener::::to_listener(url) + ToListener::to_listener(url) } else { Err(io::Error::new( io::ErrorKind::InvalidInput, @@ -94,72 +82,51 @@ where } #[cfg(unix)] -impl ToListener for async_std::path::PathBuf -where - State: Clone + Send + Sync + 'static, -{ - type Listener = UnixListener; +impl ToListener for async_std::path::PathBuf { + type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_path(self)) } } #[cfg(unix)] -impl ToListener for std::path::PathBuf -where - State: Clone + Send + Sync + 'static, -{ - type Listener = UnixListener; +impl ToListener for std::path::PathBuf { + type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_path(self)) } } -impl ToListener for async_std::net::TcpListener -where - State: Clone + Send + Sync + 'static, -{ - type Listener = TcpListener; +impl ToListener for async_std::net::TcpListener { + type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_listener(self)) } } -impl ToListener for std::net::TcpListener -where - State: Clone + Send + Sync + 'static, -{ - type Listener = TcpListener; +impl ToListener for std::net::TcpListener { + type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_listener(self)) } } -impl ToListener for (String, u16) -where - State: Clone + Send + Sync + 'static, -{ - type Listener = TcpListener; +impl ToListener for (String, u16) { + type Listener = TcpListener; fn to_listener(self) -> io::Result { - ToListener::::to_listener((self.0.as_str(), self.1)) + ToListener::to_listener((self.0.as_str(), self.1)) } } -impl ToListener for (&String, u16) -where - State: Clone + Send + Sync + 'static, -{ - type Listener = TcpListener; +impl ToListener for (&String, u16) { + type Listener = TcpListener; fn to_listener(self) -> io::Result { - ToListener::::to_listener((self.0.as_str(), self.1)) + ToListener::to_listener((self.0.as_str(), self.1)) } } -impl ToListener for (&str, u16) -where - State: Clone + Send + Sync + 'static, -{ - type Listener = TcpListener; +impl ToListener for (&str, u16) { + type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_addrs(self.to_socket_addrs()?.collect())) @@ -167,31 +134,22 @@ where } #[cfg(unix)] -impl ToListener for async_std::os::unix::net::UnixListener -where - State: Clone + Send + Sync + 'static, -{ - type Listener = UnixListener; +impl ToListener for async_std::os::unix::net::UnixListener { + type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_listener(self)) } } #[cfg(unix)] -impl ToListener for std::os::unix::net::UnixListener -where - State: Clone + Send + Sync + 'static, -{ - type Listener = UnixListener; +impl ToListener for std::os::unix::net::UnixListener { + type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_listener(self)) } } -impl ToListener for TcpListener -where - State: Clone + Send + Sync + 'static, -{ +impl ToListener for TcpListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) @@ -199,62 +157,46 @@ where } #[cfg(unix)] -impl ToListener for UnixListener -where - State: Clone + Send + Sync + 'static, -{ +impl ToListener for UnixListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for ConcurrentListener -where - State: Clone + Send + Sync + 'static, -{ +impl ToListener for ConcurrentListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for ParsedListener -where - State: Clone + Send + Sync + 'static, -{ +impl ToListener for ParsedListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for FailoverListener -where - State: Clone + Send + Sync + 'static, -{ +impl ToListener for FailoverListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for std::net::SocketAddr -where - State: Clone + Send + Sync + 'static, -{ - type Listener = TcpListener; +impl ToListener for std::net::SocketAddr { + type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_addrs(vec![self])) } } -impl ToListener for Vec +impl ToListener for Vec where - L: ToListener, - State: Clone + Send + Sync + 'static, + L: ToListener, { - type Listener = ConcurrentListener; + type Listener = ConcurrentListener; fn to_listener(self) -> io::Result { let mut concurrent_listener = ConcurrentListener::new(); for listener in self { @@ -268,7 +210,7 @@ where mod parse_tests { use super::*; - fn listen>(listener: L) -> io::Result { + fn listen(listener: L) -> io::Result { listener.to_listener() } diff --git a/src/listener/unix_listener.rs b/src/listener/unix_listener.rs index d99a21d30..ddab3a025 100644 --- a/src/listener/unix_listener.rs +++ b/src/listener/unix_listener.rs @@ -18,14 +18,14 @@ use async_std::{io, task}; /// /// This is currently crate-visible only, and tide users are expected /// to create these through [ToListener](crate::ToListener) conversions. -pub struct UnixListener { +pub struct UnixListener { path: Option, listener: Option, - server: Option>, + server: Option, info: Option, } -impl UnixListener { +impl<'server> UnixListener { pub fn from_path(path: impl Into) -> Self { Self { path: Some(path.into()), @@ -45,7 +45,7 @@ impl UnixListener { } } -fn handle_unix(app: Server, stream: UnixStream) { +fn handle_unix<'listener>(app: Server, stream: UnixStream) { task::spawn(async move { let local_addr = unix_socket_addr_to_string(stream.local_addr()); let peer_addr = unix_socket_addr_to_string(stream.peer_addr()); @@ -63,11 +63,8 @@ fn handle_unix(app: Server, stream: } #[async_trait::async_trait] -impl Listener for UnixListener -where - State: Clone + Send + Sync + 'static, -{ - async fn bind(&mut self, server: Server) -> io::Result<()> { +impl Listener for UnixListener { + async fn bind(&mut self, server: Server) -> io::Result<()> { assert!(self.server.is_none(), "`bind` should only be called once"); self.server = Some(server); @@ -124,7 +121,7 @@ where } } -impl fmt::Debug for UnixListener { +impl fmt::Debug for UnixListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("UnixListener") .field("listener", &self.listener) @@ -141,7 +138,7 @@ impl fmt::Debug for UnixListener { } } -impl Display for UnixListener { +impl Display for UnixListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match &self.listener { Some(listener) => { diff --git a/src/request.rs b/src/request.rs index d69c777d4..8a4fa95be 100644 --- a/src/request.rs +++ b/src/request.rs @@ -4,6 +4,7 @@ use routefinder::Captures; use std::ops::Index; use std::pin::Pin; +use std::sync::Arc; #[cfg(feature = "cookies")] use crate::cookies::CookieData; @@ -12,7 +13,7 @@ use crate::http::cookies::Cookie; use crate::http::format_err; use crate::http::headers::{self, HeaderName, HeaderValues, ToHeaderValues}; use crate::http::{self, Body, Method, Mime, StatusCode, Url, Version}; -use crate::Response; +use crate::{Response, State}; pin_project_lite::pin_project! { /// An HTTP request. @@ -24,6 +25,7 @@ pin_project_lite::pin_project! { /// communication between middleware and endpoints. #[derive(Debug)] pub struct Request { + pub(crate) app_state: Arc, #[pin] pub(crate) req: http::Request, pub(crate) route_params: Vec>, @@ -36,20 +38,29 @@ impl Request { req: http_types::Request, route_params: Vec>, ) -> Self { - Self { req, route_params } + Self { + app_state: Arc::new(State::default()), + req, + route_params, + } } /// Create a new `Request`. - pub(crate) fn with_state( - state: S, + pub(crate) fn with_state( + state: Arc, req: http_types::Request, route_params: Vec>, ) -> Self { let mut req = Request::new(req, route_params); - req.set_ext::(state); + req.app_state = state; req } + /// Returns the current app state + pub fn state(&self) -> &T { + &self.app_state.get::().unwrap() + } + /// Access the request's HTTP method. /// /// # Examples diff --git a/src/route.rs b/src/route.rs index 1374fb813..ff08bb960 100644 --- a/src/route.rs +++ b/src/route.rs @@ -121,10 +121,7 @@ impl<'a> Route<'a> { /// ``` /// /// [`Server`]: struct.Server.html - pub fn nest(&mut self, service: crate::Server) -> &mut Self - where - InnerState: Clone + Send + Sync + 'static, - { + pub fn nest(&mut self, service: crate::Server) -> &mut Self { let prefix = self.prefix; self.prefix = true; @@ -295,6 +292,7 @@ where let crate::Request { mut req, route_params, + app_state, } = req; let rest = route_params @@ -305,6 +303,12 @@ where req.url_mut().set_path(rest); - self.0.call(crate::Request { req, route_params }).await + self.0 + .call(crate::Request { + req, + route_params, + app_state, + }) + .await } } diff --git a/src/security/cors.rs b/src/security/cors.rs index 78e65b1d5..2bdf2f961 100644 --- a/src/security/cors.rs +++ b/src/security/cors.rs @@ -282,7 +282,7 @@ mod test { .unwrap() } - fn app() -> crate::Server<()> { + fn app() -> crate::Server { let mut app = crate::Server::new(); app.at(ENDPOINT).get(|_| async { Ok("Hello World") }); diff --git a/src/server.rs b/src/server.rs index e62e658f8..55b5c5b0c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -9,6 +9,7 @@ use crate::listener::{Listener, ToListener}; use crate::log; use crate::middleware::{Middleware, Next}; use crate::router::{Router, Selection}; +use crate::state::State; use crate::{Endpoint, Request, Route}; /// An HTTP server. @@ -26,9 +27,9 @@ use crate::{Endpoint, Request, Route}; /// - Middleware extends the base Tide framework with additional request or /// response processing, such as compression, default headers, or logging. To /// add middleware to an app, use the [`Server::with`] method. -pub struct Server { +pub struct Server { router: Arc, - state: State, + state: Arc, /// Holds the middleware stack. /// /// Note(Fishrock123): We do actually want this structure. @@ -40,7 +41,7 @@ pub struct Server { middleware: Arc>>, } -impl Server<()> { +impl Server { /// Create a new Tide server. /// /// # Examples @@ -61,16 +62,13 @@ impl Server<()> { } } -impl Default for Server<()> { +impl Default for Server { fn default() -> Self { Self::new() } } -impl Server -where - State: Clone + Send + Sync + 'static, -{ +impl Server { /// Create a new Tide server with shared application scoped state. /// /// Application scoped state is useful for storing items @@ -81,7 +79,7 @@ where /// # use async_std::task::block_on; /// # fn main() -> Result<(), std::io::Error> { block_on(async { /// # - /// use tide::{Request, RequestState}; + /// use tide::{Request}; /// /// /// The shared application state. /// #[derive(Clone)] @@ -94,29 +92,23 @@ where /// name: "Nori".to_string() /// }; /// - /// impl RequestState for Request { - /// fn state(&self) -> &State { - /// self.ext::().unwrap() - /// } - /// } - /// /// // Initialize the application with state. /// let mut app = tide::with_state(state); /// app.at("/").get(|req: Request| async move { - /// Ok(format!("Hello, {}!", &req.state().name)) + /// Ok(format!("Hello, {}!", &req.state::().name)) /// }); /// app.listen("127.0.0.1:8080").await?; /// # /// # Ok(()) }) } /// ``` - pub fn with_state(state: State) -> Self { + pub fn with_state(state: S) -> Self { Self { router: Arc::new(Router::new()), middleware: Arc::new(vec![ #[cfg(feature = "cookies")] Arc::new(cookies::CookiesMiddleware::new()), ]), - state, + state: Arc::new(State::with(state)), } } @@ -206,7 +198,7 @@ where /// # /// # Ok(()) }) } /// ``` - pub async fn listen>(self, listener: L) -> io::Result<()> { + pub async fn listen(self, listener: L) -> io::Result<()> { let mut listener = listener.to_listener()?; listener.bind(self).await?; for info in listener.info().iter() { @@ -245,10 +237,7 @@ where /// # /// # Ok(()) }) } /// ``` - pub async fn bind>( - self, - listener: L, - ) -> io::Result<>::Listener> { + pub async fn bind(self, listener: L) -> io::Result<::Listener> { let mut listener = listener.to_listener()?; listener.bind(self).await?; Ok(listener) @@ -283,18 +272,13 @@ where Res: From, { let req = req.into(); - let Self { - router, - state, - middleware, - } = self.clone(); let method = req.method().to_owned(); - let Selection { endpoint, params } = router.route(req.url().path(), method); + let Selection { endpoint, params } = self.router.route(req.url().path(), method); let route_params = vec![params]; - let req = Request::with_state(state, req, route_params); + let req = Request::with_state(self.state.clone(), req, route_params); - let next = Next::new(endpoint, middleware); + let next = Next::new(endpoint, self.middleware.clone()); let res = next.run(req).await; let res: http_types::Response = res.into(); Ok(res.into()) @@ -311,18 +295,12 @@ where /// admin.at("/").get(|_| async { Ok("nested app with cloned state") }); /// app.at("/").nest(admin); /// ``` - pub fn state(&self) -> &State { + pub fn state(&self) -> &Arc { &self.state } } -impl std::fmt::Debug for Server { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Server").finish() - } -} - -impl Clone for Server { +impl Clone for Server { fn clone(&self) -> Self { Self { router: self.router.clone(), @@ -332,8 +310,14 @@ impl Clone for Server { } } +impl std::fmt::Debug for Server { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Server").finish() + } +} + #[async_trait::async_trait] -impl Endpoint for Server { +impl Endpoint for Server { async fn call(&self, req: Request) -> crate::Result { let Request { req, @@ -342,27 +326,18 @@ impl Endpoint for Server } = req; let path = req.url().path().to_owned(); let method = req.method().to_owned(); - let router = self.router.clone(); - let middleware = self.middleware.clone(); - let state = self.state.clone(); - let Selection { endpoint, params } = router.route(&path, method); + let Selection { endpoint, params } = self.router.route(&path, method); route_params.push(params); - let req = Request::with_state(state, req, route_params); + let req = Request::with_state(self.state.clone(), req, route_params); - let next = Next::new(endpoint, middleware); + let next = Next::new(endpoint, self.middleware.clone()); Ok(next.run(req).await) } } -/// Request extension trait that returns a reference to the State -pub trait RequestState { - /// Extends the Request to be able to return a - fn state(&self) -> &State; -} - #[crate::utils::async_trait] -impl http_client::HttpClient for Server { +impl http_client::HttpClient for Server { async fn send(&self, req: crate::http::Request) -> crate::http::Result { self.respond(req).await } diff --git a/src/sessions/middleware.rs b/src/sessions/middleware.rs index 90de8f28a..2fd01e3c2 100644 --- a/src/sessions/middleware.rs +++ b/src/sessions/middleware.rs @@ -74,10 +74,9 @@ impl std::fmt::Debug for SessionMiddleware { } #[async_trait] -impl Middleware for SessionMiddleware +impl Middleware for SessionMiddleware where Store: SessionStore, - State: Clone + Send + Sync + 'static, { async fn handle(&self, mut request: Request, next: Next) -> crate::Result { let cookie = request.cookie(&self.cookie_name); diff --git a/src/sse/endpoint.rs b/src/sse/endpoint.rs index 929e7859a..c8c3dc822 100644 --- a/src/sse/endpoint.rs +++ b/src/sse/endpoint.rs @@ -11,34 +11,29 @@ use std::marker::PhantomData; use std::sync::Arc; /// Create an endpoint that can handle SSE connections. -pub fn endpoint(handler: F) -> SseEndpoint +pub fn endpoint(handler: F) -> SseEndpoint where - State: Clone + Send + Sync + 'static, F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { SseEndpoint { handler: Arc::new(handler), - __state: PhantomData, } } /// An endpoint that can handle SSE connections. #[derive(Debug)] -pub struct SseEndpoint +pub struct SseEndpoint where - State: Clone + Send + Sync + 'static, F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { handler: Arc, - __state: PhantomData, } #[async_trait::async_trait] -impl Endpoint for SseEndpoint +impl Endpoint for SseEndpoint where - State: Clone + Send + Sync + 'static, F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { diff --git a/src/sse/upgrade.rs b/src/sse/upgrade.rs index 4171d96b7..5aa1287a9 100644 --- a/src/sse/upgrade.rs +++ b/src/sse/upgrade.rs @@ -9,9 +9,8 @@ use async_std::io::BufReader; use async_std::task; /// Upgrade an existing HTTP connection to an SSE connection. -pub fn upgrade(req: Request, handler: F) -> Response +pub fn upgrade(req: Request, handler: F) -> Response where - State: Clone + Send + Sync + 'static, F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 000000000..00022df25 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,76 @@ +// Originally from https://github.com/http-rs/http-types/blob/main/src/extensions.rs +// +// Implementation is based on +// - https://github.com/trillium-rs/trillium/blob/main/http/src/state_set.rs +// - https://github.com/hyperium/http/blob/master/src/extensions.rs +// - https://github.com/kardeiz/type-map/blob/master/src/lib.rs +use std::{ + any::{Any, TypeId}, + hash::{BuildHasherDefault, Hasher}, +}; + +use hashbrown::HashMap; + +/// Store and retrieve values by +/// [`TypeId`](https://doc.rust-lang.org/std/any/struct.TypeId.html). This +/// allows storing arbitrary data that implements `Sync + Send + +/// 'static`. +#[derive(Default, Debug)] +pub struct State(HashMap, BuildHasherDefault>); + +// With TypeIds as keys, there's no need to hash them. So we simply use an identy hasher. +#[derive(Default)] +struct IdHasher(u64); + +impl Hasher for IdHasher { + fn write(&mut self, _: &[u8]) { + unreachable!("TypeId calls write_u64"); + } + + #[inline] + fn write_u64(&mut self, id: u64) { + self.0 = id; + } + + #[inline] + fn finish(&self) -> u64 { + self.0 + } +} + +impl State { + /// Create an empty `StateSet`. + pub fn new() -> Self { + Self::default() + } + + /// Create a `State` with a default inserted value. + pub fn with(val: S) -> Self { + let mut state = State::new(); + state.insert(val); + state + } + + /// Insert a value into this `State`. + /// + /// If a value of this type already exists, it will be returned. + pub fn insert(&mut self, val: T) -> Option { + self.0 + .insert(TypeId::of::(), Box::new(val)) + .and_then(|boxed| (boxed as Box).downcast().ok().map(|boxed| *boxed)) + } + + /// Get a reference to a value previously inserted on this `State`. + pub fn get(&self) -> Option<&T> { + self.0 + .get(&TypeId::of::()) + .and_then(|boxed| (&**boxed as &(dyn Any)).downcast_ref()) + } + + /// Get a mutable reference to a value previously inserted on this `State`. + pub fn get_mut(&mut self) -> Option<&mut T> { + self.0 + .get_mut(&TypeId::of::()) + .and_then(|boxed| (&mut **boxed as &mut (dyn Any)).downcast_mut()) + } +} diff --git a/tests/nested.rs b/tests/nested.rs index fb37301f1..106cab7c5 100644 --- a/tests/nested.rs +++ b/tests/nested.rs @@ -1,6 +1,6 @@ mod test_utils; use test_utils::ServerTestingExt; -use tide::{Request, RequestState}; +use tide::Request; #[async_std::test] async fn nested() -> tide::Result<()> { @@ -48,7 +48,7 @@ async fn nested_middleware() -> tide::Result<()> { Ok(()) } -#[derive(Clone)] +#[derive(Clone, Debug)] struct Num(i32); #[async_std::test] @@ -56,7 +56,8 @@ async fn nested_with_different_state() -> tide::Result<()> { let mut outer = tide::new(); let mut inner = tide::with_state(Num(42)); inner.at("/").get(|req: Request| async move { - let num = req.state().0; + let num = req.state::().0; + println!("{:?}", req.state::()); Ok(format!("the number is {}", num)) }); outer.at("/").get(|_| async { Ok("Hello, world!") }); @@ -66,9 +67,3 @@ async fn nested_with_different_state() -> tide::Result<()> { assert_eq!(outer.get("/").recv_string().await?, "Hello, world!"); Ok(()) } - -impl RequestState for Request { - fn state(&self) -> &Num { - self.ext::().unwrap() - } -} diff --git a/tests/serve_dir.rs b/tests/serve_dir.rs index d05cce1b8..62f09854d 100644 --- a/tests/serve_dir.rs +++ b/tests/serve_dir.rs @@ -7,7 +7,7 @@ fn api() -> Box { Box::new(|_| async { Ok("api") }) } -fn app(tempdir: &tempfile::TempDir) -> Result> { +fn app(tempdir: &tempfile::TempDir) -> Result { let static_dir = tempdir.path().join("static"); fs::create_dir(&static_dir)?; diff --git a/tests/test_utils.rs b/tests/test_utils.rs index 9ea19d0f5..0422a287d 100644 --- a/tests/test_utils.rs +++ b/tests/test_utils.rs @@ -60,7 +60,7 @@ pub trait ServerTestingExt { } } -impl ServerTestingExt for tide::Server { +impl ServerTestingExt for tide::Server { fn client(&self) -> Client { let config = Config::new() .set_http_client(self.clone())