diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index e3ad56ae0..7788904fd 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -54,10 +54,3 @@ contributors the benefit of the doubt and having a sincere willingness to admit that you *might* be wrong is critical for any successful open collaboration. Don't be a bad actor. - -## Developer Certificate of Origin -All contributors must read and agree to the [Developer Certificate of -Origin (DCO)](../CERTIFICATE). - -The DCO allows us to accept contributions from people to the project, similarly -to how a license allows us to distribute our code. diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index fabd2c8c2..000000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,32 +0,0 @@ ---- -name: Bug Report -about: Create a report to help us improve ---- - -## Bug Report - -#### Current Behavior - - - -#### Code/Gist - - - -#### Expected behavior/code - - - -#### Environment - -- Rust toolchain version(s): -- OS: [e.g. OSX 10.13.4, Windows 10] - -#### Possible Solution - - -#### Additional context/Screenshots - \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index a391c04b0..000000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,16 +0,0 @@ ---- -name: Feature Request -about: Suggest an idea for this project ---- - -## Feature Request - -## Detailed Description - - -## Context - - - -## Possible Implementation - diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md deleted file mode 100644 index 895e2abba..000000000 --- a/.github/ISSUE_TEMPLATE/question.md +++ /dev/null @@ -1,13 +0,0 @@ ---- -name: Question -about: Ask any question about the Tide framework ---- - - -## Question - - - -## Additional context - - diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md deleted file mode 100644 index 521c25269..000000000 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ /dev/null @@ -1,28 +0,0 @@ - - -## Description - - -## Motivation and Context - - - -## How Has This Been Tested? - - - - -## Types of changes - -- [ ] Bug fix (non-breaking change which fixes an issue) -- [ ] New feature (non-breaking change which adds functionality) -- [ ] Breaking change (fix or feature that would cause existing functionality to change) - -## Checklist: - - -- [ ] My change requires a change to the documentation. -- [ ] I have updated the documentation accordingly. -- [ ] I have read the [CONTRIBUTING](https://github.com/rustasync/tide/blob/master/.github/CONTRIBUTING.md) document. -- [ ] I have added tests to cover my changes. -- [ ] All new and existing tests passed. diff --git a/.github/stale.yml b/.github/stale.yml deleted file mode 100644 index b8550d0fb..000000000 --- a/.github/stale.yml +++ /dev/null @@ -1,17 +0,0 @@ -# Configuration for probot-stale - https://github.com/probot/stale - -daysUntilStale: 90 -daysUntilClose: 7 -exemptLabels: - - pinned - - security -exemptProjects: false -exemptMilestones: false -staleLabel: wontfix -markComment: > - This issue has been automatically marked as stale because it has not had - recent activity. It will be closed if no further activity occurs. Thank you - for your contributions. -unmarkComment: false -closeComment: false -limitPerRun: 30 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b0b938d31..f21dcee02 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -78,7 +78,7 @@ jobs: - uses: actions-rs/toolchain@v1 with: - toolchain: stable + toolchain: nightly override: true - name: setup diff --git a/Cargo.toml b/Cargo.toml index 394e64f11..4ea9eaea2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "tide" version = "0.13.0" -description = "Serve the web – HTTP server framework" +description = "A minimal and pragmatic Rust web application framework built for rapid development" authors = [ "Aaron Turon ", "Yoshua Wuyts ", @@ -26,7 +26,7 @@ rustdoc-args = ["--cfg", "feature=\"docs\""] [features] default = ["h1-server", "logger", "sessions"] h1-server = ["async-h1"] -logger = [] +logger = ["femme"] docs = ["unstable"] sessions = ["async-session"] unstable = [] @@ -39,10 +39,12 @@ async-session = { version = "2.0.0", optional = true } async-sse = "4.0.0" async-std = { version = "1.6.0", features = ["unstable"] } async-trait = "0.1.36" -femme = "2.0.1" +femme = { version = "2.0.1", optional = true } futures-util = "0.3.5" -http-types = "2.2.1" +http-types = "2.4.0" +http-client = { version = "6.0.0", default-features = false } kv-log-macro = "1.0.4" +log = { version = "0.4.13", features = ["kv_unstable_std"] } pin-project-lite = "0.1.7" route-recognizer = "0.2.0" serde = "1.0.102" @@ -56,7 +58,7 @@ lazy_static = "1.4.0" logtest = "2.0.0" portpicker = "0.1.0" serde = { version = "1.0.102", features = ["derive"] } -surf = { version = "2.0.0-alpha.3", default-features = false, features = ["h1-client"] } +surf = { version = "2.0.0-alpha.7", default-features = false, features = ["h1-client"] } tempfile = "3.1.0" [[test]] diff --git a/README.md b/README.md index d7f74e4c0..c9806b68b 100644 --- a/README.md +++ b/README.md @@ -41,11 +41,16 @@ -A modular web framework built around async/await +Tide is a minimal and pragmatic Rust web application framework built for +rapid development. It comes with a robust set of features that make building +async web applications and APIs easier and more fun. ## Getting started -Add two dependencies to your project's `Cargo.toml` file: `tide` itself, and `async-std` with the feature `attributes` enabled: +In order to build a web app in Rust you need an HTTP server, and an async +runtime. After running `cargo init` add the following lines to your +`Cargo.toml` file: + ```toml # Example, use the version numbers you need tide = "0.13.0" @@ -54,25 +59,43 @@ async-std = { version = "1.6.0", features = ["attributes"] } ## Examples -**Hello World** +Create an HTTP server that receives a JSON body, validates it, and responds +with a confirmation message. ```rust +use tide::Request; +use tide::prelude::*; + +#[derive(Debug, Deserialize)] +struct Animal { + name: String, + legs: u8, +} + #[async_std::main] -async fn main() -> Result<(), std::io::Error> { - tide::log::start(); +async fn main() -> tide::Result<()> { let mut app = tide::new(); - app.at("/").get(|_| async { Ok("Hello, world!") }); + app.at("/orders/shoes").post(order_shoes); app.listen("127.0.0.1:8080").await?; Ok(()) } + +async fn order_shoes(mut req: Request, _state: tide::State<()>) -> tide::Result { + let Animal { name, legs } = req.body_json().await?; + Ok(format!("Hello, {}! I've put in an order for {} shoes", name, legs).into()) +} ``` -To try [the included examples](https://github.com/http-rs/tide/tree/main/examples), check out this repository and run ```sh -$ cargo run --example # shows a list of available examples -$ cargo run --example hello +$ curl localhost:8000/orders/shoes -d '{ "name": "Chashu", "legs": 4 }' +Hello, Chashu! I've put in an order for 4 shoes + +$ curl localhost:8000/orders/shoes -d '{ "name": "Mary Millipede", "legs": 750 }' +number too large to fit in target type ``` +See more examples in the [examples](https://github.com/http-rs/tide/tree/main/examples) directory. + ## Tide's design: - [Rising Tide: building a modular web framework in the open](https://rustasync.github.io/team/2018/09/11/tide.html) - [Routing and extraction in Tide: a first sketch](https://rustasync.github.io/team/2018/10/16/tide-routing.html) @@ -104,6 +127,7 @@ team. Use at your own risk. * [tide-trace](https://github.com/no9/tide-trace) * [tide-tracing](https://github.com/ethanboxx/tide-tracing) * [opentelemetry-tide](https://github.com/asaaki/opentelemetry-tide) +* [driftwood](https://github.com/jbr/driftwood) http logging middleware ### Session Stores * [async-redis-session](https://github.com/jbr/async-redis-session) diff --git a/examples/catflap.rs b/examples/catflap.rs index 8532f76a1..32e2a6caa 100644 --- a/examples/catflap.rs +++ b/examples/catflap.rs @@ -4,7 +4,7 @@ async fn main() -> Result<(), std::io::Error> { use std::{env, net::TcpListener, os::unix::io::FromRawFd}; tide::log::start(); let mut app = tide::new(); - app.at("/").get(|_| async { Ok(CHANGE_THIS_TEXT) }); + app.at("/").get(|_, _| async { Ok(CHANGE_THIS_TEXT) }); const CHANGE_THIS_TEXT: &str = "hello world!"; diff --git a/examples/chunked.rs b/examples/chunked.rs index 601d4bafd..44a911736 100644 --- a/examples/chunked.rs +++ b/examples/chunked.rs @@ -4,7 +4,7 @@ use tide::Body; async fn main() -> Result<(), std::io::Error> { tide::log::start(); let mut app = tide::new(); - app.at("/").get(|_| async { + app.at("/").get(|_, _| async { // File sends are chunked by default. Ok(Body::from_file(file!()).await?) }); diff --git a/examples/concurrent_listeners.rs b/examples/concurrent_listeners.rs index 10213f578..0a8b71f43 100644 --- a/examples/concurrent_listeners.rs +++ b/examples/concurrent_listeners.rs @@ -5,10 +5,10 @@ async fn main() -> Result<(), std::io::Error> { tide::log::start(); let mut app = tide::new(); - app.at("/").get(|request: Request<_>| async move { + app.at("/").get(|req: Request, _| async move { Ok(format!( "Hi! You reached this app through: {}", - request.local_addr().unwrap_or("an unknown port") + req.local_addr().unwrap_or("an unknown port") )) }); diff --git a/examples/cookies.rs b/examples/cookies.rs index 1ae154d57..cec3883f3 100644 --- a/examples/cookies.rs +++ b/examples/cookies.rs @@ -1,19 +1,19 @@ use tide::http::Cookie; -use tide::{Request, Response, StatusCode}; +use tide::{Request, Response, State, StatusCode}; /// Tide will use the the `Cookies`'s `Extract` implementation to build this parameter. /// -async fn retrieve_cookie(req: Request<()>) -> tide::Result { +async fn retrieve_cookie(req: Request, _: State<()>) -> tide::Result { Ok(format!("hello cookies: {:?}", req.cookie("hello").unwrap())) } -async fn insert_cookie(_req: Request<()>) -> tide::Result { +async fn insert_cookie(_req: Request, _: State<()>) -> tide::Result { let mut res = Response::new(StatusCode::Ok); res.insert_cookie(Cookie::new("hello", "world")); Ok(res) } -async fn remove_cookie(_req: Request<()>) -> tide::Result { +async fn remove_cookie(_req: Request, _: State<()>) -> tide::Result { let mut res = Response::new(StatusCode::Ok); res.remove_cookie(Cookie::named("hello")); Ok(res) diff --git a/examples/error_handling.rs b/examples/error_handling.rs index 43bf96fdb..62a5631c8 100644 --- a/examples/error_handling.rs +++ b/examples/error_handling.rs @@ -1,7 +1,7 @@ use std::io::ErrorKind; use tide::utils::After; -use tide::{Body, Request, Response, Result, StatusCode}; +use tide::{Body, Response, Result, StatusCode}; #[async_std::main] async fn main() -> Result<()> { @@ -22,7 +22,7 @@ async fn main() -> Result<()> { })); app.at("/") - .get(|_req: Request<_>| async { Ok(Body::from_file("./does-not-exist").await?) }); + .get(|_, _| async { Ok(Body::from_file("./does-not-exist").await?) }); app.listen("127.0.0.1:8080").await?; diff --git a/examples/fib.rs b/examples/fib.rs index 39e122b2b..e57fd8b10 100644 --- a/examples/fib.rs +++ b/examples/fib.rs @@ -1,4 +1,4 @@ -use tide::Request; +use tide::{Request, State}; fn fib(n: usize) -> usize { if n == 0 || n == 1 { @@ -8,9 +8,9 @@ fn fib(n: usize) -> usize { } } -async fn fibsum(req: Request<()>) -> tide::Result { +async fn fibsum(req: Request, _state: State<()>) -> tide::Result { use std::time::Instant; - let n: usize = req.param("n").unwrap_or(0); + let n: usize = req.param("n")?.parse().unwrap_or(0); // Start a stopwatch let start = Instant::now(); // Compute the nth number in the fibonacci sequence diff --git a/examples/graphql.rs b/examples/graphql.rs index c771afb49..c9505107e 100644 --- a/examples/graphql.rs +++ b/examples/graphql.rs @@ -74,9 +74,9 @@ lazy_static! { static ref SCHEMA: Schema = Schema::new(QueryRoot {}, MutationRoot {}); } -async fn handle_graphql(mut request: Request) -> tide::Result { +async fn handle_graphql(mut request: Request, state: tide::State) -> tide::Result { let query: GraphQLRequest = request.body_json().await?; - let response = query.execute(&SCHEMA, request.state()); + let response = query.execute(&SCHEMA, &state); let status = if response.is_ok() { StatusCode::Ok } else { @@ -88,7 +88,10 @@ async fn handle_graphql(mut request: Request) -> tide::Result { .build()) } -async fn handle_graphiql(_: Request) -> tide::Result> { +async fn handle_graphiql( + _: Request, + _state: tide::State, +) -> tide::Result> { Ok(Response::builder(200) .body(graphiql::graphiql_source("/graphql")) .content_type(mime::HTML)) diff --git a/examples/hello.rs b/examples/hello.rs index f8dee0b2d..7a012387f 100644 --- a/examples/hello.rs +++ b/examples/hello.rs @@ -2,7 +2,7 @@ async fn main() -> Result<(), std::io::Error> { tide::log::start(); let mut app = tide::new(); - app.at("/").get(|_| async { Ok("Hello, world!") }); + app.at("/").get(|_, _| async { Ok("Hello, world!") }); app.listen("127.0.0.1:8080").await?; Ok(()) } diff --git a/examples/json.rs b/examples/json.rs index 59f0d2fc2..8603d5b18 100644 --- a/examples/json.rs +++ b/examples/json.rs @@ -12,7 +12,7 @@ async fn main() -> tide::Result<()> { tide::log::start(); let mut app = tide::new(); - app.at("/submit").post(|mut req: Request<()>| async move { + app.at("/submit").post(|mut req: Request, _| async move { let cat: Cat = req.body_json().await?; println!("cat name: {}", cat.name); @@ -23,7 +23,7 @@ async fn main() -> tide::Result<()> { Ok(Body::from_json(&cat)?) }); - app.at("/animals").get(|_| async { + app.at("/animals").get(|_, _| async { Ok(json!({ "meta": { "count": 2 }, "animals": [ diff --git a/examples/middleware.rs b/examples/middleware.rs index a7488a59c..289ed0b5f 100644 --- a/examples/middleware.rs +++ b/examples/middleware.rs @@ -22,18 +22,21 @@ impl UserDatabase { } } +type State = tide::State; + // This is an example of a function middleware that uses the // application state. Because it depends on a specific request state, // it would likely be closely tied to a specific application fn user_loader<'a>( - mut request: Request, + mut request: Request, + state: State, next: Next<'a, UserDatabase>, ) -> Pin + Send + 'a>> { Box::pin(async { - if let Some(user) = request.state().find_user().await { + if let Some(user) = state.find_user().await { tide::log::trace!("user loaded", {user: user.name}); request.set_ext(user); - Ok(next.run(request).await) + Ok(next.run(request, state).await) // this middleware only needs to run before the endpoint, so // it just passes through the result of Next } else { @@ -61,13 +64,20 @@ impl RequestCounterMiddleware { struct RequestCount(usize); #[tide::utils::async_trait] -impl Middleware for RequestCounterMiddleware { - async fn handle(&self, mut req: Request, next: Next<'_, State>) -> Result { +impl Middleware + for RequestCounterMiddleware +{ + async fn handle( + &self, + mut req: Request, + state: tide::State, + next: Next<'_, ServerState>, + ) -> Result { let count = self.requests_counted.fetch_add(1, Ordering::Relaxed); tide::log::trace!("request counter", { count: count }); req.set_ext(RequestCount(count)); - let mut res = next.run(req).await; + let mut res = next.run(req, state).await; res.insert_header("request-number", count.to_string()); Ok(res) @@ -114,12 +124,12 @@ async fn main() -> Result<()> { app.with(user_loader); app.with(RequestCounterMiddleware::new(0)); - app.with(Before(|mut request: Request| async move { - request.set_ext(std::time::Instant::now()); - request + app.with(Before(|mut req: Request, state: State| async move { + req.set_ext(std::time::Instant::now()); + (req, state) })); - app.at("/").get(|req: Request<_>| async move { + app.at("/").get(|req: Request, _| async move { let count: &RequestCount = req.ext().unwrap(); let user: &User = req.ext().unwrap(); diff --git a/examples/nested.rs b/examples/nested.rs index aabe131e6..c51e3432f 100644 --- a/examples/nested.rs +++ b/examples/nested.rs @@ -2,11 +2,12 @@ async fn main() -> Result<(), std::io::Error> { tide::log::start(); let mut app = tide::new(); - app.at("/").get(|_| async { Ok("Root") }); + app.at("/").get(|_, _| async { Ok("Root") }); app.at("/api").nest({ let mut api = tide::new(); - api.at("/hello").get(|_| async { Ok("Hello, world") }); - api.at("/goodbye").get(|_| async { Ok("Goodbye, world") }); + api.at("/hello").get(|_, _| async { Ok("Hello, world") }); + api.at("/goodbye") + .get(|_, _| async { Ok("Goodbye, world") }); api }); app.listen("127.0.0.1:8080").await?; diff --git a/examples/redirect.rs b/examples/redirect.rs index 3647330a6..177e8aa4a 100644 --- a/examples/redirect.rs +++ b/examples/redirect.rs @@ -4,13 +4,13 @@ use tide::{Redirect, Response, StatusCode}; async fn main() -> Result<(), std::io::Error> { tide::log::start(); let mut app = tide::new(); - app.at("/").get(|_| async { Ok("Root") }); + app.at("/").get(|_, _| async { Ok("Root") }); // Redirect hackers to YouTube. app.at("/.env") .get(Redirect::new("https://www.youtube.com/watch?v=dQw4w9WgXcQ")); - app.at("/users-page").get(|_| async { + app.at("/users-page").get(|_, _| async { Ok(if signed_in() { Response::new(StatusCode::Ok) } else { diff --git a/examples/sessions.rs b/examples/sessions.rs index 16779dc7a..2152b85df 100644 --- a/examples/sessions.rs +++ b/examples/sessions.rs @@ -14,21 +14,21 @@ async fn main() -> Result<(), std::io::Error> { )); app.with(tide::utils::Before( - |mut request: tide::Request<()>| async move { - let session = request.session_mut(); + |mut req: tide::Request, state| async move { + let session = req.session_mut(); let visits: usize = session.get("visits").unwrap_or_default(); session.insert("visits", visits + 1).unwrap(); - request + (req, state) }, )); - app.at("/").get(|req: tide::Request<()>| async move { + app.at("/").get(|req: tide::Request, _| async move { let visits: usize = req.session().get("visits").unwrap(); Ok(format!("you have visited this website {} times", visits)) }); app.at("/reset") - .get(|mut req: tide::Request<()>| async move { + .get(|mut req: tide::Request, _| async move { req.session_mut().destroy(); Ok(tide::Redirect::new("/")) }); diff --git a/examples/sse.rs b/examples/sse.rs index 5d78fde83..7b7d4f877 100644 --- a/examples/sse.rs +++ b/examples/sse.rs @@ -4,11 +4,12 @@ use tide::sse; async fn main() -> Result<(), std::io::Error> { tide::log::start(); let mut app = tide::new(); - app.at("/sse").get(sse::endpoint(|_req, sender| async move { - sender.send("fruit", "banana", None).await?; - sender.send("fruit", "apple", None).await?; - Ok(()) - })); + app.at("/sse") + .get(sse::endpoint(|_req, _state, sender| async move { + sender.send("fruit", "banana", None).await?; + sender.send("fruit", "apple", None).await?; + Ok(()) + })); app.listen("localhost:8080").await?; Ok(()) } diff --git a/examples/static_file.rs b/examples/static_file.rs index 6439ed8d4..fe1c6c5a5 100644 --- a/examples/static_file.rs +++ b/examples/static_file.rs @@ -2,7 +2,7 @@ async fn main() -> Result<(), std::io::Error> { tide::log::start(); let mut app = tide::new(); - app.at("/").get(|_| async { Ok("visit /src/*") }); + app.at("/").get(|_, _| async { Ok("visit /src/*") }); app.at("/src").serve_dir("src/")?; app.listen("127.0.0.1:8080").await?; Ok(()) diff --git a/examples/upload.rs b/examples/upload.rs index 5ac0e1d06..ed476c15b 100644 --- a/examples/upload.rs +++ b/examples/upload.rs @@ -24,6 +24,8 @@ impl TempDirState { } } +type State = tide::State; + #[async_std::main] async fn main() -> Result<(), IoError> { tide::log::start(); @@ -35,9 +37,9 @@ async fn main() -> Result<(), IoError> { // $ curl localhost:8080/README.md # this reads the file from the same temp directory app.at(":file") - .put(|req: Request| async move { - let path: String = req.param("file")?; - let fs_path = req.state().path().join(path); + .put(|req: Request, state: State| async move { + let path = req.param("file")?; + let fs_path = state.path().join(path); let file = OpenOptions::new() .create(true) @@ -54,9 +56,9 @@ async fn main() -> Result<(), IoError> { Ok(json!({ "bytes": bytes_written })) }) - .get(|req: Request| async move { - let path: String = req.param("file")?; - let fs_path = req.state().path().join(path); + .get(|req: Request, state: State| async move { + let path = req.param("file")?; + let fs_path = state.path().join(path); if let Ok(body) = Body::from_file(fs_path).await { Ok(body.into()) diff --git a/src/cookies/middleware.rs b/src/cookies/middleware.rs index cc9473cf7..e5cc2359f 100644 --- a/src/cookies/middleware.rs +++ b/src/cookies/middleware.rs @@ -1,5 +1,5 @@ use crate::response::CookieEvent; -use crate::{Middleware, Next, Request}; +use crate::{Middleware, Next, Request, State}; use async_trait::async_trait; use crate::http::cookies::{Cookie, CookieJar, Delta}; @@ -15,10 +15,10 @@ use std::sync::{Arc, RwLock}; /// # use tide::{Request, Response, StatusCode}; /// # use tide::http::cookies::Cookie; /// let mut app = tide::Server::new(); -/// app.at("/get").get(|req: Request<()>| async move { +/// app.at("/get").get(|req: Request, _| async move { /// Ok(req.cookie("testCookie").unwrap().value().to_string()) /// }); -/// app.at("/set").get(|_| async { +/// app.at("/set").get(|_, _| async { /// let mut res = Response::new(StatusCode::Ok); /// res.insert_cookie(Cookie::new("testCookie", "NewCookieValue")); /// Ok(res) @@ -35,19 +35,27 @@ impl CookiesMiddleware { } #[async_trait] -impl Middleware for CookiesMiddleware { - async fn handle(&self, mut ctx: Request, next: Next<'_, State>) -> crate::Result { - let cookie_jar = if let Some(cookie_data) = ctx.ext::() { +impl Middleware for CookiesMiddleware +where + ServerState: Clone + Send + Sync + 'static, +{ + async fn handle( + &self, + mut req: Request, + state: State, + next: Next<'_, ServerState>, + ) -> crate::Result { + let cookie_jar = if let Some(cookie_data) = req.ext::() { cookie_data.content.clone() } else { - let cookie_data = CookieData::from_request(&ctx); + let cookie_data = CookieData::from_request(&req, &state); // no cookie data in ext context, so we try to create it let content = cookie_data.content.clone(); - ctx.set_ext(cookie_data); + req.set_ext(cookie_data); content }; - let mut res = next.run(ctx).await; + let mut res = next.run(req, state).await; // Don't do anything if there are no cookies. if res.cookie_events.is_empty() { @@ -112,7 +120,7 @@ impl LazyJar { } impl CookieData { - pub(crate) fn from_request(req: &Request) -> Self { + pub(crate) fn from_request(req: &Request, _state: &S) -> Self { let jar = if let Some(cookie_headers) = req.header(&headers::COOKIE) { let mut jar = CookieJar::new(); for cookie_header in cookie_headers { diff --git a/src/endpoint.rs b/src/endpoint.rs index 0fecf0b9b..a44116aea 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -4,14 +4,14 @@ use async_trait::async_trait; use http_types::Result; use crate::middleware::Next; -use crate::{Middleware, Request, Response}; +use crate::{Middleware, Request, Response, State}; /// An HTTP request handler. /// /// This trait is automatically implemented for `Fn` types, and so is rarely implemented /// directly by Tide users. /// -/// In practice, endpoints are functions that take a `Request` as an argument and +/// In practice, endpoints are functions that take a `Request, state: State` as an argument and /// return a type `T` that implements `Into`. /// /// # Examples @@ -23,7 +23,7 @@ use crate::{Middleware, Request, Response}; /// A simple endpoint that is invoked on a `GET` request and returns a `String`: /// /// ```no_run -/// async fn hello(_req: tide::Request<()>) -> tide::Result { +/// async fn hello(_req: tide::Request, _state: tide::State<()>) -> tide::Result { /// Ok(String::from("hello")) /// } /// @@ -35,7 +35,7 @@ use crate::{Middleware, Request, Response}; /// /// ```no_run /// # use core::future::Future; -/// fn hello(_req: tide::Request<()>) -> impl Future> { +/// fn hello(_req: tide::Request, _state: tide::State<()>) -> impl Future> { /// async_std::future::ready(Ok(String::from("hello"))) /// } /// @@ -45,34 +45,34 @@ use crate::{Middleware, Request, Response}; /// /// Tide routes will also accept endpoints with `Fn` signatures of this form, but using the `async` keyword has better ergonomics. #[async_trait] -pub trait Endpoint: Send + Sync + 'static { +pub trait Endpoint: Send + Sync + 'static { /// Invoke the endpoint within the given context - async fn call(&self, req: Request) -> crate::Result; + async fn call(&self, req: Request, state: State) -> crate::Result; } -pub(crate) type DynEndpoint = dyn Endpoint; +pub(crate) type DynEndpoint = dyn Endpoint; #[async_trait] -impl Endpoint for F +impl Endpoint for F where - State: Clone + Send + Sync + 'static, - F: Send + Sync + 'static + Fn(Request) -> Fut, + ServerState: Clone + Send + Sync + 'static, + F: Send + Sync + 'static + Fn(Request, State) -> Fut, Fut: Future> + Send + 'static, Res: Into + 'static, { - async fn call(&self, req: Request) -> crate::Result { - let fut = (self)(req); + async fn call(&self, req: Request, state: State) -> crate::Result { + let fut = (self)(req, state); let res = fut.await?; Ok(res.into()) } } -pub struct MiddlewareEndpoint { +pub struct MiddlewareEndpoint { endpoint: E, - middleware: Vec>>, + middleware: Vec>>, } -impl Clone for MiddlewareEndpoint { +impl Clone for MiddlewareEndpoint { fn clone(&self) -> Self { Self { endpoint: self.endpoint.clone(), @@ -81,7 +81,7 @@ impl Clone for MiddlewareEndpoint { } } -impl std::fmt::Debug for MiddlewareEndpoint { +impl std::fmt::Debug for MiddlewareEndpoint { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( fmt, @@ -91,12 +91,12 @@ impl std::fmt::Debug for MiddlewareEndpoint { } } -impl MiddlewareEndpoint +impl MiddlewareEndpoint where - State: Clone + Send + Sync + 'static, - E: Endpoint, + ServerState: Clone + Send + Sync + 'static, + E: Endpoint, { - pub fn wrap_with_middleware(ep: E, middleware: &[Arc>]) -> Self { + pub fn wrap_with_middleware(ep: E, middleware: &[Arc>]) -> Self { Self { endpoint: ep, middleware: middleware.to_vec(), @@ -105,16 +105,25 @@ where } #[async_trait] -impl Endpoint for MiddlewareEndpoint +impl Endpoint for MiddlewareEndpoint where - State: Clone + Send + Sync + 'static, - E: Endpoint, + ServerState: Clone + Send + Sync + 'static, + E: Endpoint, { - async fn call(&self, req: Request) -> crate::Result { + async fn call(&self, req: Request, state: State) -> crate::Result { let next = Next { endpoint: &self.endpoint, next_middleware: &self.middleware, }; - Ok(next.run(req).await) + Ok(next.run(req, state).await) + } +} + +#[async_trait] +impl Endpoint + for Box> +{ + async fn call(&self, req: Request, state: State) -> crate::Result { + self.as_ref().call(req, state).await } } diff --git a/src/fs/serve_dir.rs b/src/fs/serve_dir.rs index 1dff4f46c..8f10b04ee 100644 --- a/src/fs/serve_dir.rs +++ b/src/fs/serve_dir.rs @@ -1,4 +1,4 @@ -use crate::log; +use crate::{log, State}; use crate::{Body, Endpoint, Request, Response, Result, StatusCode}; use async_std::path::PathBuf as AsyncPathBuf; @@ -19,11 +19,11 @@ impl ServeDir { } #[async_trait::async_trait] -impl Endpoint for ServeDir +impl Endpoint for ServeDir where - State: Clone + Send + Sync + 'static, + ServerState: Clone + Send + Sync + 'static, { - async fn call(&self, req: Request) -> Result { + async fn call(&self, req: Request, _state: State) -> Result { let path = req.url().path(); let path = path.trim_start_matches(&self.prefix); let path = path.trim_start_matches('/'); @@ -77,11 +77,10 @@ mod test { }) } - fn request(path: &str) -> crate::Request<()> { - let request = crate::http::Request::get( - crate::http::Url::parse(&format!("http://localhost/{}", path)).unwrap(), - ); - crate::Request::new((), request, vec![]) + fn request(path: &str) -> crate::Request { + let url = crate::http::Url::parse(&format!("http://localhost/{}", path)).unwrap(); + let req = crate::http::Request::get(url); + crate::Request::new(req, vec![]) } #[async_std::test] @@ -91,7 +90,7 @@ mod test { let req = request("static/foo"); - let res = serve_dir.call(req).await.unwrap(); + let res = serve_dir.call(req, State::new(())).await.unwrap(); let mut res: crate::http::Response = res.into(); assert_eq!(res.status(), 200); @@ -105,7 +104,7 @@ mod test { let req = request("static/bar"); - let res = serve_dir.call(req).await.unwrap(); + let res = serve_dir.call(req, State::new(())).await.unwrap(); let res: crate::http::Response = res.into(); assert_eq!(res.status(), 404); diff --git a/src/lib.rs b/src/lib.rs index 706d20373..3d46be201 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,197 +1,56 @@ -//! # Serve the web -//! -//! Tide is a friendly HTTP server built for casual Rustaceans and veterans alike. It's completely -//! modular, and built directly for `async/await`. Whether it's a quick webhook, or an L7 load -//! balancer, Tide will make it work. -//! -//! # Features -//! -//! - __Fast:__ Written in Rust, and built on Futures, Tide is incredibly efficient. -//! - __Friendly:__ With thorough documentation, and a complete API, Tide helps cover your every -//! need. -//! - __Minimal:__ With only a few concepts to learn, Tide is easy to pick up and become productive -//! with. +//! Tide is a minimal and pragmatic Rust web application framework built for +//! rapid development. It comes with a robust set of features that make +//! building async web applications and APIs easier and more fun. //! //! # Getting started //! -//! Add two dependencies to your project's `Cargo.toml` file: `tide` itself, and `async-std` with the feature `attributes` enabled: +//! In order to build a web app in Rust you need an HTTP server, and an async +//! runtime. After running `cargo init` add the following lines to your +//! `Cargo.toml` file: +//! //! ```toml //! # Example, use the version numbers you need -//! tide = "0.7.0" -//! async-std = { version = "1.5.0", features = ["attributes"] } +//! tide = "0.13.0" +//! async-std = { version = "1.6.0", features = ["attributes"] } //!``` //! //! # Examples //! -//! __hello world__ -//! ```no_run -//! # use async_std::task::block_on; -//! # fn main() -> Result<(), std::io::Error> { block_on(async { -//! # -//! let mut app = tide::new(); -//! app.at("/").get(|_| async { Ok("Hello, world!") }); -//! app.listen("127.0.0.1:8080").await?; -//! # -//! # Ok(()) }) } -//! ``` +//! Create an HTTP server that receives a JSON body, validates it, and responds with a +//! confirmation message. //! -//! __echo server__ //! ```no_run -//! # use async_std::task::block_on; -//! # fn main() -> Result<(), std::io::Error> { block_on(async { -//! # -//! let mut app = tide::new(); -//! app.at("/").get(|req| async { Ok(req) }); -//! app.listen("127.0.0.1:8080").await?; -//! # -//! # Ok(()) }) } -//! ```` -//! -//! __send and receive json__ -//! ```no_run -//! # use async_std::task::block_on; -//! # fn main() -> Result<(), std::io::Error> { block_on(async { -//! # use tide::{Body, Request, Response}; -//! # -//! #[derive(Debug, serde::Deserialize, serde::Serialize)] -//! struct Counter { count: usize } -//! -//! let mut app = tide::new(); -//! app.at("/").get(|mut req: Request<()>| async move { -//! let mut counter: Counter = req.body_json().await?; -//! println!("count is {}", counter.count); -//! counter.count += 1; -//! let mut res = Response::new(200); -//! res.set_body(Body::from_json(&counter)?); -//! Ok(res) -//! }); -//! app.listen("127.0.0.1:8080").await?; -//! # -//! # Ok(()) }) } -//! ``` +//! use tide::Request; +//! use tide::prelude::*; //! -//! # Concepts -//! -//! ## Request-Response -//! -//! Each Tide endpoint takes a [`Request`] and returns a [`Response`]. Because async functions -//! allow us to wait without blocking, this makes Tide feel similar to synchronous servers. Except -//! it's incredibly efficient. -//! -//! ```txt -//! async fn endpoint(req: Request) -> Result; -//! ``` -//! -//! ## Middleware -//! -//! Middleware wrap each request and response pair, allowing code to be run before the endpoint, -//! and after each endpoint. Additionally each handler can choose to never yield to the endpoint -//! and abort early. This is useful for e.g. authentication middleware. Tide's middleware works -//! like a stack. A simplified example of the logger middleware is something like this: -//! -//! ```ignore -//! async fn log(req: Request, next: Next) -> tide::Result { -//! println!("Incoming request from {} on url {}", req.peer_addr(), req.url()); -//! let res = next().await?; -//! println!("Outgoing response with status {}", res.status()); -//! res +//! #[derive(Debug, Deserialize)] +//! struct Animal { +//! name: String, +//! legs: u8, //! } -//! ``` -//! -//! As a new request comes in, we perform some logic. Then we yield to the next -//! middleware (or endpoint, we don't know when we yield to `next`), and once that's -//! done, we return the Response. We can decide to not yield to `next` at any stage, -//! and abort early. This can then be used in applications using the [`Server::middleware`] -//! method. -//! -//! ## State -//! -//! Middleware often needs to share values with the endpoint. This is done through "request scoped -//! state". Request scoped state is built using a typemap that's available through -//! [`Request::ext`]. //! -//! If the endpoint needs to share values with middleware, response scoped state can be set via -//! [`Response::insert_ext`] and is available through [`Response::ext`]. -//! -//! Application scoped state is used when a complete application needs access to a particular -//! value. Examples of this include: database connections, websocket connections, or -//! network-enabled config. Every `Request` has an inner value that must -//! implement `Send + Sync + Clone`, and can thus freely be shared between requests. -//! -//! By default `tide::new` will use `()` as the shared state. But if you want to -//! create a new app with shared state you can use the [`with_state`] function. -//! -//! ## Extension Traits -//! -//! Sometimes having application and request scoped context can require a bit of setup. There are -//! cases where it'd be nice if things were a little easier. This is why Tide -//! encourages people to write _extension traits_. -//! -//! By using an _extension trait_ you can extend [`Request`] or [`Response`] with more -//! functionality. For example, an authentication package could implement a `user` method on -//! `Request`, to access the authenticated user provided by middleware. -//! -//! Extension traits are written by defining a trait + trait impl for the struct that's being -//! extended: -//! -//! ```no_run -//! # use tide::Request; -//! # -//! pub trait RequestExt { -//! fn bark(&self) -> String; -//! } -//! -//! impl RequestExt for Request { -//! fn bark(&self) -> String { -//! "woof".to_string() -//! } -//! } -//! ``` -//! -//! Tide apps will then have access to the `bark` method on `Request`: -//! -//! ```no_run -//! # use tide::Request; -//! # -//! # pub trait RequestExt { -//! # fn bark(&self) -> String; -//! # } -//! # -//! # impl RequestExt for Request { -//! # fn bark(&self) -> String { -//! # "woof".to_string() -//! # } -//! # } -//! # //! #[async_std::main] -//! async fn main() -> Result<(), std::io::Error> { +//! async fn main() -> tide::Result<()> { //! let mut app = tide::new(); -//! app.at("/").get(|req: Request<()>| async move { Ok(req.bark()) }); -//! app.listen("127.0.0.1:8080").await +//! app.at("/orders/shoes").post(order_shoes); +//! app.listen("127.0.0.1:8080").await?; +//! Ok(()) //! } -//! ``` //! -//! # HTTP Version 1.1 only +//! async fn order_shoes(mut req: Request, _state: tide::State<()>) -> tide::Result { +//! let Animal { name, legs } = req.body_json().await?; +//! Ok(format!("Hello, {}! I've put in an order for {} shoes", name, legs).into()) +//! } +//! ```` //! -//! Tide's default backend currently only supports HTTP/1.1. In order -//! to use nginx as reverse proxy for Tide, your upstream proxy -//! configuration must include this line: +//! ```sh +//! $ curl localhost:8000/orders/shoes -d '{ "name": "Chashu", "legs": 4 }' +//! Hello, Chashu! I've put in an order for 4 shoes //! -//! ```text -//! proxy_http_version 1.1; +//! $ curl localhost:8000/orders/shoes -d '{ "name": "Mary Millipede", "legs": 750 }' +//! number too large to fit in target type //! ``` -//! -//! # API Stability -//! -//! It's still early in Tide's development cycle. While the general shape of Tide might have -//! roughly established, the exact traits and function parameters may change between versions. In -//! practice this means that building your core business on Tide is probably not a wise idea... -//! yet. -//! -//! However we *are* committed to closely following semver, and documenting any and all breaking -//! changes we make. Also as time goes on you may find that fewer and fewer changes occur, until we -//! eventually remove this notice entirely. The goal of Tide is to build a premier HTTP experience -//! for Async Rust. We have a long journey ahead of us. But we're excited you're here with us! +//! See more examples in the [examples](https://github.com/http-rs/tide/tree/main/examples) directory. #![cfg_attr(feature = "docs", feature(doc_cfg))] // #![warn(missing_docs)] @@ -210,6 +69,7 @@ mod request; mod response; mod response_builder; mod route; +mod state; #[cfg(not(feature = "__internal__bench"))] mod router; @@ -231,11 +91,12 @@ pub mod sessions; pub use endpoint::Endpoint; pub use middleware::{Middleware, Next}; pub use redirect::Redirect; -pub use request::{ParamError, Request}; +pub use request::Request; pub use response::Response; pub use response_builder::ResponseBuilder; pub use route::Route; pub use server::Server; +pub use state::State; #[doc(inline)] pub use http_types::{self as http, Body, Error, Status, StatusCode}; @@ -249,7 +110,7 @@ pub use http_types::{self as http, Body, Error, Status, StatusCode}; /// # fn main() -> Result<(), std::io::Error> { block_on(async { /// # /// let mut app = tide::new(); -/// app.at("/").get(|_| async { Ok("Hello, world!") }); +/// app.at("/").get(|_, _| async { Ok("Hello, world!") }); /// app.listen("127.0.0.1:8080").await?; /// # /// # Ok(()) }) } @@ -284,16 +145,16 @@ pub fn new() -> server::Server<()> { /// /// // Initialize the application with state. /// let mut app = tide::with_state(state); -/// app.at("/").get(|req: Request| async move { -/// Ok(format!("Hello, {}!", &req.state().name)) +/// app.at("/").get(|req: Request, state: State| async move { +/// Ok(format!("Hello, {}!", &state.name)) /// }); /// app.listen("127.0.0.1:8080").await?; /// # /// # Ok(()) }) } /// ``` -pub fn with_state(state: State) -> server::Server +pub fn with_state(state: ServerState) -> server::Server where - State: Clone + Send + Sync + 'static, + ServerState: Clone + Send + Sync + 'static, { Server::with_state(state) } diff --git a/src/listener/concurrent_listener.rs b/src/listener/concurrent_listener.rs index 1cff9b0d7..b797f4e84 100644 --- a/src/listener/concurrent_listener.rs +++ b/src/listener/concurrent_listener.rs @@ -15,7 +15,7 @@ use futures_util::stream::{futures_unordered::FuturesUnordered, StreamExt}; /// async_std::task::block_on(async { /// tide::log::start(); /// let mut app = tide::new(); -/// app.at("/").get(|_| async { Ok("Hello, world!") }); +/// app.at("/").get(|_, _| async { Ok("Hello, world!") }); /// /// let mut listener = tide::listener::ConcurrentListener::new(); /// listener.add("127.0.0.1:8000")?; @@ -33,9 +33,9 @@ use futures_util::stream::{futures_unordered::FuturesUnordered, StreamExt}; ///``` #[derive(Default)] -pub struct ConcurrentListener(Vec>>); +pub struct ConcurrentListener(Vec>>); -impl ConcurrentListener { +impl ConcurrentListener { /// creates a new ConcurrentListener pub fn new() -> Self { Self(vec![]) @@ -55,7 +55,7 @@ impl ConcurrentListener { /// # std::mem::drop(tide::new().listen(listener)); // for the State generic /// # Ok(()) } /// ``` - pub fn add>(&mut self, listener: TL) -> io::Result<()> { + pub fn add>(&mut self, listener: TL) -> io::Result<()> { self.0.push(Box::new(listener.to_listener()?)); Ok(()) } @@ -71,15 +71,17 @@ impl ConcurrentListener { /// .with_listener(async_std::net::TcpListener::bind("127.0.0.1:8081").await?), /// ).await?; /// # Ok(()) }) } - pub fn with_listener>(mut self, listener: TL) -> Self { + pub fn with_listener>(mut self, listener: TL) -> Self { self.add(listener).expect("Unable to add listener"); self } } #[async_trait::async_trait] -impl Listener for ConcurrentListener { - async fn listen(&mut self, app: Server) -> io::Result<()> { +impl Listener + for ConcurrentListener +{ + async fn listen(&mut self, app: Server) -> io::Result<()> { let mut futures_unordered = FuturesUnordered::new(); for listener in self.0.iter_mut() { @@ -94,13 +96,13 @@ impl Listener for ConcurrentListene } } -impl Debug for ConcurrentListener { +impl Debug for ConcurrentListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{:?}", self.0) } } -impl Display for ConcurrentListener { +impl Display for ConcurrentListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let string = self .0 diff --git a/src/listener/failover_listener.rs b/src/listener/failover_listener.rs index 4ab1bd242..60476df03 100644 --- a/src/listener/failover_listener.rs +++ b/src/listener/failover_listener.rs @@ -15,7 +15,7 @@ use async_std::io; /// async_std::task::block_on(async { /// tide::log::start(); /// let mut app = tide::new(); -/// app.at("/").get(|_| async { Ok("Hello, world!") }); +/// app.at("/").get(|_, _| async { Ok("Hello, world!") }); /// /// let mut listener = tide::listener::FailoverListener::new(); /// listener.add("127.0.0.1:8000")?; @@ -33,9 +33,9 @@ use async_std::io; ///``` #[derive(Default)] -pub struct FailoverListener(Vec>>); +pub struct FailoverListener(Vec>>); -impl FailoverListener { +impl FailoverListener { /// creates a new FailoverListener pub fn new() -> Self { Self(vec![]) @@ -57,7 +57,7 @@ impl FailoverListener { /// # std::mem::drop(tide::new().listen(listener)); // for the State generic /// # Ok(()) } /// ``` - pub fn add>(&mut self, listener: TL) -> io::Result<()> { + pub fn add>(&mut self, listener: TL) -> io::Result<()> { self.0.push(Box::new(listener.to_listener()?)); Ok(()) } @@ -73,15 +73,17 @@ impl FailoverListener { /// .with_listener(("localhost", 8081)), /// ).await?; /// # Ok(()) }) } - pub fn with_listener>(mut self, listener: TL) -> Self { + pub fn with_listener>(mut self, listener: TL) -> Self { self.add(listener).expect("Unable to add listener"); self } } #[async_trait::async_trait] -impl Listener for FailoverListener { - async fn listen(&mut self, app: Server) -> io::Result<()> { +impl Listener + for FailoverListener +{ + async fn listen(&mut self, app: Server) -> io::Result<()> { for listener in self.0.iter_mut() { let app = app.clone(); match listener.listen(app).await { @@ -102,13 +104,13 @@ impl Listener for FailoverListener< } } -impl Debug for FailoverListener { +impl Debug for FailoverListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{:?}", self.0) } } -impl Display for FailoverListener { +impl Display for FailoverListener { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let string = self .0 diff --git a/src/listener/mod.rs b/src/listener/mod.rs index 82033de15..75859c143 100644 --- a/src/listener/mod.rs +++ b/src/listener/mod.rs @@ -31,13 +31,13 @@ pub(crate) use unix_listener::UnixListener; /// you will also need to implement at least one [`ToListener`](crate::listener::ToListener) that /// outputs your Listener type. #[async_trait::async_trait] -pub trait Listener: +pub trait Listener: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static { /// This is the primary entrypoint for the Listener trait. listen /// is called exactly once, and is expected to spawn tasks for /// each incoming connection. - async fn listen(&mut self, app: Server) -> io::Result<()>; + async fn listen(&mut self, app: Server) -> io::Result<()>; } /// crate-internal shared logic used by tcp and unix listeners to diff --git a/src/listener/parsed_listener.rs b/src/listener/parsed_listener.rs index 4b0e186a9..ed4a50bc7 100644 --- a/src/listener/parsed_listener.rs +++ b/src/listener/parsed_listener.rs @@ -31,8 +31,8 @@ impl Display for ParsedListener { } #[async_trait::async_trait] -impl Listener for ParsedListener { - async fn listen(&mut self, app: Server) -> io::Result<()> { +impl Listener for ParsedListener { + async fn listen(&mut self, app: Server) -> io::Result<()> { match self { #[cfg(unix)] Self::Unix(u) => u.listen(app).await, diff --git a/src/listener/tcp_listener.rs b/src/listener/tcp_listener.rs index db68530ee..f2ab47430 100644 --- a/src/listener/tcp_listener.rs +++ b/src/listener/tcp_listener.rs @@ -51,7 +51,10 @@ impl TcpListener { } } -fn handle_tcp(app: Server, stream: TcpStream) { +fn handle_tcp( + app: Server, + stream: TcpStream, +) { task::spawn(async move { let local_addr = stream.local_addr().ok(); let peer_addr = stream.peer_addr().ok(); @@ -69,8 +72,8 @@ fn handle_tcp(app: Server, stream: } #[async_trait::async_trait] -impl Listener for TcpListener { - async fn listen(&mut self, app: Server) -> io::Result<()> { +impl Listener for TcpListener { + async fn listen(&mut self, app: Server) -> io::Result<()> { self.connect().await?; let listener = self.listener()?; crate::log::info!("Server listening on {}", self); diff --git a/src/listener/to_listener.rs b/src/listener/to_listener.rs index b67d068ba..28327fa44 100644 --- a/src/listener/to_listener.rs +++ b/src/listener/to_listener.rs @@ -48,8 +48,8 @@ use async_std::io; /// # Other implementations /// See below for additional provided implementations of ToListener. -pub trait ToListener { - type Listener: Listener; +pub trait ToListener { + type Listener: Listener; /// Transform self into a /// [`Listener`](crate::listener::Listener). Unless self is /// already bound/connected to the underlying io, converting to a diff --git a/src/listener/to_listener_impls.rs b/src/listener/to_listener_impls.rs index a1b406d1f..b62171f7c 100644 --- a/src/listener/to_listener_impls.rs +++ b/src/listener/to_listener_impls.rs @@ -5,7 +5,7 @@ use crate::http::url::Url; use async_std::io; use std::net::ToSocketAddrs; -impl ToListener for Url { +impl ToListener for Url { type Listener = ParsedListener; fn to_listener(self) -> io::Result { @@ -48,14 +48,14 @@ impl ToListener for Url { } } -impl ToListener for String { +impl ToListener for String { type Listener = ParsedListener; fn to_listener(self) -> io::Result { - ToListener::::to_listener(self.as_str()) + ToListener::::to_listener(self.as_str()) } } -impl ToListener for &str { +impl ToListener for &str { type Listener = ParsedListener; fn to_listener(self) -> io::Result { @@ -64,7 +64,7 @@ impl ToListener for &str { socket_addrs.collect(), ))) } else if let Ok(url) = Url::parse(self) { - ToListener::::to_listener(url) + ToListener::::to_listener(url) } else { Err(io::Error::new( io::ErrorKind::InvalidInput, @@ -75,7 +75,9 @@ impl ToListener for &str { } #[cfg(unix)] -impl ToListener for async_std::path::PathBuf { +impl ToListener + for async_std::path::PathBuf +{ type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_path(self)) @@ -83,28 +85,30 @@ impl ToListener for async_std::path } #[cfg(unix)] -impl ToListener for std::path::PathBuf { +impl ToListener for std::path::PathBuf { type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_path(self)) } } -impl ToListener for async_std::net::TcpListener { +impl ToListener + for async_std::net::TcpListener +{ type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_listener(self)) } } -impl ToListener for std::net::TcpListener { +impl ToListener for std::net::TcpListener { type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_listener(self)) } } -impl ToListener for (&str, u16) { +impl ToListener for (&str, u16) { type Listener = TcpListener; fn to_listener(self) -> io::Result { @@ -113,7 +117,7 @@ impl ToListener for (&str, u16) { } #[cfg(unix)] -impl ToListener +impl ToListener for async_std::os::unix::net::UnixListener { type Listener = UnixListener; @@ -123,14 +127,16 @@ impl ToListener } #[cfg(unix)] -impl ToListener for std::os::unix::net::UnixListener { +impl ToListener + for std::os::unix::net::UnixListener +{ type Listener = UnixListener; fn to_listener(self) -> io::Result { Ok(UnixListener::from_listener(self)) } } -impl ToListener for TcpListener { +impl ToListener for TcpListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) @@ -138,43 +144,49 @@ impl ToListener for TcpListener { } #[cfg(unix)] -impl ToListener for UnixListener { +impl ToListener for UnixListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for ConcurrentListener { +impl ToListener + for ConcurrentListener +{ type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for ParsedListener { +impl ToListener for ParsedListener { type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for FailoverListener { +impl ToListener + for FailoverListener +{ type Listener = Self; fn to_listener(self) -> io::Result { Ok(self) } } -impl ToListener for std::net::SocketAddr { +impl ToListener for std::net::SocketAddr { type Listener = TcpListener; fn to_listener(self) -> io::Result { Ok(TcpListener::from_addrs(vec![self])) } } -impl, State: Clone + Send + Sync + 'static> ToListener for Vec { - type Listener = ConcurrentListener; +impl, ServerState: Clone + Send + Sync + 'static> + ToListener for Vec +{ + type Listener = ConcurrentListener; fn to_listener(self) -> io::Result { let mut concurrent_listener = ConcurrentListener::new(); for listener in self { diff --git a/src/listener/unix_listener.rs b/src/listener/unix_listener.rs index 72aff852d..be8b81fd6 100644 --- a/src/listener/unix_listener.rs +++ b/src/listener/unix_listener.rs @@ -64,7 +64,7 @@ fn unix_socket_addr_to_string(result: io::Result) -> Option }) } -fn handle_unix(app: Server, stream: UnixStream) { +fn handle_unix(app: Server, stream: UnixStream) { task::spawn(async move { let local_addr = unix_socket_addr_to_string(stream.local_addr()); let peer_addr = unix_socket_addr_to_string(stream.peer_addr()); @@ -82,8 +82,8 @@ fn handle_unix(app: Server, stream: } #[async_trait::async_trait] -impl Listener for UnixListener { - async fn listen(&mut self, app: Server) -> io::Result<()> { +impl Listener for UnixListener { + async fn listen(&mut self, app: Server) -> io::Result<()> { self.connect().await?; crate::log::info!("Server listening on {}", self); let listener = self.listener()?; diff --git a/src/log/middleware.rs b/src/log/middleware.rs index 41a9167d3..92d185274 100644 --- a/src/log/middleware.rs +++ b/src/log/middleware.rs @@ -1,4 +1,4 @@ -use crate::log; +use crate::{log, State}; use crate::{Middleware, Next, Request}; /// Log all incoming requests and responses. @@ -28,13 +28,14 @@ impl LogMiddleware { } /// Log a request and a response. - async fn log<'a, State: Clone + Send + Sync + 'static>( + async fn log<'a, ServerState: Clone + Send + Sync + 'static>( &'a self, - mut req: Request, - next: Next<'a, State>, + mut req: Request, + state: State, + next: Next<'a, ServerState>, ) -> crate::Result { if req.ext::().is_some() { - return Ok(next.run(req).await); + return Ok(next.run(req, state).await); } req.set_ext(LogMiddlewareHasBeenRun); @@ -45,37 +46,47 @@ impl LogMiddleware { path: path, }); let start = std::time::Instant::now(); - let response = next.run(req).await; + let response = next.run(req, state).await; let status = response.status(); if status.is_server_error() { if let Some(error) = response.error() { log::error!("Internal error --> Response sent", { - message: error.to_string(), + message: format!("\"{}\"", error.to_string()), method: method, path: path, - status: status as u16, + status: format!("{} - {}", status as u16, status.canonical_reason()), duration: format!("{:?}", start.elapsed()), }); } else { log::error!("Internal error --> Response sent", { method: method, path: path, - status: status as u16, + status: format!("{} - {}", status as u16, status.canonical_reason()), duration: format!("{:?}", start.elapsed()), }); } } else if status.is_client_error() { - log::warn!("--> Response sent", { - method: method, - path: path, - status: status as u16, - duration: format!("{:?}", start.elapsed()), - }); + if let Some(error) = response.error() { + log::warn!("Client error --> Response sent", { + message: format!("\"{}\"", error.to_string()), + method: method, + path: path, + status: format!("{} - {}", status as u16, status.canonical_reason()), + duration: format!("{:?}", start.elapsed()), + }); + } else { + log::warn!("Client error --> Response sent", { + method: method, + path: path, + status: format!("{} - {}", status as u16, status.canonical_reason()), + duration: format!("{:?}", start.elapsed()), + }); + } } else { log::info!("--> Response sent", { method: method, path: path, - status: status as u16, + status: format!("{} - {}", status as u16, status.canonical_reason()), duration: format!("{:?}", start.elapsed()), }); } @@ -84,8 +95,13 @@ impl LogMiddleware { } #[async_trait::async_trait] -impl Middleware for LogMiddleware { - async fn handle(&self, req: Request, next: Next<'_, State>) -> crate::Result { - self.log(req, next).await +impl Middleware for LogMiddleware { + async fn handle( + &self, + req: Request, + state: State, + next: Next<'_, ServerState>, + ) -> crate::Result { + self.log(req, state, next).await } } diff --git a/src/log/mod.rs b/src/log/mod.rs index 77ce8ca52..8f74ac2dc 100644 --- a/src/log/mod.rs +++ b/src/log/mod.rs @@ -21,7 +21,9 @@ pub use kv_log_macro::{max_level, Level}; mod middleware; +#[cfg(feature = "logger")] pub use femme::LevelFilter; + pub use middleware::LogMiddleware; /// Start logging. diff --git a/src/middleware.rs b/src/middleware.rs index 7e1ca9a90..59cc5a4f6 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -3,16 +3,21 @@ use std::sync::Arc; use crate::endpoint::DynEndpoint; -use crate::{Request, Response}; +use crate::{Request, Response, State}; use async_trait::async_trait; use std::future::Future; use std::pin::Pin; /// Middleware that wraps around the remaining middleware chain. #[async_trait] -pub trait Middleware: Send + Sync + 'static { +pub trait Middleware: Send + Sync + 'static { /// Asynchronously handle the request, and return a response. - async fn handle(&self, request: Request, next: Next<'_, State>) -> crate::Result; + async fn handle( + &self, + request: Request, + state: State, + next: Next<'_, ServerState>, + ) -> crate::Result; /// Set the middleware's name. By default it uses the type signature. fn name(&self) -> &str { @@ -21,40 +26,49 @@ pub trait Middleware: Send + Sync + 'static { } #[async_trait] -impl Middleware for F +impl Middleware for F where - State: Clone + Send + Sync + 'static, + ServerState: Clone + Send + Sync + 'static, F: Send + Sync + 'static + for<'a> Fn( - Request, - Next<'a, State>, + Request, + State, + Next<'a, ServerState>, ) -> Pin + 'a + Send>>, { - async fn handle(&self, req: Request, next: Next<'_, State>) -> crate::Result { - (self)(req, next).await + async fn handle( + &self, + req: Request, + state: State, + next: Next<'_, ServerState>, + ) -> crate::Result { + (self)(req, state, next).await } } /// The remainder of a middleware chain, including the endpoint. #[allow(missing_debug_implementations)] -pub struct Next<'a, State> { - pub(crate) endpoint: &'a DynEndpoint, - pub(crate) next_middleware: &'a [Arc>], +pub struct Next<'a, ServerState> { + pub(crate) endpoint: &'a DynEndpoint, + pub(crate) next_middleware: &'a [Arc>], } -impl Next<'_, State> { +impl Next<'_, ServerState> +where + ServerState: Clone + Send + Sync + 'static, +{ /// Asynchronously execute the remaining middleware chain. - pub async fn run(mut self, req: Request) -> Response { + pub async fn run(mut self, req: Request, state: State) -> Response { if let Some((current, next)) = self.next_middleware.split_first() { self.next_middleware = next; - match current.handle(req, self).await { + match current.handle(req, state, self).await { Ok(request) => request, Err(err) => err.into(), } } else { - match self.endpoint.call(req).await { + match self.endpoint.call(req, state).await { Ok(request) => request, Err(err) => err.into(), } diff --git a/src/redirect.rs b/src/redirect.rs index 0230279ec..dbad8fd09 100644 --- a/src/redirect.rs +++ b/src/redirect.rs @@ -9,7 +9,7 @@ //! use tide::Redirect; //! //! let mut app = tide::new(); -//! app.at("/").get(|_| async { Ok("meow") }); +//! app.at("/").get(|_, _| async { Ok("meow") }); //! app.at("/nori").get(Redirect::temporary("/")); //! app.listen("127.0.0.1:8080").await?; //! # @@ -17,8 +17,8 @@ //! ``` use crate::http::headers::LOCATION; -use crate::StatusCode; use crate::{Endpoint, Request, Response}; +use crate::{State, StatusCode}; /// A redirection endpoint. /// @@ -28,7 +28,7 @@ use crate::{Endpoint, Request, Response}; /// # use tide::{Response, Redirect, Request, StatusCode}; /// # fn next_product() -> Option { None } /// # #[allow(dead_code)] -/// async fn route_handler(request: Request<()>) -> tide::Result { +/// async fn route_handler(req: Request, state: tide::State<()>) -> tide::Result { /// if let Some(product_url) = next_product() { /// Ok(Redirect::new(product_url).into()) /// } else { @@ -86,12 +86,12 @@ impl> Redirect { } #[async_trait::async_trait] -impl Endpoint for Redirect +impl Endpoint for Redirect where - State: Clone + Send + Sync + 'static, + ServerState: Clone + Send + Sync + 'static, T: AsRef + Send + Sync + 'static, { - async fn call(&self, _req: Request) -> crate::Result { + async fn call(&self, _req: Request, _state: State) -> crate::Result { Ok(self.into()) } } diff --git a/src/request.rs b/src/request.rs index 2bb5553fc..943b8fb46 100644 --- a/src/request.rs +++ b/src/request.rs @@ -4,10 +4,10 @@ use route_recognizer::Params; use std::ops::Index; use std::pin::Pin; -use std::{fmt, str::FromStr}; use crate::cookies::CookieData; use crate::http::cookies::Cookie; +use crate::http::format_err; use crate::http::headers::{self, HeaderName, HeaderValues, ToHeaderValues}; use crate::http::{self, Body, Method, Mime, StatusCode, Url, Version}; use crate::Response; @@ -21,39 +21,17 @@ pin_project_lite::pin_project! { /// Requests also provide *extensions*, a type map primarily used for low-level /// communication between middleware and endpoints. #[derive(Debug)] - pub struct Request { - pub(crate) state: State, + pub struct Request { #[pin] pub(crate) req: http::Request, pub(crate) route_params: Vec, } } -#[derive(Debug)] -pub enum ParamError { - NotFound(String), - ParsingError(E), -} - -impl fmt::Display for ParamError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ParamError::NotFound(name) => write!(f, "Param \"{}\" not found!", name), - ParamError::ParsingError(err) => write!(f, "Param failed to parse: {}", err), - } - } -} - -impl std::error::Error for ParamError {} - -impl Request { +impl Request { /// Create a new `Request`. - pub(crate) fn new(state: State, req: http_types::Request, route_params: Vec) -> Self { - Self { - state, - req, - route_params, - } + pub(crate) fn new(req: http_types::Request, route_params: Vec) -> Self { + Self { req, route_params } } /// Access the request's HTTP method. @@ -67,7 +45,7 @@ impl Request { /// use tide::Request; /// /// let mut app = tide::new(); - /// app.at("/").get(|req: Request<()>| async move { + /// app.at("/").get(|req: Request, _| async move { /// assert_eq!(req.method(), http_types::Method::Get); /// Ok("") /// }); @@ -91,7 +69,7 @@ impl Request { /// use tide::Request; /// /// let mut app = tide::new(); - /// app.at("/").get(|req: Request<()>| async move { + /// app.at("/").get(|req: Request, _| async move { /// assert_eq!(req.url(), &"/".parse::().unwrap()); /// Ok("") /// }); @@ -115,7 +93,7 @@ impl Request { /// use tide::Request; /// /// let mut app = tide::new(); - /// app.at("/").get(|req: Request<()>| async move { + /// app.at("/").get(|req: Request, _| async move { /// assert_eq!(req.version(), Some(http_types::Version::Http1_1)); /// Ok("") /// }); @@ -186,7 +164,7 @@ impl Request { /// use tide::Request; /// /// let mut app = tide::new(); - /// app.at("/").get(|req: Request<()>| async move { + /// app.at("/").get(|req: Request, _| async move { /// assert_eq!(req.header("X-Forwarded-For").unwrap(), "127.0.0.1"); /// Ok("") /// }); @@ -271,12 +249,6 @@ impl Request { self.req.ext_mut().insert(val) } - #[must_use] - /// Access application scoped state. - pub fn state(&self) -> &State { - &self.state - } - /// Extract and parse a route parameter by name. /// /// Returns the results of parsing the parameter according to the inferred @@ -287,10 +259,7 @@ impl Request { /// /// # Errors /// - /// Yields a `ParamError::ParsingError` if the parameter was found but failed to parse as an - /// instance of type `T`. - /// - /// Yields a `ParamError::NotFound` if `key` is not a parameter for the route. + /// An error is returned if `key` is not a valid parameter for the route. /// /// # Examples /// @@ -300,8 +269,8 @@ impl Request { /// # /// use tide::{Request, Result}; /// - /// async fn greet(req: Request<()>) -> Result { - /// let name = req.param("name").unwrap_or("world".to_owned()); + /// async fn greet(req: Request, _state: tide::State<()>) -> Result { + /// let name = req.param("name").unwrap_or("world"); /// Ok(format!("Hello, {}!", name)) /// } /// @@ -312,17 +281,16 @@ impl Request { /// # /// # Ok(()) })} /// ``` - pub fn param(&self, key: &str) -> Result> { + pub fn param(&self, key: &str) -> crate::Result<&str> { self.route_params .iter() .rev() .find_map(|params| params.find(key)) - .ok_or_else(|| ParamError::NotFound(key.to_string())) - .and_then(|param| param.parse().map_err(ParamError::ParsingError)) + .ok_or_else(|| format_err!("Param \"{}\" not found", key.to_string())) } - /// Parse the URL query component into a struct, using [serde_qs](serde_qs). To get the entire - /// query as an unparsed string, use `request.url().query()` + /// Parse the URL query component into a struct, using [serde_qs](https://docs.rs/serde_qs). To + /// get the entire query as an unparsed string, use `request.url().query()` /// /// ```rust /// # fn main() -> Result<(), std::io::Error> { async_std::task::block_on(async { @@ -343,7 +311,7 @@ impl Request { /// } /// } /// } - /// app.at("/pages").post(|req: tide::Request<()>| async move { + /// app.at("/pages").post(|req: tide::Request, _| async move { /// let page: Page = req.query()?; /// Ok(format!("page {}, with {} items", page.offset, page.size)) /// }); @@ -407,7 +375,7 @@ impl Request { /// use tide::Request; /// /// let mut app = tide::new(); - /// app.at("/").get(|mut req: Request<()>| async move { + /// app.at("/").get(|mut req: Request, _state |async move { /// let _body: Vec = req.body_bytes().await.unwrap(); /// Ok("") /// }); @@ -441,7 +409,7 @@ impl Request { /// use tide::Request; /// /// let mut app = tide::new(); - /// app.at("/").get(|mut req: Request<()>| async move { + /// app.at("/").get(|mut req: Request, _state |async move { /// let _body: String = req.body_string().await.unwrap(); /// Ok("") /// }); @@ -481,7 +449,7 @@ impl Request { /// legs: u8 /// } /// - /// app.at("/").post(|mut req: tide::Request<()>| async move { + /// app.at("/").post(|mut req: tide::Request, _| async move { /// let animal: Animal = req.body_form().await?; /// Ok(format!( /// "hello, {}! i've put in an order for {} shoes", @@ -555,31 +523,31 @@ impl Request { } } -impl AsRef for Request { +impl AsRef for Request { fn as_ref(&self) -> &http::Request { &self.req } } -impl AsMut for Request { +impl AsMut for Request { fn as_mut(&mut self) -> &mut http::Request { &mut self.req } } -impl AsRef for Request { +impl AsRef for Request { fn as_ref(&self) -> &http::Headers { self.req.as_ref() } } -impl AsMut for Request { +impl AsMut for Request { fn as_mut(&mut self) -> &mut http::Headers { self.req.as_mut() } } -impl Read for Request { +impl Read for Request { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -589,21 +557,21 @@ impl Read for Request { } } -impl Into for Request { +impl Into for Request { fn into(self) -> http::Request { self.req } } -impl Into> for http_types::Request { - fn into(self) -> Request { - Request::new(State::default(), self, Vec::::new()) +impl Into for http_types::Request { + fn into(self) -> Request { + Request::new(self, vec![]) } } // NOTE: From cannot be implemented for this conversion because `State` needs to // be constrained by a type. -impl Into for Request { +impl Into for Request { fn into(mut self) -> Response { let mut res = Response::new(StatusCode::Ok); res.set_body(self.take_body()); @@ -611,7 +579,7 @@ impl Into for Request { } } -impl IntoIterator for Request { +impl IntoIterator for Request { type Item = (HeaderName, HeaderValues); type IntoIter = http_types::headers::IntoIter; @@ -622,7 +590,7 @@ impl IntoIterator for Request { } } -impl<'a, State> IntoIterator for &'a Request { +impl<'a> IntoIterator for &'a Request { type Item = (&'a HeaderName, &'a HeaderValues); type IntoIter = http_types::headers::Iter<'a>; @@ -632,7 +600,7 @@ impl<'a, State> IntoIterator for &'a Request { } } -impl<'a, State> IntoIterator for &'a mut Request { +impl<'a> IntoIterator for &'a mut Request { type Item = (&'a HeaderName, &'a mut HeaderValues); type IntoIter = http_types::headers::IterMut<'a>; @@ -642,7 +610,7 @@ impl<'a, State> IntoIterator for &'a mut Request { } } -impl Index for Request { +impl Index for Request { type Output = HeaderValues; /// Returns a reference to the value corresponding to the supplied name. @@ -656,7 +624,7 @@ impl Index for Request { } } -impl Index<&str> for Request { +impl Index<&str> for Request { type Output = HeaderValues; /// Returns a reference to the value corresponding to the supplied name. diff --git a/src/route.rs b/src/route.rs index da71b747a..6f6c9e5a3 100644 --- a/src/route.rs +++ b/src/route.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use crate::endpoint::MiddlewareEndpoint; use crate::fs::ServeDir; -use crate::log; +use crate::{log, State}; use crate::{router::Router, Endpoint, Middleware}; /// A handle to a route. @@ -17,10 +17,10 @@ use crate::{router::Router, Endpoint, Middleware}; /// /// [`Server::at`]: ./struct.Server.html#method.at #[allow(missing_debug_implementations)] -pub struct Route<'a, State> { - router: &'a mut Router, +pub struct Route<'a, ServerState> { + router: &'a mut Router, path: String, - middleware: Vec>>, + middleware: Vec>>, /// Indicates whether the path of current route is treated as a prefix. Set by /// [`strip_prefix`]. /// @@ -28,8 +28,8 @@ pub struct Route<'a, State> { prefix: bool, } -impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { - pub(crate) fn new(router: &'a mut Router, path: String) -> Route<'a, State> { +impl<'a, ServerState: Clone + Send + Sync + 'static> Route<'a, ServerState> { + pub(crate) fn new(router: &'a mut Router, path: String) -> Route<'a, ServerState> { Route { router, path, @@ -39,11 +39,11 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { } /// Extend the route with the given `path`. - pub fn at<'b>(&'b mut self, path: &str) -> Route<'b, State> { + pub fn at<'b>(&'b mut self, path: &str) -> Route<'b, ServerState> { let mut p = self.path.clone(); if !p.ends_with('/') && !path.starts_with('/') { - p.push_str("/"); + p.push('/'); } if path != "/" { @@ -79,7 +79,7 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { /// Apply the given middleware to the current route. pub fn with(&mut self, middleware: M) -> &mut Self where - M: Middleware, + M: Middleware, { log::trace!( "Adding middleware {} to route {:?}", @@ -101,12 +101,15 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { /// [`Server`]: struct.Server.html pub fn nest(&mut self, service: crate::Server) -> &mut Self where - State: Clone + Send + Sync + 'static, + ServerState: Clone + Send + Sync + 'static, InnerState: Clone + Send + Sync + 'static, { + let prefix = self.prefix; + self.prefix = true; self.all(service); - self.prefix = false; + self.prefix = prefix; + self } @@ -138,7 +141,11 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { } /// Add an endpoint for the given HTTP method - pub fn method(&mut self, method: http_types::Method, ep: impl Endpoint) -> &mut Self { + pub fn method( + &mut self, + method: http_types::Method, + ep: impl Endpoint, + ) -> &mut Self { if self.prefix { let ep = StripPrefixEndpoint::new(ep); let (ep1, ep2): (Box>, Box>) = @@ -172,7 +179,7 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { /// Add an endpoint for all HTTP methods, as a fallback. /// /// Routes with specific HTTP methods will be tried first. - pub fn all(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn all(&mut self, ep: impl Endpoint) -> &mut Self { if self.prefix { let ep = StripPrefixEndpoint::new(ep); let (ep1, ep2): (Box>, Box>) = @@ -204,55 +211,55 @@ impl<'a, State: Clone + Send + Sync + 'static> Route<'a, State> { } /// Add an endpoint for `GET` requests - pub fn get(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn get(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Get, ep); self } /// Add an endpoint for `HEAD` requests - pub fn head(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn head(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Head, ep); self } /// Add an endpoint for `PUT` requests - pub fn put(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn put(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Put, ep); self } /// Add an endpoint for `POST` requests - pub fn post(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn post(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Post, ep); self } /// Add an endpoint for `DELETE` requests - pub fn delete(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn delete(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Delete, ep); self } /// Add an endpoint for `OPTIONS` requests - pub fn options(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn options(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Options, ep); self } /// Add an endpoint for `CONNECT` requests - pub fn connect(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn connect(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Connect, ep); self } /// Add an endpoint for `PATCH` requests - pub fn patch(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn patch(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Patch, ep); self } /// Add an endpoint for `TRACE` requests - pub fn trace(&mut self, ep: impl Endpoint) -> &mut Self { + pub fn trace(&mut self, ep: impl Endpoint) -> &mut Self { self.method(http_types::Method::Trace, ep); self } @@ -274,27 +281,22 @@ impl Clone for StripPrefixEndpoint { } #[async_trait::async_trait] -impl Endpoint for StripPrefixEndpoint +impl Endpoint for StripPrefixEndpoint where - State: Clone + Send + Sync + 'static, - E: Endpoint, + ServerState: Clone + Send + Sync + 'static, + E: Endpoint, { - async fn call(&self, req: crate::Request) -> crate::Result { + async fn call(&self, req: crate::Request, state: State) -> crate::Result { let crate::Request { - state, mut req, route_params, } = req; - let rest = crate::request::rest(&route_params).unwrap_or_else(|| ""); + let rest = crate::request::rest(&route_params).unwrap_or(""); req.url_mut().set_path(&rest); self.0 - .call(crate::Request { - state, - req, - route_params, - }) + .call(crate::Request::new(req, route_params), state) .await } } diff --git a/src/router.rs b/src/router.rs index b3673ee7d..f7f196311 100644 --- a/src/router.rs +++ b/src/router.rs @@ -9,19 +9,19 @@ use crate::{Request, Response, StatusCode}; /// Internally, we have a separate state machine per http method; indexing /// by the method first allows the table itself to be more efficient. #[allow(missing_debug_implementations)] -pub struct Router { - method_map: HashMap>>>, - all_method_router: MethodRouter>>, +pub struct Router { + method_map: HashMap>>>, + all_method_router: MethodRouter>>, } /// The result of routing a URL #[allow(missing_debug_implementations)] -pub struct Selection<'a, State> { - pub(crate) endpoint: &'a DynEndpoint, +pub struct Selection<'a, ServerState> { + pub(crate) endpoint: &'a DynEndpoint, pub(crate) params: Params, } -impl Router { +impl Router { pub fn new() -> Self { Router { method_map: HashMap::default(), @@ -29,18 +29,23 @@ impl Router { } } - pub fn add(&mut self, path: &str, method: http_types::Method, ep: Box>) { + pub fn add( + &mut self, + path: &str, + method: http_types::Method, + ep: Box>, + ) { self.method_map .entry(method) .or_insert_with(MethodRouter::new) .add(path, ep) } - pub fn add_all(&mut self, path: &str, ep: Box>) { + pub fn add_all(&mut self, path: &str, ep: Box>) { self.all_method_router.add(path, ep) } - pub fn route(&self, path: &str, method: http_types::Method) -> Selection<'_, State> { + pub fn route(&self, path: &str, method: http_types::Method) -> Selection<'_, ServerState> { if let Some(Match { handler, params }) = self .method_map .get(&method) @@ -81,14 +86,16 @@ impl Router { } } -async fn not_found_endpoint( - _req: Request, +async fn not_found_endpoint( + _req: Request, + _: ServerState, ) -> crate::Result { Ok(Response::new(StatusCode::NotFound)) } -async fn method_not_allowed( - _req: Request, +async fn method_not_allowed( + _req: Request, + _: ServerState, ) -> crate::Result { Ok(Response::new(StatusCode::MethodNotAllowed)) } diff --git a/src/security/cors.rs b/src/security/cors.rs index 0fd33bf5b..dd8e9ec53 100644 --- a/src/security/cors.rs +++ b/src/security/cors.rs @@ -2,7 +2,7 @@ use http_types::headers::{HeaderValue, HeaderValues}; use http_types::{headers, Method, StatusCode}; use crate::middleware::{Middleware, Next}; -use crate::{Request, Result}; +use crate::{Request, Result, State}; /// Middleware for CORS /// @@ -133,14 +133,19 @@ impl CorsMiddleware { } #[async_trait::async_trait] -impl Middleware for CorsMiddleware { - async fn handle(&self, req: Request, next: Next<'_, State>) -> Result { +impl Middleware for CorsMiddleware { + async fn handle( + &self, + req: Request, + state: State, + next: Next<'_, ServerState>, + ) -> Result { // TODO: how should multiple origin values be handled? let origins = req.header(&headers::ORIGIN).cloned(); if origins.is_none() { // This is not a CORS request if there is no Origin header - return Ok(next.run(req).await); + return Ok(next.run(req, state).await); } let origins = origins.unwrap(); @@ -155,7 +160,7 @@ impl Middleware for CorsMiddleware return Ok(self.build_preflight_response(&origins).into()); } - let mut response = next.run(req).await; + let mut response = next.run(req, state).await; response.insert_header( headers::ACCESS_CONTROL_ALLOW_ORIGIN, @@ -251,7 +256,7 @@ mod test { fn app() -> crate::Server<()> { let mut app = crate::Server::new(); - app.at(ENDPOINT).get(|_| async { Ok("Hello World") }); + app.at(ENDPOINT).get(|_, _| async { Ok("Hello World") }); app } @@ -367,7 +372,7 @@ mod test { async fn retain_cookies() { let mut app = crate::Server::new(); app.with(CorsMiddleware::new().allow_origin(ALLOW_ORIGIN)); - app.at(ENDPOINT).get(|_| async { + app.at(ENDPOINT).get(|_, _| async { let mut res = crate::Response::new(http_types::StatusCode::Ok); res.insert_cookie(http_types::Cookie::new("foo", "bar")); Ok(res) @@ -383,7 +388,7 @@ mod test { #[async_std::test] async fn set_cors_headers_to_error_responses() { let mut app = crate::Server::new(); - app.at(ENDPOINT).get(|_| async { + app.at(ENDPOINT).get(|_, _| async { Err::<&str, _>(crate::Error::from_str( StatusCode::BadRequest, "bad request", diff --git a/src/server.rs b/src/server.rs index 4bc5b03f9..b32a9ce48 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,12 +3,13 @@ use async_std::io; use async_std::sync::Arc; -use crate::cookies; use crate::listener::{Listener, ToListener}; use crate::log; use crate::middleware::{Middleware, Next}; use crate::router::{Router, Selection}; +use crate::{cookies, State}; use crate::{Endpoint, Request, Route}; + /// An HTTP server. /// /// Servers are built up as a combination of *state*, *endpoints* and *middleware*: @@ -24,11 +25,18 @@ use crate::{Endpoint, Request, Route}; /// - Middleware extends the base Tide framework with additional request or /// response processing, such as compression, default headers, or logging. To /// add middleware to an app, use the [`Server::middleware`] method. -#[allow(missing_debug_implementations)] -pub struct Server { - router: Arc>, - state: State, - middleware: Arc>>>, +pub struct Server { + router: Arc>, + state: ServerState, + /// Holds the middleware stack. + /// + /// Note(Fishrock123): We do actually want this structure. + /// The outer Arc allows us to clone in .respond() without cloning the array. + /// The Vec allows us to add middleware at runtime. + /// The inner Arc-s allow MiddlewareEndpoint-s to be cloned internally. + /// We don't use a Mutex around the Vec here because adding a middleware during execution should be an error. + #[allow(clippy::rc_buffer)] + middleware: Arc>>>, } impl Server<()> { @@ -41,7 +49,7 @@ impl Server<()> { /// # fn main() -> Result<(), std::io::Error> { block_on(async { /// # /// let mut app = tide::new(); - /// app.at("/").get(|_| async { Ok("Hello, world!") }); + /// app.at("/").get(|_, _| async { Ok("Hello, world!") }); /// app.listen("127.0.0.1:8080").await?; /// # /// # Ok(()) }) } @@ -58,7 +66,7 @@ impl Default for Server<()> { } } -impl Server { +impl Server { /// Create a new Tide server with shared application scoped state. /// /// Application scoped state is useful for storing items @@ -84,14 +92,14 @@ impl Server { /// /// // Initialize the application with state. /// let mut app = tide::with_state(state); - /// app.at("/").get(|req: Request| async move { - /// Ok(format!("Hello, {}!", &req.state().name)) + /// app.at("/").get(|req: Request, state: State| async move { + /// Ok(format!("Hello, {}!", &state.name)) /// }); /// app.listen("127.0.0.1:8080").await?; /// # /// # Ok(()) }) } /// ``` - pub fn with_state(state: State) -> Self { + pub fn with_state(state: ServerState) -> Self { let mut server = Self { router: Arc::new(Router::new()), middleware: Arc::new(vec![]), @@ -113,7 +121,7 @@ impl Server { /// /// ```rust,no_run /// # let mut app = tide::Server::new(); - /// app.at("/").get(|_| async { Ok("Hello, world!") }); + /// app.at("/").get(|_, _| async { Ok("Hello, world!") }); /// ``` /// /// A path is comprised of zero or many segments, i.e. non-empty strings @@ -149,7 +157,7 @@ impl Server { /// There is no fallback route matching, i.e. either a resource is a full /// match or not, which means that the order of adding resources has no /// effect. - pub fn at<'a>(&'a mut self, path: &str) -> Route<'a, State> { + pub fn at<'a>(&'a mut self, path: &str) -> Route<'a, ServerState> { let router = Arc::get_mut(&mut self.router) .expect("Registering routes is not possible after the Server has started"); Route::new(router, path.to_owned()) @@ -166,7 +174,7 @@ impl Server { /// order in which it is applied. pub fn with(&mut self, middleware: M) -> &mut Self where - M: Middleware, + M: Middleware, { log::trace!("Adding middleware {}", middleware.name()); let m = Arc::get_mut(&mut self.middleware) @@ -176,7 +184,7 @@ impl Server { } /// Asynchronously serve the app with the supplied listener. For more details, see [Listener] and [ToListener] - pub async fn listen>(self, listener: TL) -> io::Result<()> { + pub async fn listen>(self, listener: TL) -> io::Result<()> { listener.to_listener()?.listen(self).await } @@ -194,7 +202,7 @@ impl Server { /// use tide::http::{Url, Method, Request, Response}; /// /// let mut app = tide::new(); - /// app.at("/").get(|_| async { Ok("hello world") }); + /// app.at("/").get(|_, _| async { Ok("hello world") }); /// /// let req = Request::new(Method::Get, Url::parse("https://example.com")?); /// let res: Response = app.respond(req).await?; @@ -217,14 +225,15 @@ impl Server { let method = req.method().to_owned(); let Selection { endpoint, params } = router.route(&req.url().path(), method); let route_params = vec![params]; - let req = Request::new(state, req, route_params); + let req = Request::new(req, route_params); + let state = State::new(state); let next = Next { endpoint, next_middleware: &middleware, }; - let res = next.run(req).await; + let res = next.run(req, state).await; let res: http_types::Response = res.into(); Ok(res.into()) } @@ -237,15 +246,21 @@ impl Server { /// # #[derive(Clone)] struct SomeAppState; /// let mut app = tide::with_state(SomeAppState); /// let mut admin = tide::with_state(app.state().clone()); - /// admin.at("/").get(|_| async { Ok("nested app with cloned state") }); + /// admin.at("/").get(|_, _| async { Ok("nested app with cloned state") }); /// app.at("/").nest(admin); /// ``` - pub fn state(&self) -> &State { + pub fn state(&self) -> &ServerState { &self.state } } -impl Clone for Server { +impl std::fmt::Debug for Server { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Server").finish() + } +} + +impl Clone for Server { fn clone(&self) -> Self { Self { router: self.router.clone(), @@ -256,31 +271,41 @@ impl Clone for Server { } #[async_trait::async_trait] -impl - Endpoint for Server +impl Endpoint for Server +where + ServerState: Clone + Sync + Send + 'static, + InnerState: Clone + Sync + Send + 'static, { - async fn call(&self, req: Request) -> crate::Result { + async fn call(&self, req: Request, _state: State) -> crate::Result { let Request { req, mut route_params, - .. } = req; let path = req.url().path().to_owned(); let method = req.method().to_owned(); let router = self.router.clone(); let middleware = self.middleware.clone(); - let state = self.state.clone(); + let state = State::new(self.state.clone()); let Selection { endpoint, params } = router.route(&path, method); route_params.push(params); - let req = Request::new(state, req, route_params); + let req = Request::new(req, route_params); let next = Next { endpoint, next_middleware: &middleware, }; - Ok(next.run(req).await) + Ok(next.run(req, state).await) + } +} + +#[crate::utils::async_trait] +impl http_client::HttpClient + for Server +{ + async fn send(&self, req: crate::http::Request) -> crate::http::Result { + self.respond(req).await } } diff --git a/src/sessions/middleware.rs b/src/sessions/middleware.rs index 897a13860..de03ee102 100644 --- a/src/sessions/middleware.rs +++ b/src/sessions/middleware.rs @@ -1,7 +1,10 @@ use super::{Session, SessionStore}; -use crate::http::{ - cookies::{Cookie, Key, SameSite}, - format_err, +use crate::{ + http::{ + cookies::{Cookie, Key, SameSite}, + format_err, + }, + State, }; use crate::{utils::async_trait, Middleware, Next, Request}; use std::time::Duration; @@ -27,20 +30,20 @@ const BASE64_DIGEST_LEN: usize = 44; /// b"we recommend you use std::env::var(\"TIDE_SECRET\").unwrap().as_bytes() instead of a fixed value" /// )); /// -/// app.with(tide::utils::Before(|mut request: tide::Request<()>| async move { -/// let session = request.session_mut(); +/// app.with(tide::utils::Before(|mut req: tide::Request, state: tide::State<()>| async move { +/// let session = req.session_mut(); /// let visits: usize = session.get("visits").unwrap_or_default(); /// session.insert("visits", visits + 1).unwrap(); -/// request +/// (req, state) /// })); /// -/// app.at("/").get(|req: tide::Request<()>| async move { +/// app.at("/").get(|req: tide::Request, _| async move { /// let visits: usize = req.session().get("visits").unwrap(); /// Ok(format!("you have visited this website {} times", visits)) /// }); /// /// app.at("/reset") -/// .get(|mut req: tide::Request<()>| async move { +/// .get(|mut req: tide::Request, _| async move { /// req.session_mut().destroy(); /// Ok(tide::Redirect::new("/")) /// }); @@ -72,13 +75,18 @@ impl std::fmt::Debug for SessionMiddleware { } #[async_trait] -impl Middleware for SessionMiddleware +impl Middleware for SessionMiddleware where Store: SessionStore, - State: Clone + Send + Sync + 'static, + ServerState: Clone + Send + Sync + 'static, { - async fn handle(&self, mut request: Request, next: Next<'_, State>) -> crate::Result { - let cookie = request.cookie(&self.cookie_name); + async fn handle( + &self, + mut req: Request, + state: State, + next: Next<'_, ServerState>, + ) -> crate::Result { + let cookie = req.cookie(&self.cookie_name); let cookie_value = cookie .clone() .and_then(|cookie| self.verify_signature(cookie.value()).ok()); @@ -89,10 +97,10 @@ where session.expire_in(ttl); } - let secure_cookie = request.url().scheme() == "https"; - request.set_ext(session.clone()); + let secure_cookie = req.url().scheme() == "https"; + req.set_ext(session.clone()); - let mut response = next.run(request).await; + let mut response = next.run(req, state).await; if session.is_destroyed() { if let Err(e) = self.store.destroy_session(session).await { diff --git a/src/sse/endpoint.rs b/src/sse/endpoint.rs index fb81c26e8..47f3c057b 100644 --- a/src/sse/endpoint.rs +++ b/src/sse/endpoint.rs @@ -1,6 +1,6 @@ use crate::http::{mime, Body, StatusCode}; -use crate::log; use crate::sse::Sender; +use crate::{log, State}; use crate::{Endpoint, Request, Response, Result}; use async_std::future::Future; @@ -11,10 +11,10 @@ use std::marker::PhantomData; use std::sync::Arc; /// Create an endpoint that can handle SSE connections. -pub fn endpoint(handler: F) -> SseEndpoint +pub fn endpoint(handler: F) -> SseEndpoint where - State: Clone + Send + Sync + 'static, - F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, + ServerState: Clone + Send + Sync + 'static, + F: Fn(Request, State, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { SseEndpoint { @@ -25,29 +25,29 @@ where /// An endpoint that can handle SSE connections. #[derive(Debug)] -pub struct SseEndpoint +pub struct SseEndpoint where - State: Clone + Send + Sync + 'static, - F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, + ServerState: Clone + Send + Sync + 'static, + F: Fn(Request, State, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { handler: Arc, - __state: PhantomData, + __state: PhantomData, } #[async_trait::async_trait] -impl Endpoint for SseEndpoint +impl Endpoint for SseEndpoint where - State: Clone + Send + Sync + 'static, - F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, + ServerState: Clone + Send + Sync + 'static, + F: Fn(Request, State, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { - async fn call(&self, req: Request) -> Result { + async fn call(&self, req: Request, state: State) -> Result { let handler = self.handler.clone(); let (sender, encoder) = async_sse::encode(); task::spawn(async move { let sender = Sender::new(sender); - if let Err(err) = handler(req, sender).await { + if let Err(err) = handler(req, state, sender).await { log::error!("SSE handler error: {:?}", err); } }); diff --git a/src/sse/mod.rs b/src/sse/mod.rs index dc198b19c..29544bd17 100644 --- a/src/sse/mod.rs +++ b/src/sse/mod.rs @@ -17,7 +17,7 @@ //! use tide::sse; //! //! let mut app = tide::new(); -//! app.at("/sse").get(sse::endpoint(|_req, sender| async move { +//! app.at("/sse").get(sse::endpoint(|_req, _state, sender| async move { //! sender.send("fruit", "banana", None).await?; //! sender.send("fruit", "apple", None).await?; //! Ok(()) diff --git a/src/sse/upgrade.rs b/src/sse/upgrade.rs index 6cccda96d..42619263c 100644 --- a/src/sse/upgrade.rs +++ b/src/sse/upgrade.rs @@ -1,5 +1,5 @@ use crate::http::{mime, Body, StatusCode}; -use crate::log; +use crate::{log, State}; use crate::{Request, Response, Result}; use super::Sender; @@ -9,16 +9,16 @@ use async_std::io::BufReader; use async_std::task; /// Upgrade an existing HTTP connection to an SSE connection. -pub fn upgrade(req: Request, handler: F) -> Response +pub fn upgrade(req: Request, state: State, handler: F) -> Response where - State: Clone + Send + Sync + 'static, - F: Fn(Request, Sender) -> Fut + Send + Sync + 'static, + ServerState: Clone + Send + Sync + 'static, + F: Fn(Request, State, Sender) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { let (sender, encoder) = async_sse::encode(); task::spawn(async move { let sender = Sender::new(sender); - if let Err(err) = handler(req, sender).await { + if let Err(err) = handler(req, state, sender).await { log::error!("SSE handler error: {:?}", err); } }); diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 000000000..e51c59a15 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,141 @@ +// Implementation is based on +// - https://github.com/hyperium/http/blob/master/src/extensions.rs +// - https://github.com/kardeiz/type-map/blob/master/src/lib.rs + +use std::any::{Any, TypeId}; +use std::collections::HashMap; +use std::hash::{BuildHasherDefault, Hasher}; + +/// A shared state object. +/// +/// Server state can be categorized into two categories: +/// - Shared state between the whole application (e.g. an active database client). +/// - State local to the current request handler chain (e.g. are we authenticated?). +/// +/// Tide's `State` object provides a uniform interface to both kinds of state. +/// You can think of it as a `TypeMap`/`HashSet` that can be extended with +/// custom methods. +#[derive(Debug)] +pub struct State { + server_state: ServerState, + local_state: Option, BuildHasherDefault>>, +} + +impl std::ops::Deref for State { + type Target = ServerState; + + fn deref(&self) -> &Self::Target { + &self.server_state + } +} + +impl std::ops::DerefMut for State { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.server_state + } +} + +impl State { + /// Create an empty `State`. + #[inline] + pub(crate) fn new(server_state: ServerState) -> Self { + Self { + local_state: None, + server_state, + } + } + + /// Insert a value into this `State`. + /// + /// If a value of this type already exists, it will be returned. + pub fn insert(&mut self, val: T) -> Option { + self.local_state + .get_or_insert_with(Default::default) + .insert(TypeId::of::(), Box::new(val)) + .and_then(|boxed| (boxed as Box).downcast().ok().map(|boxed| *boxed)) + } + + /// Check if container contains value for type + pub fn contains(&self) -> bool { + self.local_state + .as_ref() + .and_then(|m| m.get(&TypeId::of::())) + .is_some() + } + + /// Get a reference to a value previously inserted on this `State`. + pub fn get(&self) -> Option<&T> { + self.local_state + .as_ref() + .and_then(|m| m.get(&TypeId::of::())) + .and_then(|boxed| (&**boxed as &(dyn Any)).downcast_ref()) + } + + /// Get a mutable reference to a value previously inserted on this `State`. + pub fn get_mut(&mut self) -> Option<&mut T> { + self.local_state + .as_mut() + .and_then(|m| m.get_mut(&TypeId::of::())) + .and_then(|boxed| (&mut **boxed as &mut (dyn Any)).downcast_mut()) + } + + /// Remove a value from this `State`. + /// + /// If a value of this type exists, it will be returned. + pub fn remove(&mut self) -> Option { + self.local_state + .as_mut() + .and_then(|m| m.remove(&TypeId::of::())) + .and_then(|boxed| (boxed as Box).downcast().ok().map(|boxed| *boxed)) + } + + /// Clear the `State` of all inserted values. + #[inline] + pub fn clear(&mut self) { + self.local_state = None; + } +} + +// With TypeIds as keys, there's no need to hash them. So we simply use an identy hasher. +#[derive(Default)] +struct IdHasher(u64); + +impl Hasher for IdHasher { + fn write(&mut self, _: &[u8]) { + unreachable!("TypeId calls write_u64"); + } + + #[inline] + fn write_u64(&mut self, id: u64) { + self.0 = id; + } + + #[inline] + fn finish(&self) -> u64 { + self.0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_extensions() { + #[derive(Debug, PartialEq)] + struct MyType(i32); + + let mut map = State::new(()); + + map.insert(5i32); + map.insert(MyType(10)); + + assert_eq!(map.get(), Some(&5i32)); + assert_eq!(map.get_mut(), Some(&mut 5i32)); + + assert_eq!(map.remove::(), Some(5i32)); + assert!(map.get::().is_none()); + + assert_eq!(map.get::(), None); + assert_eq!(map.get(), Some(&MyType(10))); + } +} diff --git a/src/utils.rs b/src/utils.rs index fde11b3ea..67f02b42e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,6 +1,6 @@ //! Miscellaneous utilities. -use crate::{Middleware, Next, Request, Response}; +use crate::{Middleware, Next, Request, Response, State}; pub use async_trait::async_trait; use std::future::Future; @@ -16,24 +16,29 @@ use std::future::Future; /// use std::time::Instant; /// /// let mut app = tide::new(); -/// app.with(utils::Before(|mut request: Request<()>| async move { -/// request.set_ext(Instant::now()); -/// request +/// app.with(utils::Before(|mut req: Request, state: tide::State<()>| async move { +/// req.set_ext(Instant::now()); +/// (req, state) /// })); /// ``` #[derive(Debug)] pub struct Before(pub F); #[async_trait] -impl Middleware for Before +impl Middleware for Before where - State: Clone + Send + Sync + 'static, - F: Fn(Request) -> Fut + Send + Sync + 'static, - Fut: Future> + Send + Sync + 'static, + ServerState: Clone + Send + Sync + 'static, + F: Fn(Request, State) -> Fut + Send + Sync + 'static, + Fut: Future)> + Send + Sync + 'static, { - async fn handle(&self, request: Request, next: Next<'_, State>) -> crate::Result { - let request = (self.0)(request).await; - Ok(next.run(request).await) + async fn handle( + &self, + req: Request, + state: State, + next: Next<'_, ServerState>, + ) -> crate::Result { + let (req, state) = (self.0)(req, state).await; + Ok(next.run(req, state).await) } } @@ -59,14 +64,19 @@ where #[derive(Debug)] pub struct After(pub F); #[async_trait] -impl Middleware for After +impl Middleware for After where - State: Clone + Send + Sync + 'static, + ServerState: Clone + Send + Sync + 'static, F: Fn(Response) -> Fut + Send + Sync + 'static, Fut: Future + Send + Sync + 'static, { - async fn handle(&self, request: Request, next: Next<'_, State>) -> crate::Result { - let response = next.run(request).await; - (self.0)(response).await + async fn handle( + &self, + req: Request, + state: State, + next: Next<'_, ServerState>, + ) -> crate::Result { + let res = next.run(req, state).await; + (self.0)(res).await } } diff --git a/tests/chunked-encode-large.rs b/tests/chunked-encode-large.rs index 428cea7d8..ba53ccd11 100644 --- a/tests/chunked-encode-large.rs +++ b/tests/chunked-encode-large.rs @@ -71,7 +71,7 @@ async fn chunked_large() -> Result<(), http_types::Error> { let server = task::spawn(async move { let mut app = tide::new(); app.at("/") - .get(|_| async { Ok(Body::from_reader(Cursor::new(TEXT), None)) }); + .get(|_, _| async { Ok(Body::from_reader(Cursor::new(TEXT), None)) }); app.listen(("localhost", port)).await?; Result::<(), http_types::Error>::Ok(()) }); diff --git a/tests/chunked-encode-small.rs b/tests/chunked-encode-small.rs index b4f06f198..17b6c5def 100644 --- a/tests/chunked-encode-small.rs +++ b/tests/chunked-encode-small.rs @@ -20,7 +20,7 @@ async fn chunked_large() -> Result<(), http_types::Error> { let server = task::spawn(async move { let mut app = tide::new(); app.at("/") - .get(|_| async { Ok(Body::from_reader(Cursor::new(TEXT), None)) }); + .get(|_, _| async { Ok(Body::from_reader(Cursor::new(TEXT), None)) }); app.listen(("localhost", port)).await?; Result::<(), http_types::Error>::Ok(()) }); diff --git a/tests/cookies.rs b/tests/cookies.rs index 3eed212e2..d2898abdf 100644 --- a/tests/cookies.rs +++ b/tests/cookies.rs @@ -6,7 +6,7 @@ use tide::{Request, Response, Server, StatusCode}; static COOKIE_NAME: &str = "testCookie"; -async fn retrieve_cookie(req: Request<()>) -> tide::Result { +async fn retrieve_cookie(req: Request, _state: tide::State<()>) -> tide::Result { Ok(format!( "{} and also {}", req.cookie(COOKIE_NAME).unwrap().value(), @@ -14,19 +14,19 @@ async fn retrieve_cookie(req: Request<()>) -> tide::Result { )) } -async fn insert_cookie(_req: Request<()>) -> tide::Result { +async fn insert_cookie(_req: Request, _state: tide::State<()>) -> tide::Result { let mut res = Response::new(StatusCode::Ok); res.insert_cookie(Cookie::new(COOKIE_NAME, "NewCookieValue")); Ok(res) } -async fn remove_cookie(_req: Request<()>) -> tide::Result { +async fn remove_cookie(_req: Request, _state: tide::State<()>) -> tide::Result { let mut res = Response::new(StatusCode::Ok); res.remove_cookie(Cookie::named(COOKIE_NAME)); Ok(res) } -async fn set_multiple_cookie(_req: Request<()>) -> tide::Result { +async fn set_multiple_cookie(_req: Request, _state: tide::State<()>) -> tide::Result { let mut res = Response::new(StatusCode::Ok); res.insert_cookie(Cookie::new("C1", "V1")); res.insert_cookie(Cookie::new("C2", "V2")); diff --git a/tests/endpoint.rs b/tests/endpoint.rs new file mode 100644 index 000000000..d4d05eb98 --- /dev/null +++ b/tests/endpoint.rs @@ -0,0 +1,25 @@ +use tide::http::{Method, Request, Url}; +use tide::Response; + +#[async_std::test] +async fn should_accept_boxed_endpoints() { + fn endpoint() -> Box> { + Box::new(|_, _| async { Ok("hello world") }) + } + + let mut app = tide::Server::new(); + app.at("/").get(endpoint()); + + let mut response: Response = app + .respond(Request::new( + Method::Get, + Url::parse("http://example.com/").unwrap(), + )) + .await + .unwrap(); + + assert_eq!( + response.take_body().into_string().await.unwrap(), + "hello world" + ); +} diff --git a/tests/function_middleware.rs b/tests/function_middleware.rs index fd3680a61..8324526da 100644 --- a/tests/function_middleware.rs +++ b/tests/function_middleware.rs @@ -1,28 +1,38 @@ use std::future::Future; use std::pin::Pin; -use tide::http::{self, url::Url, Method}; +use tide::{ + http::{self, url::Url, Method}, + State, +}; mod test_utils; -fn auth_middleware<'a>( - request: tide::Request<()>, - next: tide::Next<'a, ()>, -) -> Pin + 'a + Send>> { - let authenticated = match request.header("X-Auth") { +fn auth_middleware<'a, ServerState>( + req: tide::Request, + state: State, + next: tide::Next<'a, ServerState>, +) -> Pin + 'a + Send>> +where + ServerState: Clone + Send + Sync + 'static, +{ + let authenticated = match req.header("X-Auth") { Some(header) => header == "secret_key", None => false, }; Box::pin(async move { if authenticated { - Ok(next.run(request).await) + Ok(next.run(req, state).await) } else { Ok(tide::Response::new(tide::StatusCode::Unauthorized)) } }) } -async fn echo_path(req: tide::Request) -> tide::Result { +async fn echo_path( + req: tide::Request, + _state: State, +) -> tide::Result { Ok(req.url().path().to_string()) } diff --git a/tests/log.rs b/tests/log.rs index d6ddb9b74..5017d1f71 100644 --- a/tests/log.rs +++ b/tests/log.rs @@ -5,10 +5,11 @@ mod test_utils; use test_utils::ServerTestingExt; #[async_std::test] -async fn log_tests() { +async fn log_tests() -> tide::Result<()> { let mut logger = logtest::start(); test_server_listen(&mut logger).await; - test_only_log_once(&mut logger).await; + test_only_log_once(&mut logger).await?; + Ok(()) } async fn test_server_listen(logger: &mut logtest::Logger) { @@ -29,14 +30,14 @@ async fn test_server_listen(logger: &mut logtest::Logger) { ); } -async fn test_only_log_once(logger: &mut logtest::Logger) { +async fn test_only_log_once(logger: &mut logtest::Logger) -> tide::Result<()> { let mut app = tide::new(); app.at("/").nest({ let mut app = tide::new(); - app.at("/").get(|_| async { Ok("nested") }); + app.at("/").get(|_, _| async { Ok("nested") }); app }); - app.get("/").await; + assert!(app.get("/").await?.status().is_success()); let entries: Vec<_> = logger.collect(); @@ -55,4 +56,5 @@ async fn test_only_log_once(logger: &mut logtest::Logger) { .filter(|entry| entry.args() == "--> Response sent") .count() ); + Ok(()) } diff --git a/tests/nested.rs b/tests/nested.rs index ffbabdb1d..50b8b5389 100644 --- a/tests/nested.rs +++ b/tests/nested.rs @@ -2,22 +2,23 @@ mod test_utils; use test_utils::ServerTestingExt; #[async_std::test] -async fn nested() { +async fn nested() -> tide::Result<()> { let mut inner = tide::new(); - inner.at("/foo").get(|_| async { Ok("foo") }); - inner.at("/bar").get(|_| async { Ok("bar") }); + inner.at("/foo").get(|_, _| async { Ok("foo") }); + inner.at("/bar").get(|_, _| async { Ok("bar") }); let mut outer = tide::new(); // Nest the inner app on /foo outer.at("/foo").nest(inner); - assert_eq!(outer.get_body("/foo/foo").await, "foo"); - assert_eq!(outer.get_body("/foo/bar").await, "bar"); + assert_eq!(outer.get("/foo/foo").recv_string().await?, "foo"); + assert_eq!(outer.get("/foo/bar").recv_string().await?, "bar"); + Ok(()) } #[async_std::test] -async fn nested_middleware() { - let echo_path = |req: tide::Request<()>| async move { Ok(req.url().path().to_string()) }; +async fn nested_middleware() -> tide::Result<()> { + let echo_path = |req: tide::Request, _| async move { Ok(req.url().path().to_string()) }; let mut app = tide::new(); let mut inner_app = tide::new(); inner_app.with(tide::utils::After(|mut res: tide::Response| async move { @@ -29,33 +30,36 @@ async fn nested_middleware() { app.at("/foo").nest(inner_app); app.at("/bar").get(echo_path); - let mut res = app.get("/foo/echo").await; + let mut res = app.get("/foo/echo").await?; assert_eq!(res["X-Tide-Test"], "1"); assert_eq!(res.status(), 200); - assert_eq!(res.body_string().await.unwrap(), "/echo"); + assert_eq!(res.body_string().await?, "/echo"); - let mut res = app.get("/foo/x/bar").await; + let mut res = app.get("/foo/x/bar").await?; assert_eq!(res["X-Tide-Test"], "1"); assert_eq!(res.status(), 200); - assert_eq!(res.body_string().await.unwrap(), "/"); + assert_eq!(res.body_string().await?, "/"); - let mut res = app.get("/bar").await; + let mut res = app.get("/bar").await?; assert!(res.header("X-Tide-Test").is_none()); assert_eq!(res.status(), 200); - assert_eq!(res.body_string().await.unwrap(), "/bar"); + assert_eq!(res.body_string().await?, "/bar"); + Ok(()) } #[async_std::test] -async fn nested_with_different_state() { +async fn nested_with_different_state() -> tide::Result<()> { let mut outer = tide::new(); let mut inner = tide::with_state(42); - inner.at("/").get(|req: tide::Request| async move { - let num = req.state(); - Ok(format!("the number is {}", num)) - }); - outer.at("/").get(|_| async { Ok("Hello, world!") }); + inner + .at("/") + .get( + |_req: tide::Request, state: i32| async move { Ok(format!("the number is {}", state)) }, + ); + outer.at("/").get(|_, _| async { Ok("Hello, world!") }); outer.at("/foo").nest(inner); - assert_eq!(outer.get_body("/foo").await, "the number is 42"); - assert_eq!(outer.get_body("/").await, "Hello, world!"); + assert_eq!(outer.get("/foo").recv_string().await?, "the number is 42"); + assert_eq!(outer.get("/").recv_string().await?, "Hello, world!"); + Ok(()) } diff --git a/tests/params.rs b/tests/params.rs index 6dea9c632..4f89a453a 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,69 +1,39 @@ use http_types::{self, Method, Url}; -use tide::{self, Request, Response, Result, StatusCode}; +use tide::{self, Request, Response, Result}; #[async_std::test] -async fn test_param_invalid_type() { - async fn get_by_id(req: Request<()>) -> Result { - assert_eq!( - req.param::("id").unwrap_err().to_string(), - "Param failed to parse: invalid digit found in string" - ); - let _ = req.param::("id")?; - Result::Ok(Response::new(StatusCode::Ok)) +async fn test_missing_param() -> tide::Result<()> { + async fn greet(req: Request, _state: tide::State<()>) -> Result { + assert_eq!(req.param("name")?, "Param \"name\" not found"); + Ok(Response::new(200)) } - let mut server = tide::new(); - server.at("/by_id/:id").get(get_by_id); - - let req = http_types::Request::new( - Method::Get, - Url::parse("http://example.com/by_id/wrong").unwrap(), - ); - let res: http_types::Response = server.respond(req).await.unwrap(); - assert_eq!(res.status(), StatusCode::InternalServerError); -} -#[async_std::test] -async fn test_missing_param() { - async fn greet(req: Request<()>) -> Result { - assert_eq!( - req.param::("name").unwrap_err().to_string(), - "Param \"name\" not found!" - ); - let _: String = req.param("name")?; - Result::Ok(Response::new(StatusCode::Ok)) - } let mut server = tide::new(); server.at("/").get(greet); - let req = http_types::Request::new(Method::Get, Url::parse("http://example.com/").unwrap()); - let res: http_types::Response = server.respond(req).await.unwrap(); - assert_eq!(res.status(), StatusCode::InternalServerError); + let req = http_types::Request::new(Method::Get, Url::parse("http://example.com/")?); + let res: http_types::Response = server.respond(req).await?; + assert_eq!(res.status(), 500); + Ok(()) } #[async_std::test] -async fn hello_world_parametrized() { - async fn greet(req: tide::Request<()>) -> Result { - let name = req.param("name").unwrap_or_else(|_| "nori".to_owned()); - let mut response = tide::Response::new(StatusCode::Ok); - response.set_body(format!("{} says hello", name)); - Ok(response) +async fn hello_world_parametrized() -> Result<()> { + async fn greet(req: tide::Request, _state: tide::State<()>) -> Result> { + let body = format!("{} says hello", req.param("name").unwrap_or("nori")); + Ok(Response::builder(200).body(body)) } let mut server = tide::new(); server.at("/").get(greet); server.at("/:name").get(greet); - let req = http_types::Request::new(Method::Get, Url::parse("http://example.com/").unwrap()); - let mut res: http_types::Response = server.respond(req).await.unwrap(); - assert_eq!( - res.body_string().await.unwrap(), - "nori says hello".to_string() - ); + let req = http_types::Request::new(Method::Get, Url::parse("http://example.com/")?); + let mut res: http_types::Response = server.respond(req).await?; + assert_eq!(res.body_string().await?, "nori says hello"); - let req = http_types::Request::new(Method::Get, Url::parse("http://example.com/iron").unwrap()); - let mut res: http_types::Response = server.respond(req).await.unwrap(); - assert_eq!( - res.body_string().await.unwrap(), - "iron says hello".to_string() - ); + let req = http_types::Request::new(Method::Get, Url::parse("http://example.com/iron")?); + let mut res: http_types::Response = server.respond(req).await?; + assert_eq!(res.body_string().await?, "iron says hello"); + Ok(()) } diff --git a/tests/response.rs b/tests/response.rs index 704c7ff6c..79cd1858c 100644 --- a/tests/response.rs +++ b/tests/response.rs @@ -36,12 +36,12 @@ async fn string_content_type() { } #[async_std::test] -async fn json_content_type() { +async fn json_content_type() -> tide::Result<()> { use std::collections::BTreeMap; use tide::Body; let mut app = tide::new(); - app.at("/json_content_type").get(|_| async { + app.at("/json_content_type").get(|_, _| async { let mut map = BTreeMap::new(); map.insert(Some("a"), 2); map.insert(Some("b"), 4); @@ -51,7 +51,7 @@ async fn json_content_type() { Ok(resp) }); - let mut resp = app.get("/json_content_type").await; + let mut resp = app.get("/json_content_type").await?; assert_eq!(resp.status(), StatusCode::InternalServerError); assert_eq!(resp.body_string().await.unwrap(), ""); @@ -68,6 +68,8 @@ async fn json_content_type() { assert_eq!(resp.status(), StatusCode::Ok); let body = resp.take_body().into_bytes().await.unwrap(); assert_eq!(body, br##"{"a":2,"b":4,"c":6}"##); + + Ok(()) } #[test] diff --git a/tests/route_middleware.rs b/tests/route_middleware.rs index 1ec0083fc..0221106b2 100644 --- a/tests/route_middleware.rs +++ b/tests/route_middleware.rs @@ -2,7 +2,7 @@ mod test_utils; use http_types::headers::HeaderName; use std::convert::TryInto; use test_utils::ServerTestingExt; -use tide::Middleware; +use tide::{Middleware, State}; #[derive(Debug)] struct TestMiddleware(HeaderName, &'static str); @@ -14,24 +14,28 @@ impl TestMiddleware { } #[async_trait::async_trait] -impl Middleware for TestMiddleware { +impl Middleware for TestMiddleware { async fn handle( &self, - req: tide::Request, - next: tide::Next<'_, State>, + req: tide::Request, + state: State, + next: tide::Next<'_, ServerState>, ) -> tide::Result { - let mut res = next.run(req).await; + let mut res = next.run(req, state).await; res.insert_header(self.0.clone(), self.1); Ok(res) } } -async fn echo_path(req: tide::Request) -> tide::Result { +async fn echo_path( + req: tide::Request, + _state: State, +) -> tide::Result { Ok(req.url().path().to_string()) } #[async_std::test] -async fn route_middleware() { +async fn route_middleware() -> tide::Result<()> { let mut app = tide::new(); let mut foo_route = app.at("/foo"); foo_route // /foo @@ -46,17 +50,18 @@ async fn route_middleware() { .reset_middleware() .put(echo_path); - assert_eq!(app.get("/foo").await["X-Foo"], "foo"); - assert_eq!(app.post("/foo").await["X-Foo"], "foo"); - assert!(app.put("/foo").await.header("X-Foo").is_none()); + assert_eq!(app.get("/foo").await?["X-Foo"], "foo"); + assert_eq!(app.post("/foo").await?["X-Foo"], "foo"); + assert!(app.put("/foo").await?.header("X-Foo").is_none()); - let res = app.get("/foo/bar").await; + let res = app.get("/foo/bar").await?; assert_eq!(res["X-Foo"], "foo"); assert_eq!(res["x-bar"], "bar"); + Ok(()) } #[async_std::test] -async fn app_and_route_middleware() { +async fn app_and_route_middleware() -> tide::Result<()> { let mut app = tide::new(); app.with(TestMiddleware::with_header_name("X-Root", "root")); app.at("/foo") @@ -66,19 +71,20 @@ async fn app_and_route_middleware() { .with(TestMiddleware::with_header_name("X-Bar", "bar")) .get(echo_path); - let res = app.get("/foo").await; + let res = app.get("/foo").await?; assert_eq!(res["X-Root"], "root"); assert_eq!(res["x-foo"], "foo"); assert!(res.header("x-bar").is_none()); - let res = app.get("/bar").await; + let res = app.get("/bar").await?; assert_eq!(res["X-Root"], "root"); assert!(res.header("x-foo").is_none()); assert_eq!(res["X-Bar"], "bar"); + Ok(()) } #[async_std::test] -async fn nested_app_with_route_middleware() { +async fn nested_app_with_route_middleware() -> tide::Result<()> { let mut inner = tide::new(); inner.with(TestMiddleware::with_header_name("X-Inner", "inner")); inner @@ -95,23 +101,24 @@ async fn nested_app_with_route_middleware() { .with(TestMiddleware::with_header_name("X-Bar", "bar")) .nest(inner); - let res = app.get("/foo").await; + let res = app.get("/foo").await?; assert_eq!(res["X-Root"], "root"); assert!(res.header("X-Inner").is_none()); assert_eq!(res["X-Foo"], "foo"); assert!(res.header("X-Bar").is_none()); assert!(res.header("X-Baz").is_none()); - let res = app.get("/bar/baz").await; + let res = app.get("/bar/baz").await?; assert_eq!(res["X-Root"], "root"); assert_eq!(res["X-Inner"], "inner"); assert!(res.header("X-Foo").is_none()); assert_eq!(res["X-Bar"], "bar"); assert_eq!(res["X-Baz"], "baz"); + Ok(()) } #[async_std::test] -async fn subroute_not_nested() { +async fn subroute_not_nested() -> tide::Result<()> { let mut app = tide::new(); app.at("/parent") // /parent .with(TestMiddleware::with_header_name("X-Parent", "Parent")) @@ -120,7 +127,8 @@ async fn subroute_not_nested() { .with(TestMiddleware::with_header_name("X-Child", "child")) .get(echo_path); - let res = app.get("/parent/child").await; + let res = app.get("/parent/child").await?; assert!(res.header("X-Parent").is_none()); assert_eq!(res["x-child"], "child"); + Ok(()) } diff --git a/tests/server.rs b/tests/server.rs index 8ffd591bb..8709839ab 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -7,12 +7,12 @@ use serde::{Deserialize, Serialize}; use tide::{Body, Request}; #[test] -fn hello_world() -> Result<(), http_types::Error> { +fn hello_world() -> tide::Result<()> { task::block_on(async { let port = test_utils::find_port().await; let server = task::spawn(async move { let mut app = tide::new(); - app.at("/").get(move |mut req: Request<()>| async move { + app.at("/").get(move |mut req: Request, _state| async move { assert_eq!(req.body_string().await.unwrap(), "nori".to_string()); assert!(req.local_addr().unwrap().contains(&port.to_string())); assert!(req.peer_addr().is_some()); @@ -25,7 +25,7 @@ fn hello_world() -> Result<(), http_types::Error> { let client = task::spawn(async move { task::sleep(Duration::from_millis(100)).await; let string = surf::get(format!("http://localhost:{}", port)) - .body_string("nori".to_string()) + .body(Body::from_string("nori".to_string())) .recv_string() .await .unwrap(); @@ -38,12 +38,12 @@ fn hello_world() -> Result<(), http_types::Error> { } #[test] -fn echo_server() -> Result<(), http_types::Error> { +fn echo_server() -> tide::Result<()> { task::block_on(async { let port = test_utils::find_port().await; let server = task::spawn(async move { let mut app = tide::new(); - app.at("/").get(|req| async move { Ok(req) }); + app.at("/").get(|req, _| async move { Ok(req) }); app.listen(("localhost", port)).await?; Result::<(), http_types::Error>::Ok(()) @@ -52,7 +52,7 @@ fn echo_server() -> Result<(), http_types::Error> { let client = task::spawn(async move { task::sleep(Duration::from_millis(100)).await; let string = surf::get(format!("http://localhost:{}", port)) - .body_string("chashu".to_string()) + .body(Body::from_string("chashu".to_string())) .recv_string() .await .unwrap(); @@ -65,7 +65,7 @@ fn echo_server() -> Result<(), http_types::Error> { } #[test] -fn json() -> Result<(), http_types::Error> { +fn json() -> tide::Result<()> { #[derive(Deserialize, Serialize)] struct Counter { count: usize, @@ -75,7 +75,7 @@ fn json() -> Result<(), http_types::Error> { let port = test_utils::find_port().await; let server = task::spawn(async move { let mut app = tide::new(); - app.at("/").get(|mut req: Request<()>| async move { + app.at("/").get(|mut req: Request, _state| async move { let mut counter: Counter = req.body_json().await.unwrap(); assert_eq!(counter.count, 0); counter.count = 1; @@ -88,7 +88,7 @@ fn json() -> Result<(), http_types::Error> { let client = task::spawn(async move { task::sleep(Duration::from_millis(100)).await; let counter: Counter = surf::get(format!("http://localhost:{}", &port)) - .body_json(&Counter { count: 0 })? + .body(Body::from_json(&Counter { count: 0 })?) .recv_json() .await .unwrap(); diff --git a/tests/sessions.rs b/tests/sessions.rs index 06a4b95be..311111e45 100644 --- a/tests/sessions.rs +++ b/tests/sessions.rs @@ -2,16 +2,13 @@ mod test_utils; use test_utils::ServerTestingExt; use cookie::SameSite; +use http_types::headers::SET_COOKIE; use std::time::Duration; use tide::{ - http::{ - cookies as cookie, - headers::HeaderValue, - Method::{Get, Post}, - Request, Response, Url, - }, + http::{cookies as cookie, headers::HeaderValue, Response}, sessions::{MemoryStore, SessionMiddleware}, utils::Before, + State, }; #[derive(Clone, Debug, Default, PartialEq)] struct SessionData { @@ -19,25 +16,27 @@ struct SessionData { } #[async_std::test] -async fn test_basic_sessions() { +async fn test_basic_sessions() -> tide::Result<()> { let mut app = tide::new(); app.with(SessionMiddleware::new( MemoryStore::new(), b"12345678901234567890123456789012345", )); - app.with(Before(|mut request: tide::Request<()>| async move { - let visits: usize = request.session().get("visits").unwrap_or_default(); - request.session_mut().insert("visits", visits + 1).unwrap(); - request - })); + app.with(Before( + |mut req: tide::Request, state: tide::State<()>| async move { + let visits: usize = req.session().get("visits").unwrap_or_default(); + req.session_mut().insert("visits", visits + 1).unwrap(); + (req, state) + }, + )); - app.at("/").get(|req: tide::Request<()>| async move { + app.at("/").get(|req: tide::Request, _| async move { let visits: usize = req.session().get("visits").unwrap(); Ok(format!("you have visited this website {} times", visits)) }); - let response = app.get("/").await; + let response = app.get("/").await?; let cookies = Cookies::from_response(&response); let cookie = &cookies["tide.sid"]; assert_eq!(cookie.name(), "tide.sid"); @@ -46,21 +45,21 @@ async fn test_basic_sessions() { assert_eq!(cookie.secure(), None); // this request was http:// assert_eq!(cookie.path(), Some("/")); - let mut second_request = Request::new(Get, Url::parse("https://whatever/").unwrap()); - second_request.insert_header("Cookie", &cookies); - let mut second_response: Response = app.respond(second_request).await.unwrap(); - let body = second_response.body_string().await.unwrap(); + let mut second_response = app.get("/").header("Cookie", &cookies).await?; + + let body = second_response.body_string().await?; assert_eq!("you have visited this website 2 times", body); assert!(second_response.header("Set-Cookie").is_none()); - let response = app.get("https://secure/").await; + let response = app.get("https://secure/").await?; let cookies = Cookies::from_response(&response); let cookie = &cookies["tide.sid"]; assert_eq!(cookie.secure(), Some(true)); + Ok(()) } #[async_std::test] -async fn test_customized_sessions() { +async fn test_customized_sessions() -> tide::Result<()> { let mut app = tide::new(); app.with( SessionMiddleware::new(MemoryStore::new(), b"12345678901234567890123456789012345") @@ -71,31 +70,31 @@ async fn test_customized_sessions() { .without_save_unchanged(), ); - app.at("/").get(|_| async { Ok("/") }); - app.at("/nested").get(|req: tide::Request<()>| async move { + app.at("/").get(|_, _| async { Ok("/") }); + app.at("/nested").get(|req: tide::Request, _| async move { Ok(format!( "/nested {}", req.session().get::("visits").unwrap_or_default() )) }); app.at("/nested/incr") - .get(|mut req: tide::Request<()>| async move { + .get(|mut req: tide::Request, _| async move { let mut visits: usize = req.session().get("visits").unwrap_or_default(); visits += 1; - req.session_mut().insert("visits", visits).unwrap(); + req.session_mut().insert("visits", visits)?; Ok(format!("/nested/incr {}", visits)) }); - let response = app.get("/").await; + let response = app.get("/").await?; assert_eq!(Cookies::from_response(&response).len(), 0); - let mut response = app.get("/nested").await; + let mut response = app.get("/nested").await?; assert_eq!(Cookies::from_response(&response).len(), 0); - assert_eq!(response.body_string().await.unwrap(), "/nested 0"); + assert_eq!(response.body_string().await?, "/nested 0"); - let mut response = app.get("/nested/incr").await; + let mut response = app.get("/nested/incr").await?; let cookies = Cookies::from_response(&response); - assert_eq!(response.body_string().await.unwrap(), "/nested/incr 1"); + assert_eq!(response.body_string().await?, "/nested/incr 1"); assert_eq!(cookies.len(), 1); assert!(cookies.get("tide.sid").is_none()); @@ -105,61 +104,67 @@ async fn test_customized_sessions() { assert_eq!(cookie.path(), Some("/nested")); let cookie_value = cookie.value().to_string(); - let mut second_request = Request::new(Get, Url::parse("https://whatever/nested/incr").unwrap()); - second_request.insert_header("Cookie", &cookies); - let mut second_response: Response = app.respond(second_request).await.unwrap(); - let body = second_response.body_string().await.unwrap(); + let mut second_response = app + .get("https://whatever/nested/incr") + .header("Cookie", &cookies) + .await?; + let body = second_response.body_string().await?; assert_eq!("/nested/incr 2", body); assert!(second_response.header("Set-Cookie").is_none()); async_std::task::sleep(Duration::from_secs(5)).await; // wait for expiration - let mut expired_request = - Request::new(Get, Url::parse("https://whatever/nested/incr").unwrap()); - expired_request.insert_header("Cookie", &cookies); - let mut expired_response: Response = app.respond(expired_request).await.unwrap(); + let mut expired_response = app + .get("https://whatever/nested/incr") + .header("Cookie", &cookies) + .await?; let cookies = Cookies::from_response(&expired_response); assert_eq!(cookies.len(), 1); assert!(cookies["custom.cookie.name"].value() != cookie_value); - let body = expired_response.body_string().await.unwrap(); + let body = expired_response.body_string().await?; assert_eq!("/nested/incr 1", body); + Ok(()) } #[async_std::test] -async fn test_session_destruction() { +async fn test_session_destruction() -> tide::Result<()> { let mut app = tide::new(); app.with(SessionMiddleware::new( MemoryStore::new(), b"12345678901234567890123456789012345", )); - app.with(Before(|mut request: tide::Request<()>| async move { - let visits: usize = request.session().get("visits").unwrap_or_default(); - request.session_mut().insert("visits", visits + 1).unwrap(); - request - })); + app.with(Before( + |mut req: tide::Request, state: State<()>| async move { + let visits: usize = req.session().get("visits").unwrap_or_default(); + req.session_mut().insert("visits", visits + 1).unwrap(); + (req, state) + }, + )); - app.at("/").get(|req: tide::Request<()>| async move { + app.at("/").get(|req: tide::Request, _| async move { let visits: usize = req.session().get("visits").unwrap(); Ok(format!("you have visited this website {} times", visits)) }); app.at("/logout") - .post(|mut req: tide::Request<()>| async move { + .post(|mut req: tide::Request, _| async move { req.session_mut().destroy(); Ok(Response::new(200)) }); - let response = app.get("/").await; + let response = app.get("/").await?; let cookies = Cookies::from_response(&response); - let mut second_request = Request::new(Post, Url::parse("https://whatever/logout").unwrap()); - second_request.insert_header("Cookie", &cookies); - let second_response: Response = app.respond(second_request).await.unwrap(); + let second_response = app + .post("https://whatever/logout") + .header("Cookie", &cookies) + .await?; let cookies = Cookies::from_response(&second_response); assert_eq!(cookies["tide.sid"].value(), ""); assert_eq!(cookies.len(), 1); + Ok(()) } #[derive(Debug, Clone)] @@ -169,9 +174,10 @@ impl Cookies { self.0.len() } - fn from_response(response: &http_types::Response) -> Self { + fn from_response(response: &impl AsRef) -> Self { response - .header("Set-Cookie") + .as_ref() + .header(SET_COOKIE) .map(|hv| hv.to_string()) .unwrap_or_else(|| "[]".into()) .parse() diff --git a/tests/test_utils.rs b/tests/test_utils.rs index 42152d14f..c8a84b3d1 100644 --- a/tests/test_utils.rs +++ b/tests/test_utils.rs @@ -1,6 +1,4 @@ use portpicker::pick_unused_port; -use tide::http::{self, url::Url, Method}; -use tide::Server; /// Find an unused port. #[allow(dead_code)] @@ -8,53 +6,65 @@ pub async fn find_port() -> u16 { pick_unused_port().expect("No ports free") } -#[async_trait::async_trait] +use surf::{Client, RequestBuilder}; + +/// Trait that adds test request capabilities to tide [`Server`]s pub trait ServerTestingExt { - async fn request(&self, method: Method, path: &str) -> http::Response; - async fn request_body(&self, method: Method, path: &str) -> String; - async fn get(&self, path: &str) -> http::Response; - async fn get_body(&self, path: &str) -> String; - async fn post(&self, path: &str) -> http::Response; - async fn put(&self, path: &str) -> http::Response; -} + /// Construct a new surf Client + fn client(&self) -> Client; -#[async_trait::async_trait] -impl ServerTestingExt for Server -where - State: Clone + Send + Sync + 'static, -{ - async fn request(&self, method: Method, path: &str) -> http::Response { - let url = if path.starts_with("http:") || path.starts_with("https:") { - Url::parse(path).unwrap() - } else { - Url::parse("http://example.com/") - .unwrap() - .join(path) - .unwrap() - }; + /// Builds a `CONNECT` request. + fn connect(&self, uri: &str) -> RequestBuilder { + self.client().connect(uri) + } + + /// Builds a `DELETE` request. + fn delete(&self, uri: &str) -> RequestBuilder { + self.client().delete(uri) + } - let request = http::Request::new(method, url); - self.respond(request).await.unwrap() + /// Builds a `GET` request. + fn get(&self, uri: &str) -> RequestBuilder { + self.client().get(uri) } - async fn request_body(&self, method: Method, path: &str) -> String { - let mut response = self.request(method, path).await; - response.body_string().await.unwrap() + /// Builds a `HEAD` request. + fn head(&self, uri: &str) -> RequestBuilder { + self.client().head(uri) } - async fn get(&self, path: &str) -> http::Response { - self.request(Method::Get, path).await + /// Builds an `OPTIONS` request. + fn options(&self, uri: &str) -> RequestBuilder { + self.client().options(uri) } - async fn get_body(&self, path: &str) -> String { - self.request_body(Method::Get, path).await + /// Builds a `PATCH` request. + fn patch(&self, uri: &str) -> RequestBuilder { + self.client().patch(uri) } - async fn post(&self, path: &str) -> http::Response { - self.request(Method::Post, path).await + /// Builds a `POST` request. + fn post(&self, uri: &str) -> RequestBuilder { + self.client().post(uri) } - async fn put(&self, path: &str) -> http::Response { - self.request(Method::Put, path).await + /// Builds a `PUT` request. + fn put(&self, uri: &str) -> RequestBuilder { + self.client().put(uri) + } + + /// Builds a `TRACE` request. + fn trace(&self, uri: &str) -> RequestBuilder { + self.client().trace(uri) + } +} + +impl ServerTestingExt + for tide::Server +{ + fn client(&self) -> Client { + let mut client = Client::with_http_client(self.clone()); + client.set_base_url(tide::http::Url::parse("http://example.com").unwrap()); + client } } diff --git a/tests/unix.rs b/tests/unix.rs index bc28a5ed0..a130ff4de 100644 --- a/tests/unix.rs +++ b/tests/unix.rs @@ -16,7 +16,7 @@ mod unix_tests { let server = task::spawn(async move { let mut app = tide::new(); - app.at("/").get(|req: tide::Request<()>| async move { + app.at("/").get(|req: tide::Request, _| async move { Ok(req.local_addr().unwrap().to_string()) }); app.listen(sock_path).await?; diff --git a/tests/wildcard.rs b/tests/wildcard.rs index e2d526e4b..e6ade6797 100644 --- a/tests/wildcard.rs +++ b/tests/wildcard.rs @@ -1,125 +1,149 @@ mod test_utils; use test_utils::ServerTestingExt; -use tide::{Request, StatusCode}; -async fn add_one(req: Request<()>) -> Result { - match req.param::("num") { - Ok(num) => Ok((num + 1).to_string()), - Err(err) => Err(tide::Error::new(StatusCode::BadRequest, err)), - } +use tide::{Error, Request, StatusCode}; + +async fn add_one(req: Request, _state: tide::State<()>) -> Result { + let num: i64 = req + .param("num")? + .parse() + .map_err(|err| Error::new(StatusCode::BadRequest, err))?; + Ok((num + 1).to_string()) } -async fn add_two(req: Request<()>) -> Result { - let one = req - .param::("one") - .map_err(|err| tide::Error::new(StatusCode::BadRequest, err))?; - let two = req - .param::("two") - .map_err(|err| tide::Error::new(StatusCode::BadRequest, err))?; +async fn add_two(req: Request, _state: tide::State<()>) -> Result { + let one: i64 = req + .param("one")? + .parse() + .map_err(|err| Error::new(StatusCode::BadRequest, err))?; + let two: i64 = req + .param("two")? + .parse() + .map_err(|err| Error::new(StatusCode::BadRequest, err))?; Ok((one + two).to_string()) } -async fn echo_path(req: Request<()>) -> Result { - match req.param::("path") { - Ok(path) => Ok(path), - Err(err) => Err(tide::Error::new(StatusCode::BadRequest, err)), +async fn echo_path(req: Request, _state: tide::State<()>) -> Result { + match req.param("path") { + Ok(path) => Ok(path.into()), + Err(mut err) => { + err.set_status(StatusCode::BadRequest); + Err(err) + } } } #[async_std::test] -async fn wildcard() { +async fn wildcard() -> tide::Result<()> { let mut app = tide::Server::new(); app.at("/add_one/:num").get(add_one); - assert_eq!(app.get_body("/add_one/3").await, "4"); - assert_eq!(app.get_body("/add_one/-7").await, "-6"); + assert_eq!(app.get("/add_one/3").recv_string().await?, "4"); + assert_eq!(app.get("/add_one/-7").recv_string().await?, "-6"); + Ok(()) } #[async_std::test] -async fn invalid_segment_error() { +async fn invalid_segment_error() -> tide::Result<()> { let mut app = tide::new(); app.at("/add_one/:num").get(add_one); - assert_eq!(app.get("/add_one/a").await.status(), StatusCode::BadRequest); + assert_eq!( + app.get("/add_one/a").await?.status(), + StatusCode::BadRequest + ); + Ok(()) } #[async_std::test] -async fn not_found_error() { +async fn not_found_error() -> tide::Result<()> { let mut app = tide::new(); app.at("/add_one/:num").get(add_one); - assert_eq!(app.get("/add_one/").await.status(), StatusCode::NotFound); + assert_eq!(app.get("/add_one/").await?.status(), StatusCode::NotFound); + Ok(()) } #[async_std::test] -async fn wild_path() { +async fn wild_path() -> tide::Result<()> { let mut app = tide::new(); app.at("/echo/*path").get(echo_path); - assert_eq!(app.get_body("/echo/some_path").await, "some_path"); + assert_eq!(app.get("/echo/some_path").recv_string().await?, "some_path"); assert_eq!( - app.get_body("/echo/multi/segment/path").await, + app.get("/echo/multi/segment/path").recv_string().await?, "multi/segment/path" ); - assert_eq!(app.get("/echo/").await.status(), StatusCode::NotFound); + assert_eq!(app.get("/echo/").await?.status(), StatusCode::NotFound); + Ok(()) } #[async_std::test] -async fn multi_wildcard() { +async fn multi_wildcard() -> tide::Result<()> { let mut app = tide::new(); app.at("/add_two/:one/:two/").get(add_two); - assert_eq!(app.get_body("/add_two/1/2/").await, "3"); - assert_eq!(app.get_body("/add_two/-1/2/").await, "1"); - assert_eq!(app.get("/add_two/1").await.status(), StatusCode::NotFound); + assert_eq!(app.get("/add_two/1/2/").recv_string().await?, "3"); + assert_eq!(app.get("/add_two/-1/2/").recv_string().await?, "1"); + assert_eq!(app.get("/add_two/1").await?.status(), StatusCode::NotFound); + Ok(()) } #[async_std::test] -async fn wild_last_segment() { +async fn wild_last_segment() -> tide::Result<()> { let mut app = tide::new(); app.at("/echo/:path/*").get(echo_path); - assert_eq!(app.get_body("/echo/one/two").await, "one"); - assert_eq!(app.get_body("/echo/one/two/three/four").await, "one"); + assert_eq!(app.get("/echo/one/two").recv_string().await?, "one"); + assert_eq!( + app.get("/echo/one/two/three/four").recv_string().await?, + "one" + ); + Ok(()) } #[async_std::test] -async fn invalid_wildcard() { +async fn invalid_wildcard() -> tide::Result<()> { let mut app = tide::new(); app.at("/echo/*path/:one/").get(echo_path); assert_eq!( - app.get("/echo/one/two").await.status(), + app.get("/echo/one/two").await?.status(), StatusCode::NotFound ); + Ok(()) } #[async_std::test] -async fn nameless_wildcard() { +async fn nameless_wildcard() -> tide::Result<()> { let mut app = tide::Server::new(); - app.at("/echo/:").get(|_| async { Ok("") }); + app.at("/echo/:").get(|_, _| async { Ok("") }); assert_eq!( - app.get("/echo/one/two").await.status(), + app.get("/echo/one/two").await?.status(), StatusCode::NotFound ); - assert_eq!(app.get("/echo/one").await.status(), StatusCode::Ok); + assert_eq!(app.get("/echo/one").await?.status(), StatusCode::Ok); + Ok(()) } #[async_std::test] -async fn nameless_internal_wildcard() { +async fn nameless_internal_wildcard() -> tide::Result<()> { let mut app = tide::new(); app.at("/echo/:/:path").get(echo_path); - assert_eq!(app.get("/echo/one").await.status(), StatusCode::NotFound); - assert_eq!(app.get_body("/echo/one/two").await, "two"); + assert_eq!(app.get("/echo/one").await?.status(), StatusCode::NotFound); + assert_eq!(app.get("/echo/one/two").recv_string().await?, "two"); + Ok(()) } #[async_std::test] -async fn nameless_internal_wildcard2() { +async fn nameless_internal_wildcard2() -> tide::Result<()> { let mut app = tide::new(); - app.at("/echo/:/:path").get(|req: Request<()>| async move { - assert_eq!(req.param::("path")?, "two"); + app.at("/echo/:/:path").get(|req: Request, _| async move { + assert_eq!(req.param("path")?, "two"); Ok("") }); - app.get("/echo/one/two").await; + assert!(app.get("/echo/one/two").await?.status().is_success()); + Ok(()) } #[async_std::test] -async fn ambiguous_router_wildcard_vs_star() { +async fn ambiguous_router_wildcard_vs_star() -> tide::Result<()> { let mut app = tide::new(); - app.at("/:one/:two").get(|_| async { Ok("one/two") }); - app.at("/posts/*").get(|_| async { Ok("posts/*") }); - assert_eq!(app.get_body("/posts/10").await, "posts/*"); + app.at("/:one/:two").get(|_, _| async { Ok("one/two") }); + app.at("/posts/*").get(|_, _| async { Ok("posts/*") }); + assert_eq!(app.get("/posts/10").recv_string().await?, "posts/*"); + Ok(()) }