diff --git a/axum-core/CHANGELOG.md b/axum-core/CHANGELOG.md index da272e5f8a..a3cbc1fb32 100644 --- a/axum-core/CHANGELOG.md +++ b/axum-core/CHANGELOG.md @@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **breaking:** `FromRequest` and `RequestParts` has a new `S` type param which + represents the state ([#1155]) + +[#1155]: https://github.com/tokio-rs/axum/pull/1155 # 0.2.6 (18. June, 2022) diff --git a/axum-core/src/extract/from_ref.rs b/axum-core/src/extract/from_ref.rs new file mode 100644 index 0000000000..c0124140e5 --- /dev/null +++ b/axum-core/src/extract/from_ref.rs @@ -0,0 +1,23 @@ +/// Used to do reference-to-value conversions thus not consuming the input value. +/// +/// This is mainly used with [`State`] to extract "substates" from a reference to main application +/// state. +/// +/// See [`State`] for more details on how library authors should use this trait. +/// +/// [`State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html +// NOTE: This trait is defined in axum-core, even though it is mainly used with `State` which is +// defined in axum. That allows crate authors to use it when implementing extractors. +pub trait FromRef { + /// Converts to this type from a reference to the input type. + fn from_ref(input: &T) -> Self; +} + +impl FromRef for T +where + T: Clone, +{ + fn from_ref(input: &T) -> Self { + input.clone() + } +} diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index 2316633be5..ade7bf0345 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -12,9 +12,12 @@ use std::convert::Infallible; pub mod rejection; +mod from_ref; mod request_parts; mod tuple; +pub use self::from_ref::FromRef; + /// Types that can be created from requests. /// /// See [`axum::extract`] for more details. @@ -42,13 +45,15 @@ mod tuple; /// struct MyExtractor; /// /// #[async_trait] -/// impl FromRequest for MyExtractor +/// impl FromRequest for MyExtractor /// where -/// B: Send, // required by `async_trait` +/// // these bounds are required by `async_trait` +/// B: Send, +/// S: Send, /// { /// type Rejection = http::StatusCode; /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: &mut RequestParts) -> Result { /// // ... /// # unimplemented!() /// } @@ -60,20 +65,21 @@ mod tuple; /// [`http::Request`]: http::Request /// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html #[async_trait] -pub trait FromRequest: Sized { +pub trait FromRequest: Sized { /// If the extractor fails it'll use this "rejection" type. A rejection is /// a kind of error that can be converted into a response. type Rejection: IntoResponse; /// Perform the extraction. - async fn from_request(req: &mut RequestParts) -> Result; + async fn from_request(req: &mut RequestParts) -> Result; } /// The type used with [`FromRequest`] to extract data from requests. /// /// Has several convenience methods for getting owned parts of the request. #[derive(Debug)] -pub struct RequestParts { +pub struct RequestParts { + state: S, method: Method, uri: Uri, version: Version, @@ -82,8 +88,8 @@ pub struct RequestParts { body: Option, } -impl RequestParts { - /// Create a new `RequestParts`. +impl RequestParts<(), B> { + /// Create a new `RequestParts` without any state. /// /// You generally shouldn't need to construct this type yourself, unless /// using extractors outside of axum for example to implement a @@ -91,6 +97,19 @@ impl RequestParts { /// /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html pub fn new(req: Request) -> Self { + Self::with_state((), req) + } +} + +impl RequestParts { + /// Create a new `RequestParts` with the given state. + /// + /// You generally shouldn't need to construct this type yourself, unless + /// using extractors outside of axum for example to implement a + /// [`tower::Service`]. + /// + /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html + pub fn with_state(state: S, req: Request) -> Self { let ( http::request::Parts { method, @@ -104,6 +123,7 @@ impl RequestParts { ) = req.into_parts(); RequestParts { + state, method, uri, version, @@ -130,10 +150,14 @@ impl RequestParts { /// use http::{Method, Uri}; /// /// #[async_trait] - /// impl FromRequest for MyExtractor { + /// impl FromRequest for MyExtractor + /// where + /// B: Send, + /// S: Send, + /// { /// type Rejection = Infallible; /// - /// async fn from_request(req: &mut RequestParts) -> Result { + /// async fn from_request(req: &mut RequestParts) -> Result { /// let method = req.extract::().await?; /// let path = req.extract::().await?.path().to_owned(); /// @@ -141,7 +165,10 @@ impl RequestParts { /// } /// } /// ``` - pub async fn extract>(&mut self) -> Result { + pub async fn extract(&mut self) -> Result + where + E: FromRequest, + { E::from_request(self).await } @@ -153,6 +180,7 @@ impl RequestParts { /// [`take_body`]: RequestParts::take_body pub fn try_into_request(self) -> Result, BodyAlreadyExtracted> { let Self { + state: _, method, uri, version, @@ -245,30 +273,37 @@ impl RequestParts { pub fn take_body(&mut self) -> Option { self.body.take() } + + /// Get a reference to the state. + pub fn state(&self) -> &S { + &self.state + } } #[async_trait] -impl FromRequest for Option +impl FromRequest for Option where - T: FromRequest, + T: FromRequest, B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result, Self::Rejection> { + async fn from_request(req: &mut RequestParts) -> Result, Self::Rejection> { Ok(T::from_request(req).await.ok()) } } #[async_trait] -impl FromRequest for Result +impl FromRequest for Result where - T: FromRequest, + T: FromRequest, B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(T::from_request(req).await) } } diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index 33383a7d8d..4faaf2d355 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -6,16 +6,18 @@ use http::{Extensions, HeaderMap, Method, Request, Uri, Version}; use std::convert::Infallible; #[async_trait] -impl FromRequest for Request +impl FromRequest for Request where B: Send, + S: Clone + Send, { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let req = std::mem::replace( req, RequestParts { + state: req.state().clone(), method: req.method.clone(), version: req.version, uri: req.uri.clone(), @@ -30,37 +32,40 @@ where } #[async_trait] -impl FromRequest for Method +impl FromRequest for Method where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.method().clone()) } } #[async_trait] -impl FromRequest for Uri +impl FromRequest for Uri where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.uri().clone()) } } #[async_trait] -impl FromRequest for Version +impl FromRequest for Version where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.version()) } } @@ -71,27 +76,29 @@ where /// /// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html #[async_trait] -impl FromRequest for HeaderMap +impl FromRequest for HeaderMap where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.headers().clone()) } } #[async_trait] -impl FromRequest for Bytes +impl FromRequest for Bytes where B: http_body::Body + Send, B::Data: Send, B::Error: Into, + S: Send, { type Rejection = BytesRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; let bytes = crate::body::to_bytes(body) @@ -103,15 +110,16 @@ where } #[async_trait] -impl FromRequest for String +impl FromRequest for String where B: http_body::Body + Send, B::Data: Send, B::Error: Into, + S: Send, { type Rejection = StringRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; let bytes = crate::body::to_bytes(body) @@ -126,13 +134,14 @@ where } #[async_trait] -impl FromRequest for http::request::Parts +impl FromRequest for http::request::Parts where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let method = unwrap_infallible(Method::from_request(req).await); let uri = unwrap_infallible(Uri::from_request(req).await); let version = unwrap_infallible(Version::from_request(req).await); @@ -159,6 +168,6 @@ fn unwrap_infallible(result: Result) -> T { } } -pub(crate) fn take_body(req: &mut RequestParts) -> Result { +pub(crate) fn take_body(req: &mut RequestParts) -> Result { req.take_body().ok_or(BodyAlreadyExtracted) } diff --git a/axum-core/src/extract/tuple.rs b/axum-core/src/extract/tuple.rs index 8c781a8d3e..05e38bf004 100644 --- a/axum-core/src/extract/tuple.rs +++ b/axum-core/src/extract/tuple.rs @@ -4,13 +4,14 @@ use async_trait::async_trait; use std::convert::Infallible; #[async_trait] -impl FromRequest for () +impl FromRequest for () where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(_: &mut RequestParts) -> Result<(), Self::Rejection> { + async fn from_request(_: &mut RequestParts) -> Result<(), Self::Rejection> { Ok(()) } } @@ -21,14 +22,15 @@ macro_rules! impl_from_request { ( $($ty:ident),* $(,)? ) => { #[async_trait] #[allow(non_snake_case)] - impl FromRequest for ($($ty,)*) + impl FromRequest for ($($ty,)*) where - $( $ty: FromRequest + Send, )* + $( $ty: FromRequest + Send, )* B: Send, + S: Send, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { $( let $ty = $ty::from_request(req).await.map_err(|err| err.into_response())?; )* Ok(($($ty,)*)) } diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index dda3e2eb49..8d6568f997 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -17,15 +17,21 @@ and this project adheres to [Semantic Versioning]. literal `Response` - **added:** Support chaining handlers with `HandlerCallWithExtractors::or` ([#1170]) - **change:** axum-extra's MSRV is now 1.60 ([#1239]) +- **breaking:** `SignedCookieJar` and `PrivateCookieJar` now extracts the keys + from the router's state, rather than extensions - **added:** Add Protocol Buffer extractor and response ([#1239]) - **added:** Add `Either*` types for combining extractors and responses into a single type ([#1263]) - **added:** `WithRejection` extractor for customizing other extractors' rejections ([#1262]) - **added:** Add sync constructors to `CookieJar`, `PrivateCookieJar`, and `SignedCookieJar` so they're easier to use in custom middleware +- **breaking:** `Resource` has a new `S` type param which represents the state ([#1155]) +- **breaking:** `RouterExt::route_with_tsr` now only accepts `MethodRouter`s ([#1155]) +- **added:** `RouterExt::route_service_with_tsr` for routing to any `Service` ([#1155]) [#1086]: https://github.com/tokio-rs/axum/pull/1086 [#1119]: https://github.com/tokio-rs/axum/pull/1119 +[#1155]: https://github.com/tokio-rs/axum/pull/1155 [#1170]: https://github.com/tokio-rs/axum/pull/1170 [#1214]: https://github.com/tokio-rs/axum/pull/1214 [#1239]: https://github.com/tokio-rs/axum/pull/1239 diff --git a/axum-extra/src/either.rs b/axum-extra/src/either.rs index d342fc193c..84b2a91f65 100755 --- a/axum-extra/src/either.rs +++ b/axum-extra/src/either.rs @@ -190,15 +190,16 @@ macro_rules! impl_traits_for_either { $last:ident $(,)? ) => { #[async_trait] - impl FromRequest for $either<$($ident),*, $last> + impl FromRequest for $either<$($ident),*, $last> where - $($ident: FromRequest),*, - $last: FromRequest, + $($ident: FromRequest),*, + $last: FromRequest, B: Send, + S: Send, { type Rejection = $last::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { $( if let Ok(value) = req.extract().await { return Ok(Self::$ident(value)); diff --git a/axum-extra/src/extract/cached.rs b/axum-extra/src/extract/cached.rs index 0ada78a888..9545fe16dd 100644 --- a/axum-extra/src/extract/cached.rs +++ b/axum-extra/src/extract/cached.rs @@ -30,13 +30,14 @@ use std::ops::{Deref, DerefMut}; /// struct Session { /* ... */ } /// /// #[async_trait] -/// impl FromRequest for Session +/// impl FromRequest for Session /// where /// B: Send, +/// S: Send, /// { /// type Rejection = (StatusCode, String); /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: &mut RequestParts) -> Result { /// // load session... /// # unimplemented!() /// } @@ -45,13 +46,14 @@ use std::ops::{Deref, DerefMut}; /// struct CurrentUser { /* ... */ } /// /// #[async_trait] -/// impl FromRequest for CurrentUser +/// impl FromRequest for CurrentUser /// where /// B: Send, +/// S: Send, /// { /// type Rejection = Response; /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: &mut RequestParts) -> Result { /// // loading a `CurrentUser` requires first loading the `Session` /// // /// // by using `Cached` we avoid extracting the session more than @@ -88,14 +90,15 @@ pub struct Cached(pub T); struct CachedEntry(T); #[async_trait] -impl FromRequest for Cached +impl FromRequest for Cached where B: Send, - T: FromRequest + Clone + Send + Sync + 'static, + S: Send, + T: FromRequest + Clone + Send + Sync + 'static, { type Rejection = T::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match Extension::>::from_request(req).await { Ok(Extension(CachedEntry(value))) => Ok(Self(value)), Err(_) => { @@ -139,13 +142,14 @@ mod tests { struct Extractor(Instant); #[async_trait] - impl FromRequest for Extractor + impl FromRequest for Extractor where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { COUNTER.fetch_add(1, Ordering::SeqCst); Ok(Self(Instant::now())) } diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index 9190424ec6..44842d039c 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -80,7 +80,7 @@ pub use cookie::Key; /// let app = Router::new() /// .route("/sessions", post(create_session)) /// .route("/me", get(me)); -/// # let app: Router = app; +/// # let app: Router = app; /// ``` #[derive(Debug, Default)] pub struct CookieJar { @@ -88,13 +88,14 @@ pub struct CookieJar { } #[async_trait] -impl FromRequest for CookieJar +impl FromRequest for CookieJar where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(Self::from_headers(req.headers())) } } @@ -226,7 +227,7 @@ fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) { #[cfg(test)] mod tests { use super::*; - use axum::{body::Body, http::Request, routing::get, Extension, Router}; + use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router}; use tower::ServiceExt; macro_rules! cookie_test { @@ -245,12 +246,15 @@ mod tests { jar.remove(Cookie::named("key")) } - let app = Router::::new() + let state = AppState { + key: Key::generate(), + custom_key: CustomKey(Key::generate()), + }; + + let app = Router::<_, Body>::with_state(state) .route("/set", get(set_cookie)) .route("/get", get(get_cookie)) - .route("/remove", get(remove_cookie)) - .layer(Extension(Key::generate())) - .layer(Extension(CustomKey(Key::generate()))); + .route("/remove", get(remove_cookie)); let res = app .clone() @@ -298,6 +302,24 @@ mod tests { cookie_test!(private_cookies, PrivateCookieJar); cookie_test!(private_cookies_with_custom_key, PrivateCookieJar); + #[derive(Clone)] + struct AppState { + key: Key, + custom_key: CustomKey, + } + + impl FromRef for Key { + fn from_ref(state: &AppState) -> Key { + state.key.clone() + } + } + + impl FromRef for CustomKey { + fn from_ref(state: &AppState) -> CustomKey { + state.custom_key.clone() + } + } + #[derive(Clone)] struct CustomKey(Key); @@ -313,9 +335,12 @@ mod tests { format!("{:?}", jar.get("key")) } - let app = Router::::new() - .route("/get", get(get_cookie)) - .layer(Extension(Key::generate())); + let state = AppState { + key: Key::generate(), + custom_key: CustomKey(Key::generate()), + }; + + let app = Router::<_, Body>::with_state(state).route("/get", get(get_cookie)); let res = app .clone() diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index b606a6bd40..d3705fb2be 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -1,9 +1,8 @@ use super::{cookies_from_request, set_cookies, Cookie, Key}; use axum::{ async_trait, - extract::{FromRequest, RequestParts}, + extract::{FromRef, FromRequest, RequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, - Extension, }; use cookie::PrivateJar; use http::HeaderMap; @@ -23,9 +22,8 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// ```rust /// use axum::{ /// Router, -/// Extension, /// routing::{post, get}, -/// extract::TypedHeader, +/// extract::{TypedHeader, FromRef}, /// response::{IntoResponse, Redirect}, /// headers::authorization::{Authorization, Bearer}, /// http::StatusCode, @@ -45,22 +43,36 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// } /// } /// -/// // Generate a secure key -/// // -/// // You probably don't wanna generate a new one each time the app starts though -/// let key = Key::generate(); +/// // our application state +/// #[derive(Clone)] +/// struct AppState { +/// // that holds the key used to sign cookies +/// key: Key, +/// } +/// +/// // this impl tells `SignedCookieJar` how to access the key from our state +/// impl FromRef for Key { +/// fn from_ref(state: &AppState) -> Self { +/// state.key.clone() +/// } +/// } +/// +/// let state = AppState { +/// // Generate a secure key +/// // +/// // You probably don't wanna generate a new one each time the app starts though +/// key: Key::generate(), +/// }; /// -/// let app = Router::new() +/// let app = Router::with_state(state) /// .route("/set", post(set_secret)) -/// .route("/get", get(get_secret)) -/// // add extension with the key so `PrivateCookieJar` can access it -/// .layer(Extension(key)); -/// # let app: Router = app; +/// .route("/get", get(get_secret)); +/// # let app: Router<_> = app; /// ``` pub struct PrivateCookieJar { jar: cookie::CookieJar, key: Key, - // The key used to extract the key extension. Allows users to use multiple keys for different + // The key used to extract the key. Allows users to use multiple keys for different // jars. Maybe a library wants its own key. _marker: PhantomData, } @@ -75,15 +87,17 @@ impl fmt::Debug for PrivateCookieJar { } #[async_trait] -impl FromRequest for PrivateCookieJar +impl FromRequest for PrivateCookieJar where B: Send, - K: Into + Clone + Send + Sync + 'static, + S: Send, + K: FromRef + Into, { - type Rejection = as FromRequest>::Rejection; + type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - let key = req.extract::>().await?.0.into(); + async fn from_request(req: &mut RequestParts) -> Result { + let k = K::from_ref(req.state()); + let key = k.into(); let PrivateCookieJar { jar, key, diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index e27ebc33ce..74da2a11ae 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -1,9 +1,8 @@ use super::{cookies_from_request, set_cookies}; use axum::{ async_trait, - extract::{FromRequest, RequestParts}, + extract::{FromRef, FromRequest, RequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, - Extension, }; use cookie::SignedJar; use cookie::{Cookie, Key}; @@ -24,9 +23,8 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// ```rust /// use axum::{ /// Router, -/// Extension, /// routing::{post, get}, -/// extract::TypedHeader, +/// extract::{TypedHeader, FromRef}, /// response::{IntoResponse, Redirect}, /// headers::authorization::{Authorization, Bearer}, /// http::StatusCode, @@ -63,22 +61,36 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// # todo!() /// } /// -/// // Generate a secure key -/// // -/// // You probably don't wanna generate a new one each time the app starts though -/// let key = Key::generate(); +/// // our application state +/// #[derive(Clone)] +/// struct AppState { +/// // that holds the key used to sign cookies +/// key: Key, +/// } +/// +/// // this impl tells `SignedCookieJar` how to access the key from our state +/// impl FromRef for Key { +/// fn from_ref(state: &AppState) -> Self { +/// state.key.clone() +/// } +/// } +/// +/// let state = AppState { +/// // Generate a secure key +/// // +/// // You probably don't wanna generate a new one each time the app starts though +/// key: Key::generate(), +/// }; /// -/// let app = Router::new() +/// let app = Router::with_state(state) /// .route("/sessions", post(create_session)) -/// .route("/me", get(me)) -/// // add extension with the key so `SignedCookieJar` can access it -/// .layer(Extension(key)); -/// # let app: Router = app; +/// .route("/me", get(me)); +/// # let app: Router<_> = app; /// ``` pub struct SignedCookieJar { jar: cookie::CookieJar, key: Key, - // The key used to extract the key extension. Allows users to use multiple keys for different + // The key used to extract the key. Allows users to use multiple keys for different // jars. Maybe a library wants its own key. _marker: PhantomData, } @@ -93,15 +105,17 @@ impl fmt::Debug for SignedCookieJar { } #[async_trait] -impl FromRequest for SignedCookieJar +impl FromRequest for SignedCookieJar where B: Send, - K: Into + Clone + Send + Sync + 'static, + S: Send, + K: FromRef + Into, { - type Rejection = as FromRequest>::Rejection; + type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { - let key = req.extract::>().await?.0.into(); + async fn from_request(req: &mut RequestParts) -> Result { + let k = K::from_ref(req.state()); + let key = k.into(); let SignedCookieJar { jar, key, diff --git a/axum-extra/src/extract/form.rs b/axum-extra/src/extract/form.rs index 3a20aeb186..593bfba660 100644 --- a/axum-extra/src/extract/form.rs +++ b/axum-extra/src/extract/form.rs @@ -55,16 +55,17 @@ impl Deref for Form { } #[async_trait] -impl FromRequest for Form +impl FromRequest for Form where T: DeserializeOwned, B: HttpBody + Send, B::Data: Send, B::Error: Into, + S: Send, { type Rejection = FormRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if req.method() == Method::GET { let query = req.uri().query().unwrap_or_default(); let value = serde_html_form::from_str(query) @@ -85,7 +86,7 @@ where } // this is duplicated in `axum/src/extract/mod.rs` -fn has_content_type(req: &RequestParts, expected_content_type: &mime::Mime) -> bool { +fn has_content_type(req: &RequestParts, expected_content_type: &mime::Mime) -> bool { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { content_type } else { diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index de49b8aecd..debc6957a3 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -58,14 +58,15 @@ use std::ops::Deref; pub struct Query(pub T); #[async_trait] -impl FromRequest for Query +impl FromRequest for Query where T: DeserializeOwned, B: Send, + S: Send, { type Rejection = QueryRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let query = req.uri().query().unwrap_or_default(); let value = serde_html_form::from_str(query) .map_err(FailedToDeserializeQueryString::__private_new)?; diff --git a/axum-extra/src/extract/with_rejection.rs b/axum-extra/src/extract/with_rejection.rs index e6387f6cfb..e0d2135cc3 100644 --- a/axum-extra/src/extract/with_rejection.rs +++ b/axum-extra/src/extract/with_rejection.rs @@ -107,15 +107,16 @@ impl DerefMut for WithRejection { } #[async_trait] -impl FromRequest for WithRejection +impl FromRequest for WithRejection where B: Send, - E: FromRequest, + S: Send, + E: FromRequest, R: From + IntoResponse, { type Rejection = R; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let extractor = req.extract::().await?; Ok(WithRejection(extractor, PhantomData)) } @@ -134,10 +135,14 @@ mod tests { struct TestRejection; #[async_trait] - impl FromRequest for TestExtractor { + impl FromRequest for TestExtractor + where + B: Send, + S: Send, + { type Rejection = (); - async fn from_request(_: &mut RequestParts) -> Result { + async fn from_request(_: &mut RequestParts) -> Result { Err(()) } } diff --git a/axum-extra/src/handler/mod.rs b/axum-extra/src/handler/mod.rs index 5a5ccf1f2d..6842327a01 100644 --- a/axum-extra/src/handler/mod.rs +++ b/axum-extra/src/handler/mod.rs @@ -19,15 +19,15 @@ pub use self::or::Or; /// /// The drawbacks of this trait is that you cannot apply middleware to individual handlers like you /// can with [`Handler::layer`]. -pub trait HandlerCallWithExtractors: Sized { +pub trait HandlerCallWithExtractors: Sized { /// The type of future calling this handler returns. type Future: Future + Send + 'static; /// Call the handler with the extracted inputs. - fn call(self, extractors: T) -> >::Future; + fn call(self, state: S, extractors: T) -> >::Future; /// Conver this `HandlerCallWithExtractors` into [`Handler`]. - fn into_handler(self) -> IntoHandler { + fn into_handler(self) -> IntoHandler { IntoHandler { handler: self, _marker: PhantomData, @@ -67,10 +67,14 @@ pub trait HandlerCallWithExtractors: Sized { /// struct AdminPermissions {} /// /// #[async_trait] - /// impl FromRequest for AdminPermissions { + /// impl FromRequest for AdminPermissions + /// where + /// B: Send, + /// S: Send, + /// { /// // check for admin permissions... /// # type Rejection = (); - /// # async fn from_request(req: &mut axum::extract::RequestParts) -> Result { + /// # async fn from_request(req: &mut axum::extract::RequestParts) -> Result { /// # todo!() /// # } /// } @@ -78,10 +82,14 @@ pub trait HandlerCallWithExtractors: Sized { /// struct User {} /// /// #[async_trait] - /// impl FromRequest for User { + /// impl FromRequest for User + /// where + /// B: Send, + /// S: Send, + /// { /// // check for a logged in user... /// # type Rejection = (); - /// # async fn from_request(req: &mut axum::extract::RequestParts) -> Result { + /// # async fn from_request(req: &mut axum::extract::RequestParts) -> Result { /// # todo!() /// # } /// } @@ -96,9 +104,9 @@ pub trait HandlerCallWithExtractors: Sized { /// ); /// # let _: Router = app; /// ``` - fn or(self, rhs: R) -> Or + fn or(self, rhs: R) -> Or where - R: HandlerCallWithExtractors, + R: HandlerCallWithExtractors, { Or { lhs: self, @@ -111,7 +119,7 @@ pub trait HandlerCallWithExtractors: Sized { macro_rules! impl_handler_call_with { ( $($ty:ident),* $(,)? ) => { #[allow(non_snake_case)] - impl HandlerCallWithExtractors<($($ty,)*), B> for F + impl HandlerCallWithExtractors<($($ty,)*), S, B> for F where F: FnOnce($($ty,)*) -> Fut, Fut: Future + Send + 'static, @@ -122,8 +130,9 @@ macro_rules! impl_handler_call_with { fn call( self, + _state: S, ($($ty,)*): ($($ty,)*), - ) -> >::Future { + ) -> >::Future { self($($ty,)*).map(IntoResponse::into_response) } } @@ -152,34 +161,35 @@ impl_handler_call_with!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, /// /// Created with [`HandlerCallWithExtractors::into_handler`]. #[allow(missing_debug_implementations)] -pub struct IntoHandler { +pub struct IntoHandler { handler: H, - _marker: PhantomData (T, B)>, + _marker: PhantomData (T, S, B)>, } -impl Handler for IntoHandler +impl Handler for IntoHandler where - H: HandlerCallWithExtractors + Clone + Send + 'static, - T: FromRequest + Send + 'static, + H: HandlerCallWithExtractors + Clone + Send + 'static, + T: FromRequest + Send + 'static, T::Rejection: Send, B: Send + 'static, + S: Clone + Send + 'static, { type Future = BoxFuture<'static, Response>; - fn call(self, req: http::Request) -> Self::Future { + fn call(self, state: S, req: http::Request) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::new(req); + let mut req = RequestParts::with_state(state.clone(), req); match req.extract::().await { - Ok(t) => self.handler.call(t).await, + Ok(t) => self.handler.call(state, t).await, Err(rejection) => rejection.into_response(), } }) } } -impl Copy for IntoHandler where H: Copy {} +impl Copy for IntoHandler where H: Copy {} -impl Clone for IntoHandler +impl Clone for IntoHandler where H: Clone, { diff --git a/axum-extra/src/handler/or.rs b/axum-extra/src/handler/or.rs index ea5eafb932..6478b35d50 100644 --- a/axum-extra/src/handler/or.rs +++ b/axum-extra/src/handler/or.rs @@ -15,16 +15,16 @@ use std::{future::Future, marker::PhantomData}; /// /// Created with [`HandlerCallWithExtractors::or`](super::HandlerCallWithExtractors::or). #[allow(missing_debug_implementations)] -pub struct Or { +pub struct Or { pub(super) lhs: L, pub(super) rhs: R, - pub(super) _marker: PhantomData (Lt, Rt, B)>, + pub(super) _marker: PhantomData (Lt, Rt, S, B)>, } -impl HandlerCallWithExtractors, B> for Or +impl HandlerCallWithExtractors, S, B> for Or where - L: HandlerCallWithExtractors + Send + 'static, - R: HandlerCallWithExtractors + Send + 'static, + L: HandlerCallWithExtractors + Send + 'static, + R: HandlerCallWithExtractors + Send + 'static, Rt: Send + 'static, Lt: Send + 'static, B: Send + 'static, @@ -37,46 +37,48 @@ where fn call( self, + state: S, extractors: Either, - ) -> , B>>::Future { + ) -> , S, B>>::Future { match extractors { Either::E1(lt) => self .lhs - .call(lt) + .call(state, lt) .map(IntoResponse::into_response as _) .left_future(), Either::E2(rt) => self .rhs - .call(rt) + .call(state, rt) .map(IntoResponse::into_response as _) .right_future(), } } } -impl Handler<(Lt, Rt), B> for Or +impl Handler<(Lt, Rt), S, B> for Or where - L: HandlerCallWithExtractors + Clone + Send + 'static, - R: HandlerCallWithExtractors + Clone + Send + 'static, - Lt: FromRequest + Send + 'static, - Rt: FromRequest + Send + 'static, + L: HandlerCallWithExtractors + Clone + Send + 'static, + R: HandlerCallWithExtractors + Clone + Send + 'static, + Lt: FromRequest + Send + 'static, + Rt: FromRequest + Send + 'static, Lt::Rejection: Send, Rt::Rejection: Send, B: Send + 'static, + S: Clone + Send + 'static, { // this puts `futures_util` in our public API but thats fine in axum-extra type Future = BoxFuture<'static, Response>; - fn call(self, req: Request) -> Self::Future { + fn call(self, state: S, req: Request) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::new(req); + let mut req = RequestParts::with_state(state.clone(), req); if let Ok(lt) = req.extract::().await { - return self.lhs.call(lt).await; + return self.lhs.call(state, lt).await; } if let Ok(rt) = req.extract::().await { - return self.rhs.call(rt).await; + return self.rhs.call(state, rt).await; } StatusCode::NOT_FOUND.into_response() @@ -84,14 +86,14 @@ where } } -impl Copy for Or +impl Copy for Or where L: Copy, R: Copy, { } -impl Clone for Or +impl Clone for Or where L: Clone, R: Clone, diff --git a/axum-extra/src/json_lines.rs b/axum-extra/src/json_lines.rs index 1bc8d1a0bd..242b43e70f 100644 --- a/axum-extra/src/json_lines.rs +++ b/axum-extra/src/json_lines.rs @@ -98,16 +98,17 @@ impl JsonLines { } #[async_trait] -impl FromRequest for JsonLines +impl FromRequest for JsonLines where B: HttpBody + Send + 'static, B::Data: Into, B::Error: Into, T: DeserializeOwned, + S: Send, { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { // `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead` // so we can call `AsyncRead::lines` and then convert it back to a `Stream` diff --git a/axum-extra/src/protobuf.rs b/axum-extra/src/protobuf.rs index adad728014..a30421a033 100644 --- a/axum-extra/src/protobuf.rs +++ b/axum-extra/src/protobuf.rs @@ -97,16 +97,17 @@ use std::ops::{Deref, DerefMut}; pub struct ProtoBuf(pub T); #[async_trait] -impl FromRequest for ProtoBuf +impl FromRequest for ProtoBuf where T: Message + Default, B: HttpBody + Send, B::Data: Send, B::Error: Into, + S: Send, { type Rejection = ProtoBufRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let mut bytes = Bytes::from_request(req).await?; match T::decode(&mut bytes) { diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 35993b44d8..0d968b941c 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -1,12 +1,13 @@ //! Additional types for defining routes. use axum::{ - handler::Handler, + handler::{Handler, HandlerWithoutStateExt}, http::Request, response::{IntoResponse, Redirect}, + routing::{any, MethodRouter}, Router, }; -use std::{convert::Infallible, future::ready}; +use std::{convert::Infallible, future::ready, sync::Arc}; use tower_service::Service; mod resource; @@ -29,7 +30,7 @@ pub use self::typed::{FirstElementIs, TypedPath}; pub use self::spa::SpaRouter; /// Extension trait that adds additional methods to [`Router`]. -pub trait RouterExt: sealed::Sealed { +pub trait RouterExt: sealed::Sealed { /// Add a typed `GET` route to the router. /// /// The path will be inferred from the first argument to the handler function which must @@ -39,7 +40,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_get(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -52,7 +53,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_delete(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -65,7 +66,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_head(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -78,7 +79,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_options(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -91,7 +92,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_patch(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -104,7 +105,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_post(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -117,7 +118,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_put(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -130,7 +131,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_trace(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -159,7 +160,14 @@ pub trait RouterExt: sealed::Sealed { /// .route_with_tsr("/bar/", get(|| async {})); /// # let _: Router = app; /// ``` - fn route_with_tsr(self, path: &str, service: T) -> Self + fn route_with_tsr(self, path: &str, method_router: MethodRouter) -> Self + where + Self: Sized; + + /// Add another route to the router with an additional "trailing slash redirect" route. + /// + /// This works like [`RouterExt::route_with_tsr`] but accepts any [`Service`]. + fn route_service_with_tsr(self, path: &str, service: T) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, @@ -167,14 +175,15 @@ pub trait RouterExt: sealed::Sealed { Self: Sized; } -impl RouterExt for Router +impl RouterExt for Router where B: axum::body::HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, { #[cfg(feature = "typed-routing")] fn typed_get(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -184,7 +193,7 @@ where #[cfg(feature = "typed-routing")] fn typed_delete(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -194,7 +203,7 @@ where #[cfg(feature = "typed-routing")] fn typed_head(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -204,7 +213,7 @@ where #[cfg(feature = "typed-routing")] fn typed_options(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -214,7 +223,7 @@ where #[cfg(feature = "typed-routing")] fn typed_patch(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -224,7 +233,7 @@ where #[cfg(feature = "typed-routing")] fn typed_post(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -234,7 +243,7 @@ where #[cfg(feature = "typed-routing")] fn typed_put(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -244,41 +253,56 @@ where #[cfg(feature = "typed-routing")] fn typed_trace(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::trace(handler)) } - fn route_with_tsr(mut self, path: &str, service: T) -> Self + fn route_with_tsr(mut self, path: &str, method_router: MethodRouter) -> Self + where + Self: Sized, + { + self = self.route(path, method_router); + + let redirect_service = { + let path: Arc = path.into(); + (move || ready(Redirect::permanent(&path))).into_service() + }; + + if let Some(path_without_trailing_slash) = path.strip_suffix('/') { + self.route_service(path_without_trailing_slash, redirect_service) + } else { + self.route_service(&format!("{}/", path), redirect_service) + } + } + + fn route_service_with_tsr(mut self, path: &str, service: T) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, Self: Sized, { - self = self.route(path, service); + self = self.route_service(path, service); let redirect = Redirect::permanent(path); if let Some(path_without_trailing_slash) = path.strip_suffix('/') { self.route( path_without_trailing_slash, - (move || ready(redirect.clone())).into_service(), + any(move || ready(redirect.clone())), ) } else { - self.route( - &format!("{}/", path), - (move || ready(redirect.clone())).into_service(), - ) + self.route(&format!("{}/", path), any(move || ready(redirect.clone()))) } } } mod sealed { pub trait Sealed {} - impl Sealed for axum::Router {} + impl Sealed for axum::Router {} } #[cfg(test)] diff --git a/axum-extra/src/routing/resource.rs b/axum-extra/src/routing/resource.rs index 910410fb01..239d6e1895 100644 --- a/axum-extra/src/routing/resource.rs +++ b/axum-extra/src/routing/resource.rs @@ -1,13 +1,9 @@ use axum::{ body::Body, handler::Handler, - http::Request, - response::IntoResponse, - routing::{delete, get, on, post, MethodFilter}, + routing::{delete, get, on, post, MethodFilter, MethodRouter}, Router, }; -use std::{convert::Infallible, fmt}; -use tower_service::Service; /// A resource which defines a set of conventional CRUD routes. /// @@ -34,14 +30,15 @@ use tower_service::Service; /// .destroy(|Path(user_id): Path| async {}); /// /// let app = Router::new().merge(users); -/// # let _: Router = app; +/// # let _: Router = app; /// ``` -pub struct Resource { +#[derive(Debug)] +pub struct Resource { pub(crate) name: String, - pub(crate) router: Router, + pub(crate) router: Router, } -impl Resource +impl Resource<(), B> where B: axum::body::HttpBody + Send + 'static, { @@ -49,16 +46,29 @@ where /// /// All routes will be nested at `/{resource_name}`. pub fn named(resource_name: &str) -> Self { + Self::named_with((), resource_name) + } +} + +impl Resource +where + B: axum::body::HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, +{ + /// Create a `Resource` with the given name and state. + /// + /// All routes will be nested at `/{resource_name}`. + pub fn named_with(state: S, resource_name: &str) -> Self { Self { name: resource_name.to_owned(), - router: Default::default(), + router: Router::with_state(state), } } /// Add a handler at `GET /{resource_name}`. pub fn index(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = self.index_create_path(); @@ -68,7 +78,7 @@ where /// Add a handler at `POST /{resource_name}`. pub fn create(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = self.index_create_path(); @@ -78,7 +88,7 @@ where /// Add a handler at `GET /{resource_name}/new`. pub fn new(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = format!("/{}/new", self.name); @@ -88,7 +98,7 @@ where /// Add a handler at `GET /{resource_name}/:{resource_name}_id`. pub fn show(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = self.show_update_destroy_path(); @@ -98,7 +108,7 @@ where /// Add a handler at `GET /{resource_name}/:{resource_name}_id/edit`. pub fn edit(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = format!("/{0}/:{0}_id/edit", self.name); @@ -108,7 +118,7 @@ where /// Add a handler at `PUT or PATCH /resource_name/:{resource_name}_id`. pub fn update(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = self.show_update_destroy_path(); @@ -118,7 +128,7 @@ where /// Add a handler at `DELETE /{resource_name}/:{resource_name}_id`. pub fn destroy(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = self.show_update_destroy_path(); @@ -133,13 +143,8 @@ where format!("/{0}/:{0}_id", self.name) } - fn route(mut self, path: &str, svc: T) -> Self - where - T: Service, Error = Infallible> + Clone + Send + 'static, - T::Response: IntoResponse, - T::Future: Send + 'static, - { - self.router = self.router.route(path, svc); + fn route(mut self, path: &str, method_router: MethodRouter) -> Self { + self.router = self.router.route(path, method_router); self } } @@ -150,21 +155,13 @@ impl From> for Router { } } -impl fmt::Debug for Resource { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Resource") - .field("name", &self.name) - .field("router", &self.router) - .finish() - } -} - #[cfg(test)] mod tests { #[allow(unused_imports)] use super::*; use axum::{extract::Path, http::Method, Router}; - use tower::ServiceExt; + use http::Request; + use tower::{Service, ServiceExt}; #[tokio::test] async fn works() { @@ -220,7 +217,7 @@ mod tests { ); } - async fn call_route(app: &mut Router, method: Method, uri: &str) -> String { + async fn call_route(app: &mut Router<()>, method: Method, uri: &str) -> String { let res = app .ready() .await diff --git a/axum-extra/src/routing/spa.rs b/axum-extra/src/routing/spa.rs index 7c67d882cf..844c1a5d4c 100644 --- a/axum-extra/src/routing/spa.rs +++ b/axum-extra/src/routing/spa.rs @@ -36,7 +36,7 @@ use tower_service::Service; /// .merge(spa) /// // we can still add other routes /// .route("/api/foo", get(api_foo)); -/// # let _: Router = app; +/// # let _: Router = app; /// /// async fn api_foo() {} /// ``` @@ -101,7 +101,7 @@ impl SpaRouter { /// .index_file("another_file.html"); /// /// let app = Router::new().merge(spa); - /// # let _: Router = app; + /// # let _: Router = app; /// ``` pub fn index_file

(mut self, path: P) -> Self where @@ -136,7 +136,7 @@ impl SpaRouter { /// } /// /// let app = Router::new().merge(spa); - /// # let _: Router = app; + /// # let _: Router = app; /// ``` pub fn handle_error(self, f: F2) -> SpaRouter { SpaRouter { @@ -147,7 +147,7 @@ impl SpaRouter { } } -impl From> for Router +impl From> for Router<(), B> where F: Clone + Send + 'static, HandleError, F, T>: Service, Error = Infallible>, @@ -162,7 +162,7 @@ where Router::new() .nest(&spa.paths.assets_path, assets_service) - .fallback( + .fallback_service( get_service(ServeFile::new(&spa.paths.index_file)).handle_error(spa.handle_error), ) } @@ -264,6 +264,13 @@ mod tests { let spa = SpaRouter::new("/assets", "test_files").handle_error(handle_error); - Router::::new().merge(spa); + Router::<_, Body>::new().merge(spa); + } + + #[allow(dead_code)] + fn works_with_router_with_state() { + let _: Router = Router::with_state(String::new()) + .merge(SpaRouter::new("/assets", "test_files")) + .route("/", get(|_: axum::extract::State| async {})); } } diff --git a/axum-extra/src/routing/typed.rs b/axum-extra/src/routing/typed.rs index 683d68431c..159472a063 100644 --- a/axum-extra/src/routing/typed.rs +++ b/axum-extra/src/routing/typed.rs @@ -60,7 +60,7 @@ use http::Uri; /// async fn users_destroy(_: UsersCollection) { /* ... */ } /// /// # -/// # let app: Router = app; +/// # let app: Router = app; /// ``` /// /// # Using `#[derive(TypedPath)]` diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index cc0a682242..4972fc49d6 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -203,7 +203,7 @@ fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStr #[allow(warnings)] fn #name() where - #ty: ::axum::extract::FromRequest<#body_ty> + Send, + #ty: ::axum::extract::FromRequest<(), #body_ty> + Send, {} } }) diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs index f017af7f5c..9e84bd248f 100644 --- a/axum-macros/src/from_request.rs +++ b/axum-macros/src/from_request.rs @@ -218,16 +218,17 @@ fn impl_struct_by_extracting_each_field( Ok(quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, B::Error: ::std::convert::Into<::axum::BoxError>, + S: Send, { type Rejection = #rejection_ident; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: &mut ::axum::extract::RequestParts, ) -> ::std::result::Result { ::std::result::Result::Ok(Self { #(#extract_fields)* @@ -422,7 +423,7 @@ fn extract_each_field_rejection( Ok(quote_spanned! {ty_span=> #[allow(non_camel_case_types)] - #variant_name(<#extractor_ty as ::axum::extract::FromRequest<::axum::body::Body>>::Rejection), + #variant_name(<#extractor_ty as ::axum::extract::FromRequest<(), ::axum::body::Body>>::Rejection), }) }) .collect::>>()?; @@ -609,26 +610,26 @@ fn impl_struct_by_extracting_all_at_once( quote! { #rejection } } else { quote! { - <#path as ::axum::extract::FromRequest>::Rejection + <#path as ::axum::extract::FromRequest>::Rejection } }; let rejection_bound = rejection.as_ref().map(|rejection| { if generic_ident.is_some() { quote! { - #rejection: ::std::convert::From<<#path as ::axum::extract::FromRequest>::Rejection>, + #rejection: ::std::convert::From<<#path as ::axum::extract::FromRequest>::Rejection>, } } else { quote! { - #rejection: ::std::convert::From<<#path as ::axum::extract::FromRequest>::Rejection>, + #rejection: ::std::convert::From<<#path as ::axum::extract::FromRequest>::Rejection>, } } }).unwrap_or_default(); let impl_generics = if generic_ident.is_some() { - quote! { B, T } + quote! { S, B, T } } else { - quote! { B } + quote! { S, B } }; let type_generics = generic_ident @@ -653,18 +654,19 @@ fn impl_struct_by_extracting_all_at_once( Ok(quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] - impl<#impl_generics> ::axum::extract::FromRequest for #ident #type_generics + impl<#impl_generics> ::axum::extract::FromRequest for #ident #type_generics where - #path<#via_type_generics>: ::axum::extract::FromRequest, + #path<#via_type_generics>: ::axum::extract::FromRequest, #rejection_bound B: ::std::marker::Send, + S: ::std::marker::Send, { type Rejection = #associated_rejection_type; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: &mut ::axum::extract::RequestParts, ) -> ::std::result::Result { - ::axum::extract::FromRequest::::from_request(req) + ::axum::extract::FromRequest::::from_request(req) .await .map(|#path(value)| #value_to_self) .map_err(::std::convert::From::from) @@ -709,7 +711,7 @@ fn impl_enum_by_extracting_all_at_once( quote! { #rejection } } else { quote! { - <#path as ::axum::extract::FromRequest>::Rejection + <#path as ::axum::extract::FromRequest>::Rejection } }; @@ -718,18 +720,19 @@ fn impl_enum_by_extracting_all_at_once( Ok(quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, B::Error: ::std::convert::Into<::axum::BoxError>, + S: ::std::marker::Send, { type Rejection = #associated_rejection_type; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: &mut ::axum::extract::RequestParts, ) -> ::std::result::Result { - ::axum::extract::FromRequest::::from_request(req) + ::axum::extract::FromRequest::::from_request(req) .await .map(|#path(inner)| inner) .map_err(::std::convert::From::from) diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index c76084d695..7b27a4bd83 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -125,7 +125,7 @@ mod typed_path; /// ``` /// pub struct ViaExtractor(pub T); /// -/// // impl FromRequest for ViaExtractor { ... } +/// // impl FromRequest for ViaExtractor { ... } /// ``` /// /// More complex via extractors are not supported and require writing a manual implementation. @@ -223,14 +223,15 @@ mod typed_path; /// struct OtherExtractor; /// /// #[async_trait] -/// impl FromRequest for OtherExtractor +/// impl FromRequest for OtherExtractor /// where -/// B: Send + 'static, +/// B: Send, +/// S: Send, /// { /// // this rejection doesn't implement `Display` and `Error` /// type Rejection = (StatusCode, String); /// -/// async fn from_request(_req: &mut RequestParts) -> Result { +/// async fn from_request(_req: &mut RequestParts) -> Result { /// // ... /// # unimplemented!() /// } diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs index a765437198..6a8f03c170 100644 --- a/axum-macros/src/typed_path.rs +++ b/axum-macros/src/typed_path.rs @@ -127,13 +127,14 @@ fn expand_named_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: Send, + S: Send, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { + async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { ::axum::extract::Path::from_request(req) .await .map(|path| path.0) @@ -229,13 +230,14 @@ fn expand_unnamed_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: Send, + S: Send, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { + async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { ::axum::extract::Path::from_request(req) .await .map(|path| path.0) @@ -310,13 +312,14 @@ fn expand_unit_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: Send, + S: Send, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { + async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { if req.uri().path() == ::PATH { Ok(Self) } else { @@ -387,7 +390,7 @@ enum Segment { fn path_rejection() -> TokenStream { quote! { - <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection + <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection } } diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr index 078b2b0371..265258419e 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr @@ -1,17 +1,17 @@ -error[E0277]: the trait bound `bool: FromRequest` is not satisfied +error[E0277]: the trait bound `bool: FromRequest<(), Body>` is not satisfied --> tests/debug_handler/fail/argument_not_extractor.rs:4:23 | 4 | async fn handler(foo: bool) {} - | ^^^^ the trait `FromRequest` is not implemented for `bool` + | ^^^^ the trait `FromRequest<(), Body>` is not implemented for `bool` | - = help: the following other types implement trait `FromRequest`: - () - (T1, T2) - (T1, T2, T3) - (T1, T2, T3, T4) - (T1, T2, T3, T4, T5) - (T1, T2, T3, T4, T5, T6) - (T1, T2, T3, T4, T5, T6, T7) - (T1, T2, T3, T4, T5, T6, T7, T8) - and 33 others + = help: the following other types implement trait `FromRequest`: + <() as FromRequest> + <(T1, T2) as FromRequest> + <(T1, T2, T3) as FromRequest> + <(T1, T2, T3, T4) as FromRequest> + <(T1, T2, T3, T4, T5) as FromRequest> + <(T1, T2, T3, T4, T5, T6) as FromRequest> + <(T1, T2, T3, T4, T5, T6, T7) as FromRequest> + <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequest> + and 34 others = help: see issue #48214 diff --git a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs index 01eb636bc0..d38d5e0c4d 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs @@ -7,13 +7,14 @@ use axum_macros::debug_handler; struct A; #[async_trait] -impl FromRequest for A +impl FromRequest for A where - B: Send + 'static, + B: Send, + S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr b/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr index 595786bf4e..3d80dffbca 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr +++ b/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr @@ -1,5 +1,5 @@ error: Handlers must only take owned values - --> tests/debug_handler/fail/extract_self_mut.rs:23:22 + --> tests/debug_handler/fail/extract_self_mut.rs:24:22 | -23 | async fn handler(&mut self) {} +24 | async fn handler(&mut self) {} | ^^^^^^^^^ diff --git a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs index d64732cdcc..06b87f0a82 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs @@ -7,13 +7,14 @@ use axum_macros::debug_handler; struct A; #[async_trait] -impl FromRequest for A +impl FromRequest for A where - B: Send + 'static, + B: Send, + S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr b/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr index 4c0b4950c7..82d9a89ff5 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr +++ b/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr @@ -1,5 +1,5 @@ error: Handlers must only take owned values - --> tests/debug_handler/fail/extract_self_ref.rs:23:22 + --> tests/debug_handler/fail/extract_self_ref.rs:24:22 | -23 | async fn handler(&self) {} +24 | async fn handler(&self) {} | ^^^^^ diff --git a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs index 81269a9cb2..762809b62a 100644 --- a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs +++ b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs @@ -120,13 +120,14 @@ impl A { } #[async_trait] -impl FromRequest for A +impl FromRequest for A where - B: Send + 'static, + B: Send, + S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/pass/self_receiver.rs b/axum-macros/tests/debug_handler/pass/self_receiver.rs index f22ccc08cd..a88382cf18 100644 --- a/axum-macros/tests/debug_handler/pass/self_receiver.rs +++ b/axum-macros/tests/debug_handler/pass/self_receiver.rs @@ -7,13 +7,14 @@ use axum_macros::debug_handler; struct A; #[async_trait] -impl FromRequest for A +impl FromRequest for A where - B: Send + 'static, + B: Send, + S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/from_request/fail/double_via_attr.rs b/axum-macros/tests/from_request/fail/double_via_attr.rs index b78cec4653..e65406512c 100644 --- a/axum-macros/tests/from_request/fail/double_via_attr.rs +++ b/axum-macros/tests/from_request/fail/double_via_attr.rs @@ -1,8 +1,7 @@ use axum_macros::FromRequest; -use axum::extract::Extension; #[derive(FromRequest)] -struct Extractor(#[from_request(via(Extension), via(Extension))] State); +struct Extractor(#[from_request(via(axum::Extension), via(axum::Extension))] State); #[derive(Clone)] struct State; diff --git a/axum-macros/tests/from_request/fail/double_via_attr.stderr b/axum-macros/tests/from_request/fail/double_via_attr.stderr index e63d4502d2..9d0ff2490b 100644 --- a/axum-macros/tests/from_request/fail/double_via_attr.stderr +++ b/axum-macros/tests/from_request/fail/double_via_attr.stderr @@ -1,13 +1,5 @@ error: `via` specified more than once - --> tests/from_request/fail/double_via_attr.rs:5:49 + --> tests/from_request/fail/double_via_attr.rs:4:55 | -5 | struct Extractor(#[from_request(via(Extension), via(Extension))] State); - | ^^^ - -warning: unused import: `axum::extract::Extension` - --> tests/from_request/fail/double_via_attr.rs:2:5 - | -2 | use axum::extract::Extension; - | ^^^^^^^^^^^^^^^^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default +4 | struct Extractor(#[from_request(via(axum::Extension), via(axum::Extension))] State); + | ^^^ diff --git a/axum-macros/tests/from_request/fail/generic_without_via.rs b/axum-macros/tests/from_request/fail/generic_without_via.rs index 29ec609566..38eaa437a3 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via.rs +++ b/axum-macros/tests/from_request/fail/generic_without_via.rs @@ -1,4 +1,4 @@ -use axum::{body::Body, routing::get, Extension, Router}; +use axum::{body::Body, routing::get, Router}; use axum_macros::FromRequest; #[derive(FromRequest, Clone)] @@ -7,5 +7,5 @@ struct Extractor(T); async fn foo(_: Extractor<()>) {} fn main() { - Router::::new().route("/", get(foo)); + Router::<(), Body>::new().route("/", get(foo)); } diff --git a/axum-macros/tests/from_request/fail/generic_without_via.stderr b/axum-macros/tests/from_request/fail/generic_without_via.stderr index 950779fefd..4b6b920a31 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via.stderr +++ b/axum-macros/tests/from_request/fail/generic_without_via.stderr @@ -4,23 +4,15 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_reque 5 | struct Extractor(T); | ^ -warning: unused import: `Extension` - --> tests/from_request/fail/generic_without_via.rs:1:38 - | -1 | use axum::{body::Body, routing::get, Extension, Router}; - | ^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default - -error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _>` is not satisfied - --> tests/from_request/fail/generic_without_via.rs:10:42 +error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _, _>` is not satisfied + --> tests/from_request/fail/generic_without_via.rs:10:46 | -10 | Router::::new().route("/", get(foo)); - | --- ^^^ the trait `Handler<_, _>` is not implemented for `fn(Extractor<()>) -> impl Future {foo}` - | | - | required by a bound introduced by this call +10 | Router::<(), Body>::new().route("/", get(foo)); + | --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future {foo}` + | | + | required by a bound introduced by this call | - = help: the trait `Handler` is implemented for `Layered` + = help: the trait `Handler` is implemented for `Layered` note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection.rs b/axum-macros/tests/from_request/fail/generic_without_via_rejection.rs index 99bdf38774..38d6b0910b 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection.rs +++ b/axum-macros/tests/from_request/fail/generic_without_via_rejection.rs @@ -1,4 +1,4 @@ -use axum::{body::Body, routing::get, Extension, Router}; +use axum::{body::Body, routing::get, Router}; use axum_macros::FromRequest; #[derive(FromRequest, Clone)] @@ -8,5 +8,5 @@ struct Extractor(T); async fn foo(_: Extractor<()>) {} fn main() { - Router::::new().route("/", get(foo)); + Router::<(), Body>::new().route("/", get(foo)); } diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr b/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr index 199d3dd2b9..d470986cde 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr +++ b/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr @@ -4,23 +4,15 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_reque 6 | struct Extractor(T); | ^ -warning: unused import: `Extension` - --> tests/from_request/fail/generic_without_via_rejection.rs:1:38 - | -1 | use axum::{body::Body, routing::get, Extension, Router}; - | ^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default - -error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _>` is not satisfied - --> tests/from_request/fail/generic_without_via_rejection.rs:11:42 +error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _, _>` is not satisfied + --> tests/from_request/fail/generic_without_via_rejection.rs:11:46 | -11 | Router::::new().route("/", get(foo)); - | --- ^^^ the trait `Handler<_, _>` is not implemented for `fn(Extractor<()>) -> impl Future {foo}` - | | - | required by a bound introduced by this call +11 | Router::<(), Body>::new().route("/", get(foo)); + | --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future {foo}` + | | + | required by a bound introduced by this call | - = help: the trait `Handler` is implemented for `Layered` + = help: the trait `Handler` is implemented for `Layered` note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.rs b/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.rs index 8ded02831e..ec5bb80099 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.rs +++ b/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.rs @@ -1,4 +1,4 @@ -use axum::{body::Body, routing::get, Extension, Router}; +use axum::{body::Body, routing::get, Router}; use axum_macros::FromRequest; #[derive(FromRequest, Clone)] @@ -8,5 +8,5 @@ struct Extractor(T); async fn foo(_: Extractor<()>) {} fn main() { - Router::::new().route("/", get(foo)); + Router::<(), Body>::new().route("/", get(foo)); } diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.stderr b/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.stderr index 5bc39b62ab..10b674c150 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.stderr +++ b/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.stderr @@ -4,23 +4,15 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_reque 6 | struct Extractor(T); | ^ -warning: unused import: `Extension` - --> tests/from_request/fail/generic_without_via_rejection_derive.rs:1:38 - | -1 | use axum::{body::Body, routing::get, Extension, Router}; - | ^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default - -error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _>` is not satisfied - --> tests/from_request/fail/generic_without_via_rejection_derive.rs:11:42 +error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _, _>` is not satisfied + --> tests/from_request/fail/generic_without_via_rejection_derive.rs:11:46 | -11 | Router::::new().route("/", get(foo)); - | --- ^^^ the trait `Handler<_, _>` is not implemented for `fn(Extractor<()>) -> impl Future {foo}` - | | - | required by a bound introduced by this call +11 | Router::<(), Body>::new().route("/", get(foo)); + | --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future {foo}` + | | + | required by a bound introduced by this call | - = help: the trait `Handler` is implemented for `Layered` + = help: the trait `Handler` is implemented for `Layered` note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | diff --git a/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr b/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr index 1bd8a29a36..6fc8d999ed 100644 --- a/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr +++ b/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr @@ -4,15 +4,15 @@ error: cannot use `rejection` without `via` 18 | #[from_request(rejection(MyRejection))] | ^^^^^^^^^^^ -error[E0277]: the trait bound `fn(MyExtractor) -> impl Future {handler}: Handler<_, _>` is not satisfied +error[E0277]: the trait bound `fn(MyExtractor) -> impl Future {handler}: Handler<_, _, _>` is not satisfied --> tests/from_request/fail/override_rejection_on_enum_without_via.rs:10:50 | 10 | let _: Router = Router::new().route("/", get(handler).post(handler_result)); - | --- ^^^^^^^ the trait `Handler<_, _>` is not implemented for `fn(MyExtractor) -> impl Future {handler}` + | --- ^^^^^^^ the trait `Handler<_, _, _>` is not implemented for `fn(MyExtractor) -> impl Future {handler}` | | | required by a bound introduced by this call | - = help: the trait `Handler` is implemented for `Layered` + = help: the trait `Handler` is implemented for `Layered` note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | @@ -20,18 +20,18 @@ note: required by a bound in `axum::routing::get` | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `axum::routing::get` = note: this error originates in the macro `top_level_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info) -error[E0277]: the trait bound `fn(Result) -> impl Future {handler_result}: Handler<_, _>` is not satisfied +error[E0277]: the trait bound `fn(Result) -> impl Future {handler_result}: Handler<_, _, _>` is not satisfied --> tests/from_request/fail/override_rejection_on_enum_without_via.rs:10:64 | 10 | let _: Router = Router::new().route("/", get(handler).post(handler_result)); - | ---- ^^^^^^^^^^^^^^ the trait `Handler<_, _>` is not implemented for `fn(Result) -> impl Future {handler_result}` + | ---- ^^^^^^^^^^^^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Result) -> impl Future {handler_result}` | | | required by a bound introduced by this call | - = help: the trait `Handler` is implemented for `Layered` -note: required by a bound in `MethodRouter::::post` + = help: the trait `Handler` is implemented for `Layered` +note: required by a bound in `MethodRouter::::post` --> $WORKSPACE/axum/src/routing/method_routing.rs | | chained_handler_fn!(post, POST); - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `MethodRouter::::post` + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `MethodRouter::::post` = note: this error originates in the macro `chained_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/from_request/fail/rejection_derive_and_via.rs b/axum-macros/tests/from_request/fail/rejection_derive_and_via.rs index bb658f1103..a369150d60 100644 --- a/axum-macros/tests/from_request/fail/rejection_derive_and_via.rs +++ b/axum-macros/tests/from_request/fail/rejection_derive_and_via.rs @@ -1,8 +1,7 @@ use axum_macros::FromRequest; -use axum::extract::Extension; #[derive(FromRequest, Clone)] -#[from_request(rejection_derive(!Error), via(Extension))] +#[from_request(rejection_derive(!Error), via(axum::Extension))] struct Extractor { config: String, } diff --git a/axum-macros/tests/from_request/fail/rejection_derive_and_via.stderr b/axum-macros/tests/from_request/fail/rejection_derive_and_via.stderr index 59c6b0bee3..3a50044f83 100644 --- a/axum-macros/tests/from_request/fail/rejection_derive_and_via.stderr +++ b/axum-macros/tests/from_request/fail/rejection_derive_and_via.stderr @@ -1,13 +1,5 @@ error: cannot use both `rejection_derive` and `via` - --> tests/from_request/fail/rejection_derive_and_via.rs:5:42 + --> tests/from_request/fail/rejection_derive_and_via.rs:4:42 | -5 | #[from_request(rejection_derive(!Error), via(Extension))] +4 | #[from_request(rejection_derive(!Error), via(axum::Extension))] | ^^^ - -warning: unused import: `axum::extract::Extension` - --> tests/from_request/fail/rejection_derive_and_via.rs:2:5 - | -2 | use axum::extract::Extension; - | ^^^^^^^^^^^^^^^^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default diff --git a/axum-macros/tests/from_request/fail/via_and_rejection_derive.rs b/axum-macros/tests/from_request/fail/via_and_rejection_derive.rs index 8c183a60d7..5f42ef0cf7 100644 --- a/axum-macros/tests/from_request/fail/via_and_rejection_derive.rs +++ b/axum-macros/tests/from_request/fail/via_and_rejection_derive.rs @@ -1,8 +1,7 @@ use axum_macros::FromRequest; -use axum::extract::Extension; #[derive(FromRequest, Clone)] -#[from_request(via(Extension), rejection_derive(!Error))] +#[from_request(via(axum::Extension), rejection_derive(!Error))] struct Extractor { config: String, } diff --git a/axum-macros/tests/from_request/fail/via_and_rejection_derive.stderr b/axum-macros/tests/from_request/fail/via_and_rejection_derive.stderr index 25f2011b28..af45e8f811 100644 --- a/axum-macros/tests/from_request/fail/via_and_rejection_derive.stderr +++ b/axum-macros/tests/from_request/fail/via_and_rejection_derive.stderr @@ -1,13 +1,5 @@ error: cannot use both `via` and `rejection_derive` - --> tests/from_request/fail/via_and_rejection_derive.rs:5:32 + --> tests/from_request/fail/via_and_rejection_derive.rs:4:38 | -5 | #[from_request(via(Extension), rejection_derive(!Error))] - | ^^^^^^^^^^^^^^^^ - -warning: unused import: `axum::extract::Extension` - --> tests/from_request/fail/via_and_rejection_derive.rs:2:5 - | -2 | use axum::extract::Extension; - | ^^^^^^^^^^^^^^^^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default +4 | #[from_request(via(axum::Extension), rejection_derive(!Error))] + | ^^^^^^^^^^^^^^^^ diff --git a/axum-macros/tests/from_request/fail/via_on_container_and_field.rs b/axum-macros/tests/from_request/fail/via_on_container_and_field.rs index f213857b03..8499659bde 100644 --- a/axum-macros/tests/from_request/fail/via_on_container_and_field.rs +++ b/axum-macros/tests/from_request/fail/via_on_container_and_field.rs @@ -1,9 +1,8 @@ use axum_macros::FromRequest; -use axum::extract::Extension; #[derive(FromRequest)] -#[from_request(via(Extension))] -struct Extractor(#[from_request(via(Extension))] State); +#[from_request(via(axum::Extension))] +struct Extractor(#[from_request(via(axum::Extension))] State); #[derive(Clone)] struct State; diff --git a/axum-macros/tests/from_request/fail/via_on_container_and_field.stderr b/axum-macros/tests/from_request/fail/via_on_container_and_field.stderr index 01b6cac073..ff63e37df0 100644 --- a/axum-macros/tests/from_request/fail/via_on_container_and_field.stderr +++ b/axum-macros/tests/from_request/fail/via_on_container_and_field.stderr @@ -1,13 +1,5 @@ error: `#[from_request(via(...))]` on a field cannot be used together with `#[from_request(...)]` on the container - --> tests/from_request/fail/via_on_container_and_field.rs:6:33 + --> tests/from_request/fail/via_on_container_and_field.rs:5:33 | -6 | struct Extractor(#[from_request(via(Extension))] State); +5 | struct Extractor(#[from_request(via(axum::Extension))] State); | ^^^ - -warning: unused import: `axum::extract::Extension` - --> tests/from_request/fail/via_on_container_and_field.rs:2:5 - | -2 | use axum::extract::Extension; - | ^^^^^^^^^^^^^^^^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default diff --git a/axum-macros/tests/from_request/pass/container.rs b/axum-macros/tests/from_request/pass/container.rs index a6732527e4..e8eaa0a58a 100644 --- a/axum-macros/tests/from_request/pass/container.rs +++ b/axum-macros/tests/from_request/pass/container.rs @@ -15,7 +15,7 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest, + Extractor: FromRequest<(), Body, Rejection = JsonRejection>, { } diff --git a/axum-macros/tests/from_request/pass/derive_opt_out.rs b/axum-macros/tests/from_request/pass/derive_opt_out.rs index 0bf24c7389..e73d5a959c 100644 --- a/axum-macros/tests/from_request/pass/derive_opt_out.rs +++ b/axum-macros/tests/from_request/pass/derive_opt_out.rs @@ -14,13 +14,14 @@ struct Extractor { struct OtherExtractor; #[async_trait] -impl FromRequest for OtherExtractor +impl FromRequest for OtherExtractor where - B: Send + 'static, + B: Send, + S: Send, { type Rejection = OtherExtractorRejection; - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/from_request/pass/empty_named.rs b/axum-macros/tests/from_request/pass/empty_named.rs index 2cc5dda8b6..eec021d0f5 100644 --- a/axum-macros/tests/from_request/pass/empty_named.rs +++ b/axum-macros/tests/from_request/pass/empty_named.rs @@ -5,7 +5,7 @@ struct Extractor {} fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, { } diff --git a/axum-macros/tests/from_request/pass/empty_tuple.rs b/axum-macros/tests/from_request/pass/empty_tuple.rs index bbb525fa1b..3d8bcd25c0 100644 --- a/axum-macros/tests/from_request/pass/empty_tuple.rs +++ b/axum-macros/tests/from_request/pass/empty_tuple.rs @@ -5,7 +5,7 @@ struct Extractor(); fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, { } diff --git a/axum-macros/tests/from_request/pass/enum_via.rs b/axum-macros/tests/from_request/pass/enum_via.rs index d6ba90e277..c68b9796a7 100644 --- a/axum-macros/tests/from_request/pass/enum_via.rs +++ b/axum-macros/tests/from_request/pass/enum_via.rs @@ -8,5 +8,5 @@ enum Extractor {} async fn foo(_: Extractor) {} fn main() { - Router::::new().route("/", get(foo)); + Router::<(), Body>::new().route("/", get(foo)); } diff --git a/axum-macros/tests/from_request/pass/named.rs b/axum-macros/tests/from_request/pass/named.rs index 4ee40a0447..89fb8da004 100644 --- a/axum-macros/tests/from_request/pass/named.rs +++ b/axum-macros/tests/from_request/pass/named.rs @@ -18,7 +18,7 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest, + Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>, { } diff --git a/axum-macros/tests/from_request/pass/named_via.rs b/axum-macros/tests/from_request/pass/named_via.rs index fc03b8c0fb..8a81869d1a 100644 --- a/axum-macros/tests/from_request/pass/named_via.rs +++ b/axum-macros/tests/from_request/pass/named_via.rs @@ -25,7 +25,7 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest, + Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>, { } diff --git a/axum-macros/tests/from_request/pass/override_rejection.rs b/axum-macros/tests/from_request/pass/override_rejection.rs index 431474b895..c308d61521 100644 --- a/axum-macros/tests/from_request/pass/override_rejection.rs +++ b/axum-macros/tests/from_request/pass/override_rejection.rs @@ -28,14 +28,15 @@ struct MyExtractor { struct OtherExtractor; #[async_trait] -impl FromRequest for OtherExtractor +impl FromRequest for OtherExtractor where B: Send + 'static, + S: Send, { // this rejection doesn't implement `Display` and `Error` type Rejection = (StatusCode, String); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { todo!() } } diff --git a/axum-macros/tests/from_request/pass/tuple.rs b/axum-macros/tests/from_request/pass/tuple.rs index 9786285223..2af407d0f9 100644 --- a/axum-macros/tests/from_request/pass/tuple.rs +++ b/axum-macros/tests/from_request/pass/tuple.rs @@ -5,7 +5,7 @@ struct Extractor(axum::http::HeaderMap, String); fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs index 0434bb29f3..00b6dd78df 100644 --- a/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs @@ -13,7 +13,7 @@ struct Payload {} fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs index df7761eab9..0b148ebc50 100644 --- a/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs @@ -27,7 +27,7 @@ struct Payload {} fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } diff --git a/axum-macros/tests/from_request/pass/tuple_via.rs b/axum-macros/tests/from_request/pass/tuple_via.rs index dde9887771..03a9e3610c 100644 --- a/axum-macros/tests/from_request/pass/tuple_via.rs +++ b/axum-macros/tests/from_request/pass/tuple_via.rs @@ -1,5 +1,5 @@ +use axum::Extension; use axum_macros::FromRequest; -use axum::extract::Extension; #[derive(FromRequest)] struct Extractor(#[from_request(via(Extension))] State); @@ -9,7 +9,7 @@ struct State; fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } diff --git a/axum-macros/tests/from_request/pass/unit.rs b/axum-macros/tests/from_request/pass/unit.rs index 57f774d143..3e5d986917 100644 --- a/axum-macros/tests/from_request/pass/unit.rs +++ b/axum-macros/tests/from_request/pass/unit.rs @@ -5,7 +5,7 @@ struct Extractor; fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, { } diff --git a/axum-macros/tests/typed_path/fail/not_deserialize.stderr b/axum-macros/tests/typed_path/fail/not_deserialize.stderr index 7581b3997c..9aabf3625f 100644 --- a/axum-macros/tests/typed_path/fail/not_deserialize.stderr +++ b/axum-macros/tests/typed_path/fail/not_deserialize.stderr @@ -15,5 +15,5 @@ error[E0277]: the trait bound `for<'de> MyPath: serde::de::Deserialize<'de>` is (T0, T1, T2, T3) and 138 others = note: required because of the requirements on the impl of `serde::de::DeserializeOwned` for `MyPath` - = note: required because of the requirements on the impl of `FromRequest` for `axum::extract::Path` + = note: required because of the requirements on the impl of `FromRequest` for `axum::extract::Path` = note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/typed_path/pass/customize_rejection.rs b/axum-macros/tests/typed_path/pass/customize_rejection.rs index 41aa7e614e..40f3ec0ada 100644 --- a/axum-macros/tests/typed_path/pass/customize_rejection.rs +++ b/axum-macros/tests/typed_path/pass/customize_rejection.rs @@ -40,7 +40,7 @@ impl Default for MyRejection { } fn main() { - axum::Router::::new() + axum::Router::<(), axum::body::Body>::new() .typed_get(|_: Result| async {}) .typed_post(|_: Result| async {}) .typed_put(|_: Result| async {}); diff --git a/axum-macros/tests/typed_path/pass/named_fields_struct.rs b/axum-macros/tests/typed_path/pass/named_fields_struct.rs index 6942bd3394..6119304080 100644 --- a/axum-macros/tests/typed_path/pass/named_fields_struct.rs +++ b/axum-macros/tests/typed_path/pass/named_fields_struct.rs @@ -9,7 +9,7 @@ struct MyPath { } fn main() { - axum::Router::::new().route("/", axum::routing::get(|_: MyPath| async {})); + axum::Router::<(), axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {})); assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id"); assert_eq!( diff --git a/axum-macros/tests/typed_path/pass/option_result.rs b/axum-macros/tests/typed_path/pass/option_result.rs index d89dea2dd5..252bde137f 100644 --- a/axum-macros/tests/typed_path/pass/option_result.rs +++ b/axum-macros/tests/typed_path/pass/option_result.rs @@ -20,7 +20,7 @@ struct UsersIndex; async fn result_handler_unit_struct(_: Result) {} fn main() { - axum::Router::::new() + axum::Router::<(), axum::body::Body>::new() .typed_get(option_handler) .typed_post(result_handler) .typed_post(result_handler_unit_struct); diff --git a/axum-macros/tests/typed_path/pass/tuple_struct.rs b/axum-macros/tests/typed_path/pass/tuple_struct.rs index 5e3d27ff40..4f8fa17eeb 100644 --- a/axum-macros/tests/typed_path/pass/tuple_struct.rs +++ b/axum-macros/tests/typed_path/pass/tuple_struct.rs @@ -8,7 +8,7 @@ pub type Result = std::result::Result; struct MyPath(u32, u32); fn main() { - axum::Router::::new().route("/", axum::routing::get(|_: MyPath| async {})); + axum::Router::<(), axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {})); assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id"); assert_eq!(format!("{}", MyPath(1, 2)), "/users/1/teams/2"); diff --git a/axum-macros/tests/typed_path/pass/unit_struct.rs b/axum-macros/tests/typed_path/pass/unit_struct.rs index 9b6a0f6e39..0ba27f81ac 100644 --- a/axum-macros/tests/typed_path/pass/unit_struct.rs +++ b/axum-macros/tests/typed_path/pass/unit_struct.rs @@ -5,7 +5,7 @@ use axum_extra::routing::TypedPath; struct MyPath; fn main() { - axum::Router::::new() + axum::Router::<(), axum::body::Body>::new() .route("/", axum::routing::get(|_: MyPath| async {})); assert_eq!(MyPath::PATH, "/users"); diff --git a/axum-macros/tests/typed_path/pass/wildcards.rs b/axum-macros/tests/typed_path/pass/wildcards.rs index 0c9155f71d..e7794fc895 100644 --- a/axum-macros/tests/typed_path/pass/wildcards.rs +++ b/axum-macros/tests/typed_path/pass/wildcards.rs @@ -8,5 +8,5 @@ struct MyPath { } fn main() { - axum::Router::::new().typed_get(|_: MyPath| async {}); + axum::Router::<(), axum::body::Body>::new().typed_get(|_: MyPath| async {}); } diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 256a8b2556..8321cc98ec 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -35,6 +35,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Added `debug_handler` which is an attribute macro that improves type errors when applied to handler function. It is re-exported from `axum-macros` +- **added:** Added new type safe `State` extractor. This can be used with + `Router::with_state` and gives compile errors for missing states, whereas + `Extension` would result in runtime errors ([#1155]) +- **breaking:** The following types or traits have a new `S` type param + (`()` by default) which represents the state ([#1155]): + - `FromRequest` + - `RequestParts` + - `Router` + - `MethodRouter` + - `Handler` +- **breaking:** `Router::route` now only accepts `MethodRouter`s created with + `get`, `post`, etc ([#1155]) +- **added:** `Router::route_service` for routing to arbitrary `Service`s ([#1155]) - **added:** Support any middleware response that implements `IntoResponse` ([#1152]) - **breaking:** Require middleware added with `Handler::layer` to have `Infallible` as the error type ([#1152]) @@ -54,6 +67,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1130]: https://github.com/tokio-rs/axum/pull/1130 [#1135]: https://github.com/tokio-rs/axum/pull/1135 [#1152]: https://github.com/tokio-rs/axum/pull/1152 +[#1155]: https://github.com/tokio-rs/axum/pull/1155 [#1171]: https://github.com/tokio-rs/axum/pull/1171 [#1239]: https://github.com/tokio-rs/axum/pull/1239 [#1248]: https://github.com/tokio-rs/axum/pull/1248 diff --git a/axum/src/docs/error_handling.md b/axum/src/docs/error_handling.md index 01a0afaa09..45c768c69a 100644 --- a/axum/src/docs/error_handling.md +++ b/axum/src/docs/error_handling.md @@ -69,7 +69,7 @@ let some_fallible_service = tower::service_fn(|_req| async { Ok::<_, anyhow::Error>(Response::new(Body::empty())) }); -let app = Router::new().route( +let app = Router::new().route_service( "/", // we cannot route to `some_fallible_service` directly since it might fail. // we have to use `handle_error` which converts its errors into responses diff --git a/axum/src/docs/extract.md b/axum/src/docs/extract.md index 92784a65e5..9c31d0e4d3 100644 --- a/axum/src/docs/extract.md +++ b/axum/src/docs/extract.md @@ -421,13 +421,14 @@ use http::{StatusCode, header::{HeaderValue, USER_AGENT}}; struct ExtractUserAgent(HeaderValue); #[async_trait] -impl FromRequest for ExtractUserAgent +impl FromRequest for ExtractUserAgent where B: Send, + S: Send, { type Rejection = (StatusCode, &'static str); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if let Some(user_agent) = req.headers().get(USER_AGENT) { Ok(ExtractUserAgent(user_agent.clone())) } else { @@ -472,13 +473,14 @@ struct AuthenticatedUser { } #[async_trait] -impl FromRequest for AuthenticatedUser +impl FromRequest for AuthenticatedUser where B: Send, + S: Send, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let TypedHeader(Authorization(token)) = TypedHeader::>::from_request(req) .await @@ -633,7 +635,7 @@ fn token_is_valid(token: &str) -> bool { } let app = Router::new().layer(middleware::from_fn(auth_middleware)); -# let _: Router = app; +# let _: Router<()> = app; ``` [`body::Body`]: crate::body::Body diff --git a/axum/src/docs/method_routing/fallback.md b/axum/src/docs/method_routing/fallback.md index c6a0d09024..906cbb3b5d 100644 --- a/axum/src/docs/method_routing/fallback.md +++ b/axum/src/docs/method_routing/fallback.md @@ -11,7 +11,7 @@ use axum::{ http::{StatusCode, Method, Uri}, }; -let handler = get(|| async {}).fallback(fallback.into_service()); +let handler = get(|| async {}).fallback(fallback); let app = Router::new().route("/", handler); @@ -36,11 +36,9 @@ use axum::{ http::{StatusCode, Uri}, }; -let one = get(|| async {}) - .fallback(fallback_one.into_service()); +let one = get(|| async {}).fallback(fallback_one); -let two = post(|| async {}) - .fallback(fallback_two.into_service()); +let two = post(|| async {}).fallback(fallback_two); let method_route = one.merge(two); diff --git a/axum/src/docs/middleware.md b/axum/src/docs/middleware.md index 3266634ba8..8d0e970c4f 100644 --- a/axum/src/docs/middleware.md +++ b/axum/src/docs/middleware.md @@ -6,7 +6,8 @@ - [Ordering](#ordering) - [Writing middleware](#writing-middleware) - [Routing to services/middleware and backpressure](#routing-to-servicesmiddleware-and-backpressure) -- [Sharing state between handlers and middleware](#sharing-state-between-handlers-and-middleware) +- [Accessing state in middleware](#accessing-state-in-middleware) +- [Passing state from middleware to handlers](#passing-state-from-middleware-to-handlers) # Intro @@ -95,7 +96,7 @@ let app = Router::new() .layer(layer_one) .layer(layer_two) .layer(layer_three); -# let app: Router = app; +# let _: Router<(), axum::body::Body> = app; ``` Think of the middleware as being layered like an onion where each new layer @@ -154,7 +155,7 @@ let app = Router::new() .layer(layer_two) .layer(layer_three), ); -# let app: Router = app; +# let _: Router<(), axum::body::Body> = app; ``` `ServiceBuilder` works by composing all layers into one such that they run top @@ -386,9 +387,119 @@ Also note that handlers created from async functions don't care about backpressure and are always ready. So if you're not using any Tower middleware you don't have to worry about any of this. -# Sharing state between handlers and middleware +# Accessing state in middleware -State can be shared between middleware and handlers using [request extensions]: +Handlers can access state using the [`State`] extractor but this isn't available +to middleware. Instead you have to pass the state directly to middleware using +either closure captures (for [`axum::middleware::from_fn`]) or regular struct +fields (if you're implementing a [`tower::Layer`]) + +## Accessing state in `axum::middleware::from_fn` + +```rust +use axum::{ + Router, + routing::get, + middleware::{self, Next}, + response::Response, + extract::State, + http::Request, +}; + +#[derive(Clone)] +struct AppState {} + +async fn my_middleware( + state: AppState, + req: Request, + next: Next, +) -> Response { + next.run(req).await +} + +async fn handler(_: State) {} + +let state = AppState {}; + +let app = Router::with_state(state.clone()) + .route("/", get(handler)) + .layer(middleware::from_fn(move |req, next| { + my_middleware(state.clone(), req, next) + })); +# let _: Router<_> = app; +``` + +## Accessing state in custom `tower::Layer`s + +```rust +use axum::{ + Router, + routing::get, + middleware::{self, Next}, + response::Response, + extract::State, + http::Request, +}; +use tower::{Layer, Service}; +use std::task::{Context, Poll}; + +#[derive(Clone)] +struct AppState {} + +#[derive(Clone)] +struct MyLayer { + state: AppState, +} + +impl Layer for MyLayer { + type Service = MyService; + + fn layer(&self, inner: S) -> Self::Service { + MyService { + inner, + state: self.state.clone(), + } + } +} + +#[derive(Clone)] +struct MyService { + inner: S, + state: AppState, +} + +impl Service> for MyService +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + // do something with `self.state` + + self.inner.call(req) + } +} + +async fn handler(_: State) {} + +let state = AppState {}; + +let app = Router::with_state(state.clone()) + .route("/", get(handler)) + .layer(MyLayer { state }); +# let _: Router<_> = app; +``` + +# Passing state from middleware to handlers + +State can be passed from middleware to handlers using [request extensions]: ```rust use axum::{ @@ -415,6 +526,8 @@ async fn auth(mut req: Request, next: Next) -> Result = app; ``` [Response extensions] can also be used but note that request extensions are not @@ -462,3 +575,4 @@ extensions you need. [`MethodRouter::route_layer`]: crate::routing::MethodRouter::route_layer [request extensions]: https://docs.rs/http/latest/http/request/struct.Request.html#method.extensions [Response extensions]: https://docs.rs/http/latest/http/response/struct.Response.html#method.extensions +[`State`]: crate::extract::State diff --git a/axum/src/docs/routing/fallback.md b/axum/src/docs/routing/fallback.md index a5c7467ed3..f2b5d3331e 100644 --- a/axum/src/docs/routing/fallback.md +++ b/axum/src/docs/routing/fallback.md @@ -1,4 +1,4 @@ -Add a fallback service to the router. +Add a fallback [`Handler`] to the router. This service will be called if no routes matches the incoming request. @@ -13,7 +13,7 @@ use axum::{ let app = Router::new() .route("/foo", get(|| async { /* ... */ })) - .fallback(fallback.into_service()); + .fallback(fallback); async fn fallback(uri: Uri) -> (StatusCode, String) { (StatusCode::NOT_FOUND, format!("No route for {}", uri)) diff --git a/axum/src/docs/routing/nest.md b/axum/src/docs/routing/nest.md index c1d14907ef..77ce306fa4 100644 --- a/axum/src/docs/routing/nest.md +++ b/axum/src/docs/routing/nest.md @@ -104,7 +104,7 @@ let api_routes = Router::new().nest("/users", get(|| async {})); let app = Router::new() .nest("/api", api_routes) - .fallback(fallback.into_service()); + .fallback(fallback); # let _: Router = app; ``` @@ -132,11 +132,11 @@ async fn api_fallback() -> (StatusCode, Json) { let api_routes = Router::new() .nest("/users", get(|| async {})) // add dedicated fallback for requests starting with `/api` - .fallback(api_fallback.into_service()); + .fallback(api_fallback); let app = Router::new() .nest("/api", api_routes) - .fallback(fallback.into_service()); + .fallback(fallback); # let _: Router = app; ``` diff --git a/axum/src/docs/routing/route.md b/axum/src/docs/routing/route.md index cd9d703e1b..fcc68cc46b 100644 --- a/axum/src/docs/routing/route.md +++ b/axum/src/docs/routing/route.md @@ -3,10 +3,10 @@ Add another route to the router. `path` is a string of path segments separated by `/`. Each segment can be either static, a capture, or a wildcard. -`service` is the [`Service`] that should receive the request if the path matches -`path`. `service` will commonly be a handler wrapped in a method router like -[`get`](crate::routing::get). See [`handler`](crate::handler) for more details -on handlers. +`method_router` is the [`MethodRouter`] that should receive the request if the +path matches `path`. `method_router` will commonly be a handler wrapped in a method +router like [`get`](crate::routing::get). See [`handler`](crate::handler) for +more details on handlers. # Static paths @@ -105,69 +105,6 @@ async fn serve_asset(Path(path): Path) {} # }; ``` -# Routing to any [`Service`] - -axum also supports routing to general [`Service`]s: - -```rust,no_run -use axum::{ - Router, - body::Body, - routing::{any_service, get_service}, - http::{Request, StatusCode}, - error_handling::HandleErrorLayer, -}; -use tower_http::services::ServeFile; -use http::Response; -use std::{convert::Infallible, io}; -use tower::service_fn; - -let app = Router::new() - .route( - // Any request to `/` goes to a service - "/", - // Services whose response body is not `axum::body::BoxBody` - // can be wrapped in `axum::routing::any_service` (or one of the other routing filters) - // to have the response body mapped - any_service(service_fn(|_: Request| async { - let res = Response::new(Body::from("Hi from `GET /`")); - Ok::<_, Infallible>(res) - })) - ) - .route( - "/foo", - // This service's response body is `axum::body::BoxBody` so - // it can be routed to directly. - service_fn(|req: Request| async move { - let body = Body::from(format!("Hi from `{} /foo`", req.method())); - let body = axum::body::boxed(body); - let res = Response::new(body); - Ok::<_, Infallible>(res) - }) - ) - .route( - // GET `/static/Cargo.toml` goes to a service from tower-http - "/static/Cargo.toml", - get_service(ServeFile::new("Cargo.toml")) - // though we must handle any potential errors - .handle_error(|error: io::Error| async move { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Unhandled internal error: {}", error), - ) - }) - ); -# async { -# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -# }; -``` - -Routing to arbitrary services in this way has complications for backpressure -([`Service::poll_ready`]). See the [Routing to services and backpressure] module -for more details. - -[Routing to services and backpressure]: middleware/index.html#routing-to-servicesmiddleware-and-backpressure - # Panics Panics if the route overlaps with another route: @@ -187,21 +124,3 @@ The static route `/foo` and the dynamic route `/:key` are not considered to overlap and `/foo` will take precedence. Also panics if `path` is empty. - -## Nesting - -`route` cannot be used to nest `Router`s. Instead use [`Router::nest`]. - -Attempting to will result in a panic: - -```rust,should_panic -use axum::{routing::get, Router}; - -let app = Router::new().route( - "/", - Router::new().route("/foo", get(|| async {})), -); -# async { -# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -# }; -``` diff --git a/axum/src/docs/routing/route_service.md b/axum/src/docs/routing/route_service.md new file mode 100644 index 0000000000..a14323a933 --- /dev/null +++ b/axum/src/docs/routing/route_service.md @@ -0,0 +1,81 @@ +Add another route to the router that calls a [`Service`]. + +# Example + +```rust,no_run +use axum::{ + Router, + body::Body, + routing::{any_service, get_service}, + http::{Request, StatusCode}, + error_handling::HandleErrorLayer, +}; +use tower_http::services::ServeFile; +use http::Response; +use std::{convert::Infallible, io}; +use tower::service_fn; + +let app = Router::new() + .route( + // Any request to `/` goes to a service + "/", + // Services whose response body is not `axum::body::BoxBody` + // can be wrapped in `axum::routing::any_service` (or one of the other routing filters) + // to have the response body mapped + any_service(service_fn(|_: Request| async { + let res = Response::new(Body::from("Hi from `GET /`")); + Ok::<_, Infallible>(res) + })) + ) + .route_service( + "/foo", + // This service's response body is `axum::body::BoxBody` so + // it can be routed to directly. + service_fn(|req: Request| async move { + let body = Body::from(format!("Hi from `{} /foo`", req.method())); + let body = axum::body::boxed(body); + let res = Response::new(body); + Ok::<_, Infallible>(res) + }) + ) + .route( + // GET `/static/Cargo.toml` goes to a service from tower-http + "/static/Cargo.toml", + get_service(ServeFile::new("Cargo.toml")) + // though we must handle any potential errors + .handle_error(|error: io::Error| async move { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Unhandled internal error: {}", error), + ) + }) + ); +# async { +# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +# }; +``` + +Routing to arbitrary services in this way has complications for backpressure +([`Service::poll_ready`]). See the [Routing to services and backpressure] module +for more details. + +# Panics + +Panics for the same reasons as [`Router::route`] or if you attempt to route to a +`Router`: + +```rust,should_panic +use axum::{routing::get, Router}; + +let app = Router::new().route_service( + "/", + Router::new().route("/foo", get(|| async {})), +); +# async { +# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +# }; +``` + +Use [`Router::nest`] instead. + +[Routing to services and backpressure]: middleware/index.html#routing-to-servicesmiddleware-and-backpressure diff --git a/axum/src/error_handling/mod.rs b/axum/src/error_handling/mod.rs index 76c3d02531..6a72d82069 100644 --- a/axum/src/error_handling/mod.rs +++ b/axum/src/error_handling/mod.rs @@ -1,7 +1,6 @@ #![doc = include_str!("../docs/error_handling.md")] use crate::{ - body::boxed, extract::{FromRequest, RequestParts}, http::{Request, StatusCode}, response::{IntoResponse, Response}, @@ -113,16 +112,16 @@ where } } -impl Service> for HandleError +impl Service> for HandleError where - S: Service> + Clone + Send + 'static, + S: Service> + Clone + Send + 'static, S::Response: IntoResponse + Send, S::Error: Send, S::Future: Send, F: FnOnce(S::Error) -> Fut + Clone + Send + 'static, Fut: Future + Send, Res: IntoResponse, - ReqBody: Send + 'static, + B: Send + 'static, { type Response = Response; type Error = Infallible; @@ -132,7 +131,7 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { let f = self.f.clone(); let clone = self.inner.clone(); @@ -152,18 +151,18 @@ where #[allow(unused_macros)] macro_rules! impl_service { ( $($ty:ident),* $(,)? ) => { - impl Service> + impl Service> for HandleError where - S: Service> + Clone + Send + 'static, + S: Service> + Clone + Send + 'static, S::Response: IntoResponse + Send, S::Error: Send, S::Future: Send, F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static, Fut: Future + Send, Res: IntoResponse, - $( $ty: FromRequest + Send,)* - ReqBody: Send + 'static, + $( $ty: FromRequest<(), B> + Send,)* + B: Send + 'static, { type Response = Response; type Error = Infallible; @@ -175,7 +174,7 @@ macro_rules! impl_service { } #[allow(non_snake_case)] - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { let f = self.f.clone(); let clone = self.inner.clone(); @@ -187,7 +186,7 @@ macro_rules! impl_service { $( let $ty = match $ty::from_request(&mut req).await { Ok(value) => value, - Err(rejection) => return Ok(rejection.into_response().map(boxed)), + Err(rejection) => return Ok(rejection.into_response()), }; )* @@ -200,7 +199,7 @@ macro_rules! impl_service { match inner.oneshot(req).await { Ok(res) => Ok(res.into_response()), - Err(err) => Ok(f($($ty),*, err).await.into_response().map(boxed)), + Err(err) => Ok(f($($ty),*, err).await.into_response()), } }); diff --git a/axum/src/extension.rs b/axum/src/extension.rs index 390f29e031..d040b9e4e5 100644 --- a/axum/src/extension.rs +++ b/axum/src/extension.rs @@ -73,14 +73,15 @@ use tower_service::Service; pub struct Extension(pub T); #[async_trait] -impl FromRequest for Extension +impl FromRequest for Extension where T: Clone + Send + Sync + 'static, B: Send, + S: Send, { type Rejection = ExtensionRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let value = req .extensions() .get::() diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index 8363a25ebf..3aa7684c81 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -128,14 +128,15 @@ opaque_future! { pub struct ConnectInfo(pub T); #[async_trait] -impl FromRequest for ConnectInfo +impl FromRequest for ConnectInfo where B: Send, + S: Send, T: Clone + Send + Sync + 'static, { - type Rejection = as FromRequest>::Rejection; + type Rejection = as FromRequest>::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let Extension(connect_info) = Extension::::from_request(req).await?; Ok(connect_info) } diff --git a/axum/src/extract/content_length_limit.rs b/axum/src/extract/content_length_limit.rs index 62148a8ea7..f4c475437f 100644 --- a/axum/src/extract/content_length_limit.rs +++ b/axum/src/extract/content_length_limit.rs @@ -36,15 +36,16 @@ use std::ops::Deref; pub struct ContentLengthLimit(pub T); #[async_trait] -impl FromRequest for ContentLengthLimit +impl FromRequest for ContentLengthLimit where - T: FromRequest, + T: FromRequest, T::Rejection: IntoResponse, B: Send, + S: Send, { type Rejection = ContentLengthLimitRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let content_length = req .headers() .get(http::header::CONTENT_LENGTH) diff --git a/axum/src/extract/host.rs b/axum/src/extract/host.rs index 85ab8a9dcd..79ae13fc28 100644 --- a/axum/src/extract/host.rs +++ b/axum/src/extract/host.rs @@ -21,13 +21,14 @@ const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; pub struct Host(pub String); #[async_trait] -impl FromRequest for Host +impl FromRequest for Host where B: Send, + S: Send, { type Rejection = HostRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if let Some(host) = parse_forwarded(req.headers()) { return Ok(Host(host.to_owned())); } diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index d0cb01afd6..6413cf7714 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -64,13 +64,14 @@ impl MatchedPath { } #[async_trait] -impl FromRequest for MatchedPath +impl FromRequest for MatchedPath where B: Send, + S: Send, { type Rejection = MatchedPathRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let matched_path = req .extensions() .get::() @@ -84,7 +85,9 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{extract::Extension, handler::Handler, routing::get, test_helpers::*, Router}; + use crate::{ + extract::Extension, handler::HandlerWithoutStateExt, routing::get, test_helpers::*, Router, + }; use http::{Request, StatusCode}; use std::task::{Context, Poll}; use tower::layer::layer_fn; @@ -93,7 +96,7 @@ mod tests { #[derive(Clone)] struct SetMatchedPathExtension(S); - impl Service> for SetMatchedPathExtension + impl Service> for SetMatchedPathExtension where S: Service>, { diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 6f0435f6e9..081793a83c 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -14,9 +14,10 @@ mod content_length_limit; mod host; mod raw_query; mod request_parts; +mod state; #[doc(inline)] -pub use axum_core::extract::{FromRequest, RequestParts}; +pub use axum_core::extract::{FromRef, FromRequest, RequestParts}; #[doc(inline)] #[allow(deprecated)] @@ -27,6 +28,7 @@ pub use self::{ path::Path, raw_query::RawQuery, request_parts::{BodyStream, RawBody}, + state::State, }; #[doc(no_inline)] @@ -73,13 +75,13 @@ pub use self::ws::WebSocketUpgrade; #[doc(no_inline)] pub use crate::TypedHeader; -pub(crate) fn take_body(req: &mut RequestParts) -> Result { +pub(crate) fn take_body(req: &mut RequestParts) -> Result { req.take_body().ok_or_else(BodyAlreadyExtracted::default) } // this is duplicated in `axum-extra/src/extract/form.rs` -pub(super) fn has_content_type( - req: &RequestParts, +pub(super) fn has_content_type( + req: &RequestParts, expected_content_type: &mime::Mime, ) -> bool { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index 8f58455fb9..076f4db106 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -50,14 +50,15 @@ pub struct Multipart { } #[async_trait] -impl FromRequest for Multipart +impl FromRequest for Multipart where B: HttpBody + Default + Unpin + Send + 'static, B::Error: Into, + S: Send, { type Rejection = MultipartRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let stream = BodyStream::from_request(req).await?; let headers = req.headers(); let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?; @@ -179,7 +180,7 @@ impl<'a> Field<'a> { /// } /// /// let app = Router::new().route("/upload", post(upload)); - /// # let _: Router = app; + /// # let _: Router = app; /// ``` pub async fn chunk(&mut self) -> Result, MultipartError> { self.inner diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index ab298143fb..ca9e9fb605 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -163,14 +163,15 @@ impl DerefMut for Path { } #[async_trait] -impl FromRequest for Path +impl FromRequest for Path where T: DeserializeOwned + Send, B: Send, + S: Send, { type Rejection = PathRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let params = match req.extensions_mut().get::() { Some(UrlParams::Params(params)) => params, Some(UrlParams::InvalidUtf8InPathParam { key }) => { diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index dc024366b2..ce1f747cc1 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -49,14 +49,15 @@ use std::ops::Deref; pub struct Query(pub T); #[async_trait] -impl FromRequest for Query +impl FromRequest for Query where T: DeserializeOwned, B: Send, + S: Send, { type Rejection = QueryRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let query = req.uri().query().unwrap_or_default(); let value = serde_urlencoded::from_str(query) .map_err(FailedToDeserializeQueryString::__private_new)?; @@ -81,7 +82,8 @@ mod tests { use std::fmt::Debug; async fn check(uri: impl AsRef, value: T) { - let mut req = RequestParts::new(Request::builder().uri(uri.as_ref()).body(()).unwrap()); + let req = Request::builder().uri(uri.as_ref()).body(()).unwrap(); + let mut req = RequestParts::new(req); assert_eq!(Query::::from_request(&mut req).await.unwrap().0, value); } diff --git a/axum/src/extract/raw_query.rs b/axum/src/extract/raw_query.rs index 463c31d88a..faf8df6e4c 100644 --- a/axum/src/extract/raw_query.rs +++ b/axum/src/extract/raw_query.rs @@ -27,13 +27,14 @@ use std::convert::Infallible; pub struct RawQuery(pub Option); #[async_trait] -impl FromRequest for RawQuery +impl FromRequest for RawQuery where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let query = req.uri().query().map(|query| query.to_owned()); Ok(Self(query)) } diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs index 02b044c50c..c04ff5349b 100644 --- a/axum/src/extract/request_parts.rs +++ b/axum/src/extract/request_parts.rs @@ -86,13 +86,14 @@ pub struct OriginalUri(pub Uri); #[cfg(feature = "original-uri")] #[async_trait] -impl FromRequest for OriginalUri +impl FromRequest for OriginalUri where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let uri = Extension::::from_request(req) .await .unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone()))) @@ -140,15 +141,16 @@ impl Stream for BodyStream { } #[async_trait] -impl FromRequest for BodyStream +impl FromRequest for BodyStream where B: HttpBody + Send + 'static, B::Data: Into, B::Error: Into, + S: Send, { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)? .map_data(Into::into) .map_err(|err| Error::new(err.into())); @@ -196,13 +198,14 @@ fn body_stream_traits() { pub struct RawBody(pub B); #[async_trait] -impl FromRequest for RawBody +impl FromRequest for RawBody where B: Send, + S: Send, { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; Ok(Self(body)) } diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs new file mode 100644 index 0000000000..94ccf5b1a6 --- /dev/null +++ b/axum/src/extract/state.rs @@ -0,0 +1,207 @@ +use async_trait::async_trait; +use axum_core::extract::{FromRef, FromRequest, RequestParts}; +use std::{ + convert::Infallible, + ops::{Deref, DerefMut}, +}; + +/// Extractor for state. +/// +/// Note this extractor is not available to middleware. See ["Accessing state in +/// middleware"][state-from-middleware] for how to access state in middleware. +/// +/// [state-from-middleware]: ../middleware/index.html#accessing-state-in-middleware +/// +/// # With `Router` +/// +/// ``` +/// use axum::{Router, routing::get, extract::State}; +/// +/// // the application state +/// // +/// // here you can put configuration, database connection pools, or whatever +/// // state you need +/// #[derive(Clone)] +/// struct AppState {} +/// +/// let state = AppState {}; +/// +/// // create a `Router` that holds our state +/// let app = Router::with_state(state).route("/", get(handler)); +/// +/// async fn handler( +/// // access the state via the `State` extractor +/// // extracting a state of the wrong type results in a compile error +/// State(state): State, +/// ) { +/// // use `state`... +/// } +/// # let _: Router = app; +/// ``` +/// +/// # With `MethodRouter` +/// +/// ``` +/// use axum::{routing::get, extract::State}; +/// +/// #[derive(Clone)] +/// struct AppState {} +/// +/// let state = AppState {}; +/// +/// let method_router_with_state = get(handler) +/// // provide the state so the handler can access it +/// .with_state(state); +/// +/// async fn handler(State(state): State) { +/// // use `state`... +/// } +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(method_router_with_state.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// # With `Handler` +/// +/// ``` +/// use axum::{routing::get, handler::Handler, extract::State}; +/// +/// #[derive(Clone)] +/// struct AppState {} +/// +/// let state = AppState {}; +/// +/// async fn handler(State(state): State) { +/// // use `state`... +/// } +/// +/// // provide the state so the handler can access it +/// let handler_with_state = handler.with_state(state); +/// +/// # async { +/// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) +/// .serve(handler_with_state.into_make_service()) +/// .await +/// .expect("server failed"); +/// # }; +/// ``` +/// +/// # Substates +/// +/// [`State`] only allows a single state type but you can use [`From`] to extract "substates": +/// +/// ``` +/// use axum::{Router, routing::get, extract::{State, FromRef}}; +/// +/// // the application state +/// #[derive(Clone)] +/// struct AppState { +/// // that holds some api specific state +/// api_state: ApiState, +/// } +/// +/// // the api specific state +/// #[derive(Clone)] +/// struct ApiState {} +/// +/// // support converting an `AppState` in an `ApiState` +/// impl FromRef for ApiState { +/// fn from_ref(app_state: &AppState) -> ApiState { +/// app_state.api_state.clone() +/// } +/// } +/// +/// let state = AppState { +/// api_state: ApiState {}, +/// }; +/// +/// let app = Router::with_state(state) +/// .route("/", get(handler)) +/// .route("/api/users", get(api_users)); +/// +/// async fn api_users( +/// // access the api specific state +/// State(api_state): State, +/// ) { +/// } +/// +/// async fn handler( +/// // we can still access to top level state +/// State(state): State, +/// ) { +/// } +/// # let _: Router = app; +/// ``` +/// +/// # For library authors +/// +/// If you're writing a library that has an extractor that needs state, this is the recommended way +/// to do it: +/// +/// ```rust +/// use axum_core::extract::{FromRequest, RequestParts, FromRef}; +/// use async_trait::async_trait; +/// use std::convert::Infallible; +/// +/// // the extractor your library provides +/// struct MyLibraryExtractor; +/// +/// #[async_trait] +/// impl FromRequest for MyLibraryExtractor +/// where +/// B: Send, +/// // keep `S` generic but require that it can produce a `MyLibraryState` +/// // this means users will have to implement `FromRef for MyLibraryState` +/// MyLibraryState: FromRef, +/// S: Send, +/// { +/// type Rejection = Infallible; +/// +/// async fn from_request(req: &mut RequestParts) -> Result { +/// // get a `MyLibraryState` from a reference to the state +/// let state = MyLibraryState::from_ref(req.state()); +/// +/// // ... +/// # todo!() +/// } +/// } +/// +/// // the state your library needs +/// struct MyLibraryState { +/// // ... +/// } +/// ``` +/// +/// Note that you don't need to use the `State` extractor since you can access the state directly +/// from [`RequestParts`]. +#[derive(Debug, Default, Clone, Copy)] +pub struct State(pub S); + +#[async_trait] +impl FromRequest for State +where + B: Send, + InnerState: FromRef, + OuterState: Send, +{ + type Rejection = Infallible; + + async fn from_request(req: &mut RequestParts) -> Result { + let inner_state = InnerState::from_ref(req.state()); + Ok(Self(inner_state)) + } +} + +impl Deref for State { + type Target = S; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for State { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 5643c819b1..952ea13636 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -275,13 +275,14 @@ impl WebSocketUpgrade { } #[async_trait] -impl FromRequest for WebSocketUpgrade +impl FromRequest for WebSocketUpgrade where B: Send, + S: Send, { type Rejection = WebSocketUpgradeRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if req.method() != Method::GET { return Err(MethodNotGet.into()); } @@ -320,7 +321,7 @@ where } } -fn header_eq(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { +fn header_eq(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { if let Some(header) = req.headers().get(&key) { header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) } else { @@ -328,7 +329,7 @@ fn header_eq(req: &RequestParts, key: HeaderName, value: &'static str) -> } } -fn header_contains(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { +fn header_contains(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { let header = if let Some(header) = req.headers().get(&key) { header } else { diff --git a/axum/src/form.rs b/axum/src/form.rs index d1a16f7b6f..8267b8efe5 100644 --- a/axum/src/form.rs +++ b/axum/src/form.rs @@ -56,16 +56,17 @@ use std::ops::Deref; pub struct Form(pub T); #[async_trait] -impl FromRequest for Form +impl FromRequest for Form where T: DeserializeOwned, B: HttpBody + Send, B::Data: Send, B::Error: Into, + S: Send, { type Rejection = FormRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if req.method() == Method::GET { let query = req.uri().query().unwrap_or_default(); let value = serde_urlencoded::from_str(query) @@ -125,29 +126,27 @@ mod tests { } async fn check_query(uri: impl AsRef, value: T) { - let mut req = RequestParts::new( - Request::builder() - .uri(uri.as_ref()) - .body(Empty::::new()) - .unwrap(), - ); + let req = Request::builder() + .uri(uri.as_ref()) + .body(Empty::::new()) + .unwrap(); + let mut req = RequestParts::new(req); assert_eq!(Form::::from_request(&mut req).await.unwrap().0, value); } async fn check_body(value: T) { - let mut req = RequestParts::new( - Request::builder() - .uri("http://example.com/test") - .method(Method::POST) - .header( - http::header::CONTENT_TYPE, - mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), - ) - .body(Full::::new( - serde_urlencoded::to_string(&value).unwrap().into(), - )) - .unwrap(), - ); + let req = Request::builder() + .uri("http://example.com/test") + .method(Method::POST) + .header( + http::header::CONTENT_TYPE, + mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), + ) + .body(Full::::new( + serde_urlencoded::to_string(&value).unwrap().into(), + )) + .unwrap(); + let mut req = RequestParts::new(req); assert_eq!(Form::::from_request(&mut req).await.unwrap().0, value); } @@ -204,21 +203,20 @@ mod tests { #[tokio::test] async fn test_incorrect_content_type() { - let mut req = RequestParts::new( - Request::builder() - .uri("http://example.com/test") - .method(Method::POST) - .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) - .body(Full::::new( - serde_urlencoded::to_string(&Pagination { - size: Some(10), - page: None, - }) - .unwrap() - .into(), - )) - .unwrap(), - ); + let req = Request::builder() + .uri("http://example.com/test") + .method(Method::POST) + .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) + .body(Full::::new( + serde_urlencoded::to_string(&Pagination { + size: Some(10), + page: None, + }) + .unwrap() + .into(), + )) + .unwrap(); + let mut req = RequestParts::new(req); assert!(matches!( Form::::from_request(&mut req) .await diff --git a/axum/src/handler/future.rs b/axum/src/handler/future.rs index b8b3e4e867..59487c31b2 100644 --- a/axum/src/handler/future.rs +++ b/axum/src/handler/future.rs @@ -19,29 +19,29 @@ opaque_future! { pin_project! { /// The response future for [`Layered`](super::Layered). - pub struct LayeredFuture + pub struct LayeredFuture where - S: Service>, + S: Service>, { #[pin] - inner: Map>, fn(Result) -> Response>, + inner: Map>, fn(Result) -> Response>, } } -impl LayeredFuture +impl LayeredFuture where - S: Service>, + S: Service>, { pub(super) fn new( - inner: Map>, fn(Result) -> Response>, + inner: Map>, fn(Result) -> Response>, ) -> Self { Self { inner } } } -impl Future for LayeredFuture +impl Future for LayeredFuture where - S: Service>, + S: Service>, { type Output = Response; diff --git a/axum/src/handler/into_service.rs b/axum/src/handler/into_service.rs index 34f36b2d21..c8e51a243c 100644 --- a/axum/src/handler/into_service.rs +++ b/axum/src/handler/into_service.rs @@ -11,29 +11,40 @@ use tower_service::Service; /// An adapter that makes a [`Handler`] into a [`Service`]. /// -/// Created with [`Handler::into_service`]. -pub struct IntoService { +/// Created with [`HandlerWithoutStateExt::into_service`]. +/// +/// [`HandlerWithoutStateExt::into_service`]: super::HandlerWithoutStateExt::into_service +pub struct IntoService { handler: H, + state: S, _marker: PhantomData (T, B)>, } +impl IntoService { + /// Get a reference to the state. + pub fn state(&self) -> &S { + &self.state + } +} + #[test] fn traits() { use crate::test_helpers::*; - assert_send::>(); - assert_sync::>(); + assert_send::>(); + assert_sync::>(); } -impl IntoService { - pub(super) fn new(handler: H) -> Self { +impl IntoService { + pub(super) fn new(handler: H, state: S) -> Self { Self { handler, + state, _marker: PhantomData, } } } -impl fmt::Debug for IntoService { +impl fmt::Debug for IntoService { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("IntoService") .field(&format_args!("...")) @@ -41,22 +52,25 @@ impl fmt::Debug for IntoService { } } -impl Clone for IntoService +impl Clone for IntoService where H: Clone, + S: Clone, { fn clone(&self) -> Self { Self { handler: self.handler.clone(), + state: self.state.clone(), _marker: PhantomData, } } } -impl Service> for IntoService +impl Service> for IntoService where - H: Handler + Clone + Send + 'static, + H: Handler + Clone + Send + 'static, B: Send + 'static, + S: Clone, { type Response = Response; type Error = Infallible; @@ -74,7 +88,7 @@ where use futures_util::future::FutureExt; let handler = self.handler.clone(); - let future = Handler::call(handler, req); + let future = Handler::call(handler, self.state.clone(), req); let future = future.map(Ok as _); super::future::IntoServiceFuture::new(future) diff --git a/axum/src/handler/into_service_state_in_extension.rs b/axum/src/handler/into_service_state_in_extension.rs new file mode 100644 index 0000000000..011161d93a --- /dev/null +++ b/axum/src/handler/into_service_state_in_extension.rs @@ -0,0 +1,85 @@ +use super::Handler; +use crate::response::Response; +use http::Request; +use std::{ + convert::Infallible, + fmt, + marker::PhantomData, + task::{Context, Poll}, +}; +use tower_service::Service; + +pub(crate) struct IntoServiceStateInExtension { + handler: H, + _marker: PhantomData (T, S, B)>, +} + +#[test] +fn traits() { + use crate::test_helpers::*; + assert_send::>(); + assert_sync::>(); +} + +impl IntoServiceStateInExtension { + pub(crate) fn new(handler: H) -> Self { + Self { + handler, + _marker: PhantomData, + } + } +} + +impl fmt::Debug for IntoServiceStateInExtension { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("IntoServiceStateInExtension") + .field(&format_args!("...")) + .finish() + } +} + +impl Clone for IntoServiceStateInExtension +where + H: Clone, +{ + fn clone(&self) -> Self { + Self { + handler: self.handler.clone(), + _marker: PhantomData, + } + } +} + +impl Service> for IntoServiceStateInExtension +where + H: Handler + Clone + Send + 'static, + B: Send + 'static, + S: Clone + Send + Sync + 'static, +{ + type Response = Response; + type Error = Infallible; + type Future = super::future::IntoServiceFuture; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + // `IntoServiceStateInExtension` can only be constructed from async functions which are always ready, or + // from `Layered` which bufferes in `::call` and is therefore + // also always ready. + Poll::Ready(Ok(())) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + use futures_util::future::FutureExt; + + let state = req + .extensions_mut() + .remove::() + .expect("state extension missing. This is a bug in axum, please file an issue"); + + let handler = self.handler.clone(); + let future = Handler::call(handler, state, req); + let future = future.map(Ok as _); + + super::future::IntoServiceFuture::new(future) + } +} diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 9002685352..d51423468a 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -49,8 +49,11 @@ use tower_service::Service; pub mod future; mod into_service; +mod into_service_state_in_extension; +mod with_state; -pub use self::into_service::IntoService; +pub(crate) use self::into_service_state_in_extension::IntoServiceStateInExtension; +pub use self::{into_service::IntoService, with_state::WithState}; /// Trait for async functions that can be used to handle requests. /// @@ -59,13 +62,45 @@ pub use self::into_service::IntoService; /// /// See the [module docs](crate::handler) for more details. /// +/// # Converting `Handler`s into [`Service`]s +/// +/// To convert `Handler`s into [`Service`]s you have to call either +/// [`HandlerWithoutStateExt::into_service`] or [`Handler::with_state`]: +/// +/// ``` +/// use tower::Service; +/// use axum::{ +/// extract::State, +/// body::Body, +/// http::Request, +/// handler::{HandlerWithoutStateExt, Handler}, +/// }; +/// +/// // this handler doesn't require any state +/// async fn one() {} +/// // so it can be converted to a service with `HandlerWithoutStateExt::into_service` +/// assert_service(one.into_service()); +/// +/// // this handler requires state +/// async fn two(_: State) {} +/// // so we have to provide it +/// let handler_with_state = two.with_state(String::new()); +/// // which gives us a `Service` +/// assert_service(handler_with_state); +/// +/// // helper to check that a value implements `Service` +/// fn assert_service(service: S) +/// where +/// S: Service>, +/// {} +/// ``` #[doc = include_str!("../docs/debugging_handler_type_errors.md")] -pub trait Handler: Clone + Send + Sized + 'static { +pub trait Handler: Clone + Send + Sized + 'static { /// The type of future calling this handler returns. type Future: Future + Send + 'static; /// Call the handler with the given request. - fn call(self, req: Request) -> Self::Future; + fn call(self, state: S, req: Request) -> Self::Future; /// Apply a [`tower::Layer`] to the handler. /// @@ -103,112 +138,26 @@ pub trait Handler: Clone + Send + Sized + 'static { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - fn layer(self, layer: L) -> Layered + fn layer(self, layer: L) -> Layered where - L: Layer>, + L: Layer>, { - Layered::new(layer.layer(self.into_service())) - } - - /// Convert the handler into a [`Service`]. - /// - /// This is commonly used together with [`Router::fallback`]: - /// - /// ```rust - /// use axum::{ - /// Server, - /// handler::Handler, - /// http::{Uri, Method, StatusCode}, - /// response::IntoResponse, - /// routing::{get, Router}, - /// }; - /// use tower::make::Shared; - /// use std::net::SocketAddr; - /// - /// async fn handler(method: Method, uri: Uri) -> (StatusCode, String) { - /// (StatusCode::NOT_FOUND, format!("Nothing to see at {} {}", method, uri)) - /// } - /// - /// let app = Router::new() - /// .route("/", get(|| async {})) - /// .fallback(handler.into_service()); - /// - /// # async { - /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) - /// .serve(app.into_make_service()) - /// .await?; - /// # Ok::<_, hyper::Error>(()) - /// # }; - /// ``` - /// - /// [`Router::fallback`]: crate::routing::Router::fallback - fn into_service(self) -> IntoService { - IntoService::new(self) - } - - /// Convert the handler into a [`MakeService`]. - /// - /// This allows you to serve a single handler if you don't need any routing: - /// - /// ```rust - /// use axum::{ - /// Server, handler::Handler, http::{Uri, Method}, response::IntoResponse, - /// }; - /// use std::net::SocketAddr; - /// - /// async fn handler(method: Method, uri: Uri, body: String) -> String { - /// format!("received `{} {}` with body `{:?}`", method, uri, body) - /// } - /// - /// # async { - /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) - /// .serve(handler.into_make_service()) - /// .await?; - /// # Ok::<_, hyper::Error>(()) - /// # }; - /// ``` - /// - /// [`MakeService`]: tower::make::MakeService - fn into_make_service(self) -> IntoMakeService> { - IntoMakeService::new(self.into_service()) + Layered { + layer, + handler: self, + _marker: PhantomData, + } } - /// Convert the handler into a [`MakeService`] which stores information - /// about the incoming connection. - /// - /// See [`Router::into_make_service_with_connect_info`] for more details. - /// - /// ```rust - /// use axum::{ - /// Server, - /// handler::Handler, - /// response::IntoResponse, - /// extract::ConnectInfo, - /// }; - /// use std::net::SocketAddr; - /// - /// async fn handler(ConnectInfo(addr): ConnectInfo) -> String { - /// format!("Hello {}", addr) - /// } - /// - /// # async { - /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) - /// .serve(handler.into_make_service_with_connect_info::()) - /// .await?; - /// # Ok::<_, hyper::Error>(()) - /// # }; - /// ``` - /// - /// [`MakeService`]: tower::make::MakeService - /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info - fn into_make_service_with_connect_info( - self, - ) -> IntoMakeServiceWithConnectInfo, C> { - IntoMakeServiceWithConnectInfo::new(self.into_service()) + /// Convert the handler into a [`Service`] by providing the state + fn with_state(self, state: S) -> WithState { + WithState { + service: IntoService::new(self, state), + } } } -impl Handler<(), B> for F +impl Handler<(), S, B> for F where F: FnOnce() -> Fut + Clone + Send + 'static, Fut: Future + Send, @@ -217,7 +166,7 @@ where { type Future = Pin + Send>>; - fn call(self, _req: Request) -> Self::Future { + fn call(self, _state: S, _req: Request) -> Self::Future { Box::pin(async move { self().await.into_response() }) } } @@ -225,19 +174,20 @@ where macro_rules! impl_handler { ( $($ty:ident),* $(,)? ) => { #[allow(non_snake_case)] - impl Handler<($($ty,)*), B> for F + impl Handler<($($ty,)*), S, B> for F where F: FnOnce($($ty,)*) -> Fut + Clone + Send + 'static, Fut: Future + Send, B: Send + 'static, + S: Send + 'static, Res: IntoResponse, - $( $ty: FromRequest + Send,)* + $( $ty: FromRequest + Send,)* { type Future = Pin + Send>>; - fn call(self, req: Request) -> Self::Future { + fn call(self, state: S, req: Request) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::new(req); + let mut req = RequestParts::with_state(state, req); $( let $ty = match $ty::from_request(&mut req).await { @@ -260,58 +210,116 @@ all_the_tuples!(impl_handler); /// A [`Service`] created from a [`Handler`] by applying a Tower middleware. /// /// Created with [`Handler::layer`]. See that method for more details. -pub struct Layered { - svc: S, - _input: PhantomData T>, +pub struct Layered { + layer: L, + handler: H, + _marker: PhantomData (T, S, B)>, } -impl fmt::Debug for Layered +impl fmt::Debug for Layered where - S: fmt::Debug, + L: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Layered").field("svc", &self.svc).finish() + f.debug_struct("Layered") + .field("layer", &self.layer) + .finish() } } -impl Clone for Layered +impl Clone for Layered where - S: Clone, + L: Clone, + H: Clone, { fn clone(&self) -> Self { - Self::new(self.svc.clone()) + Self { + layer: self.layer.clone(), + handler: self.handler.clone(), + _marker: PhantomData, + } } } -impl Handler for Layered +impl Handler for Layered where - S: Service, Error = Infallible> + Clone + Send + 'static, - S::Response: IntoResponse, - S::Future: Send, + L: Layer> + Clone + Send + 'static, + H: Handler, + L::Service: Service, Error = Infallible> + Clone + Send + 'static, + >>::Response: IntoResponse, + >>::Future: Send, T: 'static, - ReqBody: Send + 'static, + S: 'static, + B: Send + 'static, { - type Future = future::LayeredFuture; + type Future = future::LayeredFuture; - fn call(self, req: Request) -> Self::Future { + fn call(self, state: S, req: Request) -> Self::Future { use futures_util::future::{FutureExt, Map}; - let future: Map<_, fn(Result) -> _> = - self.svc.oneshot(req).map(|result| match result { - Ok(res) => res.into_response(), - Err(err) => match err {}, - }); + let svc = self.handler.with_state(state); + let svc = self.layer.layer(svc); + + let future: Map< + _, + fn( + Result< + >>::Response, + >>::Error, + >, + ) -> _, + > = svc.oneshot(req).map(|result| match result { + Ok(res) => res.into_response(), + Err(err) => match err {}, + }); future::LayeredFuture::new(future) } } -impl Layered { - pub(crate) fn new(svc: S) -> Self { - Self { - svc, - _input: PhantomData, - } +/// Extension trait for [`Handler`]s that don't have state. +/// +/// This provides convenience methods to convert the [`Handler`] into a [`Service`] or [`MakeService`]. +/// +/// [`MakeService`]: tower::make::MakeService +pub trait HandlerWithoutStateExt: Handler { + /// Convert the handler into a [`Service`] and no state. + fn into_service(self) -> WithState; + + /// Convert the handler into a [`MakeService`] and no state. + /// + /// See [`WithState::into_make_service`] for more details. + /// + /// [`MakeService`]: tower::make::MakeService + fn into_make_service(self) -> IntoMakeService>; + + /// Convert the handler into a [`MakeService`] which stores information + /// about the incoming connection and has no state. + /// + /// See [`WithState::into_make_service_with_connect_info`] for more details. + /// + /// [`MakeService`]: tower::make::MakeService + fn into_make_service_with_connect_info( + self, + ) -> IntoMakeServiceWithConnectInfo, C>; +} + +impl HandlerWithoutStateExt for H +where + H: Handler, +{ + fn into_service(self) -> WithState { + self.with_state(()) + } + + fn into_make_service(self) -> IntoMakeService> { + self.with_state(()).into_make_service() + } + + fn into_make_service_with_connect_info( + self, + ) -> IntoMakeServiceWithConnectInfo, C> { + self.with_state(()).into_make_service_with_connect_info() } } diff --git a/axum/src/handler/with_state.rs b/axum/src/handler/with_state.rs new file mode 100644 index 0000000000..4afc9b106a --- /dev/null +++ b/axum/src/handler/with_state.rs @@ -0,0 +1,144 @@ +use super::{Handler, IntoService}; +use crate::{extract::connect_info::IntoMakeServiceWithConnectInfo, routing::IntoMakeService}; +use http::Request; +use std::task::{Context, Poll}; +use tower_service::Service; + +/// A [`Handler`] which has access to some state. +/// +/// Implements [`Service`]. +/// +/// The state can be extracted with [`State`](crate::extract::State). +/// +/// Created with [`Handler::with_state`]. +pub struct WithState { + pub(super) service: IntoService, +} + +impl WithState { + /// Get a reference to the state. + pub fn state(&self) -> &S { + self.service.state() + } +} + +impl WithState { + /// Convert the handler into a [`MakeService`]. + /// + /// This allows you to serve a single handler if you don't need any routing: + /// + /// ```rust + /// use axum::{ + /// Server, + /// handler::Handler, + /// extract::State, + /// http::{Uri, Method}, + /// response::IntoResponse, + /// }; + /// use std::net::SocketAddr; + /// + /// #[derive(Clone)] + /// struct AppState {} + /// + /// async fn handler(State(state): State) { + /// // ... + /// } + /// + /// let app = handler.with_state(AppState {}); + /// + /// # async { + /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) + /// .serve(app.into_make_service()) + /// .await?; + /// # Ok::<_, hyper::Error>(()) + /// # }; + /// ``` + /// + /// [`MakeService`]: tower::make::MakeService + pub fn into_make_service(self) -> IntoMakeService> { + IntoMakeService::new(self.service) + } + + /// Convert the handler into a [`MakeService`] which stores information + /// about the incoming connection. + /// + /// See [`Router::into_make_service_with_connect_info`] for more details. + /// + /// ```rust + /// use axum::{ + /// Server, + /// handler::Handler, + /// response::IntoResponse, + /// extract::{ConnectInfo, State}, + /// }; + /// use std::net::SocketAddr; + /// + /// #[derive(Clone)] + /// struct AppState {}; + /// + /// async fn handler( + /// ConnectInfo(addr): ConnectInfo, + /// State(state): State, + /// ) -> String { + /// format!("Hello {}", addr) + /// } + /// + /// let app = handler.with_state(AppState {}); + /// + /// # async { + /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) + /// .serve(app.into_make_service_with_connect_info::()) + /// .await?; + /// # Ok::<_, hyper::Error>(()) + /// # }; + /// ``` + /// + /// [`MakeService`]: tower::make::MakeService + /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info + pub fn into_make_service_with_connect_info( + self, + ) -> IntoMakeServiceWithConnectInfo, C> { + IntoMakeServiceWithConnectInfo::new(self.service) + } +} + +impl Service> for WithState +where + H: Handler + Clone + Send + 'static, + B: Send + 'static, + S: Clone, +{ + type Response = as Service>>::Response; + type Error = as Service>>::Error; + type Future = as Service>>::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + #[inline] + fn call(&mut self, req: Request) -> Self::Future { + self.service.call(req) + } +} + +impl std::fmt::Debug for WithState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WithState") + .field("service", &self.service) + .finish() + } +} + +impl Clone for WithState +where + H: Clone, + S: Clone, +{ + fn clone(&self) -> Self { + Self { + service: self.service.clone(), + } + } +} diff --git a/axum/src/json.rs b/axum/src/json.rs index 3fd83a4bbe..e35a1623ca 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -94,16 +94,17 @@ use std::ops::{Deref, DerefMut}; pub struct Json(pub T); #[async_trait] -impl FromRequest for Json +impl FromRequest for Json where T: DeserializeOwned, B: HttpBody + Send, B::Data: Send, B::Error: Into, + S: Send, { type Rejection = JsonRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if json_content_type(req) { let bytes = Bytes::from_request(req).await?; @@ -136,7 +137,7 @@ where } } -fn json_content_type(req: &RequestParts) -> bool { +fn json_content_type(req: &RequestParts) -> bool { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { content_type } else { diff --git a/axum/src/lib.rs b/axum/src/lib.rs index b355d6dc9f..a9b1bfb144 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -168,13 +168,48 @@ //! pool of database connections or clients to other services. //! //! The two most common ways of doing that are: +//! - Using the [`State`] extractor. //! - Using request extensions //! - Using closure captures //! +//! ## Using the [`State`] extractor +//! +//! ```rust,no_run +//! use axum::{ +//! extract::State, +//! routing::get, +//! Router, +//! }; +//! use std::sync::Arc; +//! +//! struct AppState { +//! // ... +//! } +//! +//! let shared_state = Arc::new(AppState { /* ... */ }); +//! +//! let app = Router::with_state(shared_state) +//! .route("/", get(handler)); +//! +//! async fn handler( +//! State(state): State>, +//! ) { +//! // ... +//! } +//! # async { +//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` +//! +//! You should prefer using [`State`] if possible since it's more type safe. The downside is that +//! its less dynamic than request extensions. +//! +//! See [`State`] for more details about accessing state. +//! //! ## Using request extensions //! -//! The easiest way to extract state in handlers is using [`Extension`](crate::extract::Extension) -//! as layer and extractor: +//! Another way to extract state in handlers is using [`Extension`](crate::extract::Extension) as +//! layer and extractor: //! //! ```rust,no_run //! use axum::{ @@ -184,18 +219,18 @@ //! }; //! use std::sync::Arc; //! -//! struct State { +//! struct AppState { //! // ... //! } //! -//! let shared_state = Arc::new(State { /* ... */ }); +//! let shared_state = Arc::new(AppState { /* ... */ }); //! //! let app = Router::new() //! .route("/", get(handler)) //! .layer(Extension(shared_state)); //! //! async fn handler( -//! Extension(state): Extension>, +//! Extension(state): Extension>, //! ) { //! // ... //! } @@ -223,11 +258,11 @@ //! use std::sync::Arc; //! use serde::Deserialize; //! -//! struct State { +//! struct AppState { //! // ... //! } //! -//! let shared_state = Arc::new(State { /* ... */ }); +//! let shared_state = Arc::new(AppState { /* ... */ }); //! //! let app = Router::new() //! .route( @@ -245,11 +280,11 @@ //! }), //! ); //! -//! async fn get_user(Path(user_id): Path, state: Arc) { +//! async fn get_user(Path(user_id): Path, state: Arc) { //! // ... //! } //! -//! async fn create_user(Json(payload): Json, state: Arc) { +//! async fn create_user(Json(payload): Json, state: Arc) { //! // ... //! } //! @@ -263,7 +298,7 @@ //! ``` //! //! The downside to this approach is that it's a little more verbose than using -//! extensions. +//! [`State`] or extensions. //! //! # Building integrations for axum //! @@ -350,6 +385,7 @@ //! [`Infallible`]: std::convert::Infallible //! [load shed]: tower::load_shed //! [`axum-core`]: http://crates.io/crates/axum-core +//! [`State`]: crate::extract::State #![warn( clippy::all, diff --git a/axum/src/middleware/from_extractor.rs b/axum/src/middleware/from_extractor.rs index 8399d8c52b..dfa3dfec82 100644 --- a/axum/src/middleware/from_extractor.rs +++ b/axum/src/middleware/from_extractor.rs @@ -45,13 +45,14 @@ use tower_service::Service; /// struct RequireAuth; /// /// #[async_trait] -/// impl FromRequest for RequireAuth +/// impl FromRequest for RequireAuth /// where /// B: Send, +/// S: Send, /// { /// type Rejection = StatusCode; /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: &mut RequestParts) -> Result { /// let auth_header = req /// .headers() /// .get(header::AUTHORIZATION) @@ -166,23 +167,23 @@ where } } -impl Service> for FromExtractor +impl Service> for FromExtractor where - E: FromRequest + 'static, - ReqBody: Default + Send + 'static, - S: Service> + Clone, + E: FromRequest<(), B> + 'static, + B: Default + Send + 'static, + S: Service> + Clone, S::Response: IntoResponse, { type Response = Response; type Error = S::Error; - type Future = ResponseFuture; + type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { let extract_future = Box::pin(async move { let mut req = RequestParts::new(req); let extracted = E::from_request(&mut req).await; @@ -201,35 +202,37 @@ where pin_project! { /// Response future for [`FromExtractor`]. #[allow(missing_debug_implementations)] - pub struct ResponseFuture + pub struct ResponseFuture where - E: FromRequest, - S: Service>, + E: FromRequest<(), B>, + S: Service>, { #[pin] - state: State, + state: State, svc: Option, } } pin_project! { #[project = StateProj] - enum State + enum State where - E: FromRequest, - S: Service>, + E: FromRequest<(), B>, + S: Service>, { - Extracting { future: BoxFuture<'static, (RequestParts, Result)> }, + Extracting { + future: BoxFuture<'static, (RequestParts<(), B>, Result)>, + }, Call { #[pin] future: S::Future }, } } -impl Future for ResponseFuture +impl Future for ResponseFuture where - E: FromRequest, - S: Service>, + E: FromRequest<(), B>, + S: Service>, S::Response: IntoResponse, - ReqBody: Default, + B: Default, { type Output = Result; @@ -277,13 +280,14 @@ mod tests { struct RequireAuth; #[async_trait::async_trait] - impl FromRequest for RequireAuth + impl FromRequest for RequireAuth where B: Send, + S: Send, { type Rejection = StatusCode; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if let Some(auth) = req .headers() .get(header::AUTHORIZATION) diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index 5322adfb18..0d37c61863 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -251,19 +251,19 @@ where macro_rules! impl_service { ( $($ty:ident),* $(,)? ) => { #[allow(non_snake_case)] - impl Service> for FromFn + impl Service> for FromFn where - F: FnMut($($ty),*, Next) -> Fut + Clone + Send + 'static, - $( $ty: FromRequest + Send, )* + F: FnMut($($ty),*, Next) -> Fut + Clone + Send + 'static, + $( $ty: FromRequest<(), B> + Send, )* Fut: Future + Send + 'static, Out: IntoResponse + 'static, - S: Service, Error = Infallible> + S: Service, Error = Infallible> + Clone + Send + 'static, S::Response: IntoResponse, S::Future: Send + 'static, - ReqBody: Send + 'static, + B: Send + 'static, { type Response = Response; type Error = Infallible; @@ -273,7 +273,7 @@ macro_rules! impl_service { self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { let not_ready_inner = self.inner.clone(); let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); @@ -320,13 +320,13 @@ where } /// The remainder of a middleware stack, including the handler. -pub struct Next { - inner: BoxCloneService, Response, Infallible>, +pub struct Next { + inner: BoxCloneService, Response, Infallible>, } -impl Next { +impl Next { /// Execute the remaining middleware stack. - pub async fn run(mut self, req: Request) -> Response { + pub async fn run(mut self, req: Request) -> Response { match self.inner.call(req).await { Ok(res) => res, Err(err) => match err {}, @@ -334,7 +334,7 @@ impl Next { } } -impl fmt::Debug for Next { +impl fmt::Debug for Next { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromFnLayer") .field("inner", &self.inner) diff --git a/axum/src/response/mod.rs b/axum/src/response/mod.rs index 695b931c22..5cf19c04ee 100644 --- a/axum/src/response/mod.rs +++ b/axum/src/response/mod.rs @@ -93,7 +93,7 @@ mod tests { } } - Router::::new() + Router::<_, Body>::new() .route("/", get(impl_trait_ok)) .route("/", get(impl_trait_err)) .route("/", get(impl_trait_both)) @@ -203,7 +203,7 @@ mod tests { ) } - Router::::new() + Router::<_, Body>::new() .route("/", get(status)) .route("/", get(status_headermap)) .route("/", get(status_header_array)) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 0aca7daffa..a71dc2fd47 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1,9 +1,11 @@ +//! Route to services and handlers based on HTTP methods. + use super::IntoMakeService; use crate::{ - body::{boxed, Body, Bytes, Empty, HttpBody}, + body::{Body, Bytes, HttpBody}, error_handling::{HandleError, HandleErrorLayer}, extract::connect_info::IntoMakeServiceWithConnectInfo, - handler::Handler, + handler::{Handler, IntoServiceStateInExtension}, http::{Method, Request, StatusCode}, response::Response, routing::{future::RouteFuture, Fallback, MethodFilter, Route}, @@ -13,6 +15,7 @@ use bytes::BytesMut; use std::{ convert::Infallible, fmt, + marker::PhantomData, task::{Context, Poll}, }; use tower::{service_fn, util::MapResponseLayer}; @@ -74,11 +77,12 @@ macro_rules! top_level_service_fn { $name:ident, $method:ident ) => { $(#[$m])+ - pub fn $name(svc: S) -> MethodRouter + pub fn $name(svc: T) -> MethodRouter where - S: Service> + Clone + Send + 'static, - S::Response: IntoResponse + 'static, - S::Future: Send + 'static, + T: Service> + Clone + Send + 'static, + T::Response: IntoResponse + 'static, + T::Future: Send + 'static, + B: Send + 'static, { on_service(MethodFilter::$method, svc) } @@ -134,11 +138,12 @@ macro_rules! top_level_handler_fn { $name:ident, $method:ident ) => { $(#[$m])+ - pub fn $name(handler: H) -> MethodRouter + pub fn $name(handler: H) -> MethodRouter where - H: Handler, + H: Handler, B: Send + 'static, T: 'static, + S: Clone + Send + Sync + 'static, { on(MethodFilter::$method, handler) } @@ -206,14 +211,14 @@ macro_rules! chained_service_fn { ) => { $(#[$m])+ #[track_caller] - pub fn $name(self, svc: S) -> Self + pub fn $name(self, svc: T) -> Self where - S: Service, Error = E> + T: Service, Error = E> + Clone + Send + 'static, - S::Response: IntoResponse + 'static, - S::Future: Send + 'static, + T::Response: IntoResponse + 'static, + T::Future: Send + 'static, { self.on_service(MethodFilter::$method, svc) } @@ -272,8 +277,9 @@ macro_rules! chained_handler_fn { #[track_caller] pub fn $name(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, + S: Clone + Send + Sync + 'static, { self.on(MethodFilter::$method, handler) } @@ -314,11 +320,12 @@ top_level_service_fn!(trace_service, TRACE); /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn on_service(filter: MethodFilter, svc: S) -> MethodRouter +pub fn on_service(filter: MethodFilter, svc: T) -> MethodRouter where - S: Service> + Clone + Send + 'static, - S::Response: IntoResponse + 'static, - S::Future: Send + 'static, + T: Service> + Clone + Send + 'static, + T::Response: IntoResponse + 'static, + T::Future: Send + 'static, + B: Send + 'static, { MethodRouter::new().on_service(filter, svc) } @@ -376,13 +383,16 @@ where /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn any_service(svc: S) -> MethodRouter +pub fn any_service(svc: T) -> MethodRouter where - S: Service> + Clone + Send + 'static, - S::Response: IntoResponse + 'static, - S::Future: Send + 'static, + T: Service> + Clone + Send + 'static, + T::Response: IntoResponse + 'static, + T::Future: Send + 'static, + B: Send + 'static, { - MethodRouter::new().fallback(svc).skip_allow_header() + MethodRouter::new() + .fallback_service(svc) + .skip_allow_header() } top_level_handler_fn!(delete, DELETE); @@ -413,11 +423,12 @@ top_level_handler_fn!(trace, TRACE); /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn on(filter: MethodFilter, handler: H) -> MethodRouter +pub fn on(filter: MethodFilter, handler: H) -> MethodRouter where - H: Handler, + H: Handler, B: Send + 'static, T: 'static, + S: Clone + Send + Sync + 'static, { MethodRouter::new().on(filter, handler) } @@ -459,20 +470,48 @@ where /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn any(handler: H) -> MethodRouter +pub fn any(handler: H) -> MethodRouter where - H: Handler, + H: Handler, B: Send + 'static, T: 'static, + S: Clone + Send + Sync + 'static, { MethodRouter::new() - .fallback_boxed_response_body(handler.into_service()) + .fallback_boxed_response_body(IntoServiceStateInExtension::new(handler)) .skip_allow_header() } /// A [`Service`] that accepts requests based on a [`MethodFilter`] and /// allows chaining additional handlers and services. -pub struct MethodRouter { +/// +/// # When does `MethodRouter` implement [`Service`]? +/// +/// Whether or not `MethodRouter` implements [`Service`] depends on the state type it requires. +/// +/// ``` +/// use tower::Service; +/// use axum::{routing::get, extract::State, body::Body, http::Request}; +/// +/// // this `MethodRouter` doesn't require any state, i.e. the state is `()`, +/// let method_router = get(|| async {}); +/// // and thus it implements `Service` +/// assert_service(method_router); +/// +/// // this requires a `String` and doesn't implement `Service` +/// let method_router = get(|_: State| async {}); +/// // until you provide the `String` with `.with_state(...)` +/// let method_router_with_state = method_router.with_state(String::new()); +/// // and then it implements `Service` +/// assert_service(method_router_with_state); +/// +/// // helper to check that a value implements `Service` +/// fn assert_service(service: S) +/// where +/// S: Service>, +/// {} +/// ``` +pub struct MethodRouter { get: Option>, head: Option>, delete: Option>, @@ -483,6 +522,7 @@ pub struct MethodRouter { trace: Option>, fallback: Fallback, allow_header: AllowHeader, + _marker: PhantomData S>, } #[derive(Clone)] @@ -511,7 +551,7 @@ impl AllowHeader { } } -impl fmt::Debug for MethodRouter { +impl fmt::Debug for MethodRouter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MethodRouter") .field("get", &self.get) @@ -527,32 +567,7 @@ impl fmt::Debug for MethodRouter { } } -impl MethodRouter { - /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all - /// requests. - pub fn new() -> Self { - let fallback = Route::new(service_fn(|_: Request| async { - let mut response = Response::new(boxed(Empty::new())); - *response.status_mut() = StatusCode::METHOD_NOT_ALLOWED; - Ok(response) - })); - - Self { - get: None, - head: None, - delete: None, - options: None, - patch: None, - post: None, - put: None, - trace: None, - allow_header: AllowHeader::None, - fallback: Fallback::Default(fallback), - } - } -} - -impl MethodRouter +impl MethodRouter where B: Send + 'static, { @@ -582,10 +597,11 @@ where #[track_caller] pub fn on(self, filter: MethodFilter, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, + S: Clone + Send + Sync + 'static, { - self.on_service_boxed_response_body(filter, handler.into_service()) + self.on_service_boxed_response_body(filter, IntoServiceStateInExtension::new(handler)) } chained_handler_fn!(delete, DELETE); @@ -597,6 +613,21 @@ where chained_handler_fn!(put, PUT); chained_handler_fn!(trace, TRACE); + /// Add a fallback [`Handler`] to the router. + pub fn fallback(self, handler: H) -> Self + where + H: Handler, + T: 'static, + S: Clone + Send + Sync + 'static, + { + self.fallback_service(IntoServiceStateInExtension::new(handler)) + } +} + +impl MethodRouter<(), B, Infallible> +where + B: Send + 'static, +{ /// Convert the handler into a [`MakeService`]. /// /// This allows you to serve a single handler if you don't need any routing: @@ -666,7 +697,58 @@ where } } -impl MethodRouter { +impl MethodRouter +where + B: Send + 'static, +{ + /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all + /// requests. + pub fn new() -> Self { + let fallback = Route::new(service_fn(|_: Request| async { + Ok(StatusCode::METHOD_NOT_ALLOWED.into_response()) + })); + + Self { + get: None, + head: None, + delete: None, + options: None, + patch: None, + post: None, + put: None, + trace: None, + allow_header: AllowHeader::None, + fallback: Fallback::Default(fallback), + _marker: PhantomData, + } + } + + /// Provide the state. + /// + /// See [`State`](crate::extract::State) for more details about accessing state. + pub fn with_state(self, state: S) -> WithState { + WithState { + method_router: self, + state, + } + } + + pub(crate) fn downcast_state(self) -> MethodRouter { + MethodRouter { + get: self.get, + head: self.head, + delete: self.delete, + options: self.options, + patch: self.patch, + post: self.post, + put: self.put, + trace: self.trace, + fallback: self.fallback, + allow_header: self.allow_header, + _marker: PhantomData, + } + } + /// Chain an additional service that will accept requests matching the given /// `MethodFilter`. /// @@ -693,11 +775,11 @@ impl MethodRouter { /// # }; /// ``` #[track_caller] - pub fn on_service(self, filter: MethodFilter, svc: S) -> Self + pub fn on_service(self, filter: MethodFilter, svc: T) -> Self where - S: Service, Error = E> + Clone + Send + 'static, - S::Response: IntoResponse + 'static, - S::Future: Send + 'static, + T: Service, Error = E> + Clone + Send + 'static, + T::Response: IntoResponse + 'static, + T::Future: Send + 'static, { self.on_service_boxed_response_body(filter, svc) } @@ -712,30 +794,30 @@ impl MethodRouter { chained_service_fn!(trace_service, TRACE); #[doc = include_str!("../docs/method_routing/fallback.md")] - pub fn fallback(mut self, svc: S) -> Self + pub fn fallback_service(mut self, svc: T) -> Self where - S: Service, Error = E> + Clone + Send + 'static, - S::Response: IntoResponse + 'static, - S::Future: Send + 'static, + T: Service, Error = E> + Clone + Send + 'static, + T::Response: IntoResponse + 'static, + T::Future: Send + 'static, { self.fallback = Fallback::Custom(Route::new(svc)); self } - fn fallback_boxed_response_body(mut self, svc: S) -> Self + fn fallback_boxed_response_body(mut self, svc: T) -> Self where - S: Service, Error = E> + Clone + Send + 'static, - S::Response: IntoResponse + 'static, - S::Future: Send + 'static, + T: Service, Error = E> + Clone + Send + 'static, + T::Response: IntoResponse + 'static, + T::Future: Send + 'static, { self.fallback = Fallback::Custom(Route::new(svc)); self } #[doc = include_str!("../docs/method_routing/layer.md")] - pub fn layer(self, layer: L) -> MethodRouter + pub fn layer(self, layer: L) -> MethodRouter where - L: Layer>, + L: Layer>, L::Service: Service, Error = NewError> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, >>::Future: Send + 'static, @@ -757,16 +839,17 @@ impl MethodRouter { trace: self.trace.map(layer_fn), fallback: self.fallback.map(layer_fn), allow_header: self.allow_header, + _marker: self._marker, } } #[doc = include_str!("../docs/method_routing/route_layer.md")] - pub fn route_layer(mut self, layer: L) -> MethodRouter + pub fn route_layer(mut self, layer: L) -> MethodRouter where - L: Layer>, - L::Service: Service, Error = E> + Clone + Send + 'static, - >>::Response: IntoResponse + 'static, - >>::Future: Send + 'static, + L: Layer>, + L::Service: Service, Error = E> + Clone + Send + 'static, + >>::Response: IntoResponse + 'static, + >>::Future: Send + 'static, { let layer_fn = |svc| { let svc = layer.layer(svc); @@ -788,7 +871,7 @@ impl MethodRouter { #[doc = include_str!("../docs/method_routing/merge.md")] #[track_caller] - pub fn merge(mut self, other: MethodRouter) -> Self { + pub fn merge(mut self, other: MethodRouter) -> Self { // written using inner functions to generate less IR #[track_caller] fn merge_inner(name: &str, first: Option, second: Option) -> Option { @@ -836,26 +919,25 @@ impl MethodRouter { /// Apply a [`HandleErrorLayer`]. /// /// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`. - pub fn handle_error(self, f: F) -> MethodRouter + pub fn handle_error(self, f: F) -> MethodRouter where F: Clone + Send + 'static, - HandleError, F, T>: Service, Error = Infallible>, - , F, T> as Service>>::Future: Send, - , F, T> as Service>>::Response: - IntoResponse + Send, + HandleError, F, T>: Service, Error = Infallible>, + , F, T> as Service>>::Future: Send, + , F, T> as Service>>::Response: IntoResponse + Send, T: 'static, E: 'static, - ReqBody: 'static, + B: 'static, { self.layer(HandleErrorLayer::new(f)) } #[track_caller] - fn on_service_boxed_response_body(mut self, filter: MethodFilter, svc: S) -> Self + fn on_service_boxed_response_body(mut self, filter: MethodFilter, svc: T) -> Self where - S: Service, Error = E> + Clone + Send + 'static, - S::Response: IntoResponse + 'static, - S::Future: Send + 'static, + T: Service, Error = E> + Clone + Send + 'static, + T::Response: IntoResponse + 'static, + T::Future: Send + 'static, { // written using an inner function to generate less IR fn set_service( @@ -991,7 +1073,25 @@ fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) { } } -impl Clone for MethodRouter { +impl Service> for MethodRouter<(), B, E> +where + B: HttpBody + Send + 'static, +{ + type Response = Response; + type Error = E; + type Future = RouteFuture; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + self.clone().with_state(()).call(req) + } +} + +impl Clone for MethodRouter { fn clone(&self) -> Self { Self { get: self.get.clone(), @@ -1004,11 +1104,12 @@ impl Clone for MethodRouter { trace: self.trace.clone(), fallback: self.fallback.clone(), allow_header: self.allow_header.clone(), + _marker: self._marker, } } } -impl Default for MethodRouter +impl Default for MethodRouter where B: Send + 'static, { @@ -1017,9 +1118,72 @@ where } } -impl Service> for MethodRouter +/// A [`MethodRouter`] which has access to some state. +/// +/// Implements [`Service`]. +/// +/// The state can be extracted with [`State`](crate::extract::State). +/// +/// Created with [`MethodRouter::with_state`] +pub struct WithState { + method_router: MethodRouter, + state: S, +} + +impl WithState { + /// Get a reference to the state. + pub fn state(&self) -> &S { + &self.state + } + + /// Convert the handler into a [`MakeService`]. + /// + /// See [`MethodRouter::into_make_service`] for more details. + /// + /// [`MakeService`]: tower::make::MakeService + pub fn into_make_service(self) -> IntoMakeService { + IntoMakeService::new(self) + } + + /// Convert the router into a [`MakeService`] which stores information + /// about the incoming connection. + /// + /// See [`MethodRouter::into_make_service_with_connect_info`] for more details. + /// + /// [`MakeService`]: tower::make::MakeService + pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { + IntoMakeServiceWithConnectInfo::new(self) + } +} + +impl Clone for WithState +where + S: Clone, +{ + fn clone(&self) -> Self { + Self { + method_router: self.method_router.clone(), + state: self.state.clone(), + } + } +} + +impl fmt::Debug for WithState +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WithState") + .field("method_router", &self.method_router) + .field("state", &self.state) + .finish() + } +} + +impl Service> for WithState where B: HttpBody, + S: Clone + Send + Sync + 'static, { type Response = Response; type Error = E; @@ -1030,7 +1194,7 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { macro_rules! call { ( $req:expr, @@ -1051,18 +1215,25 @@ where // written with a pattern match like this to ensure we call all routes let Self { - get, - head, - delete, - options, - patch, - post, - put, - trace, - fallback, - allow_header, + state, + method_router: + MethodRouter { + get, + head, + delete, + options, + patch, + post, + put, + trace, + fallback, + allow_header, + _marker: _, + }, } = self; + req.extensions_mut().insert(state.clone()); + call!(req, method, HEAD, head); call!(req, method, HEAD, get); call!(req, method, GET, get); @@ -1091,7 +1262,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{body::Body, error_handling::HandleErrorLayer}; + use crate::{body::Body, error_handling::HandleErrorLayer, extract::State}; use axum_core::response::IntoResponse; use http::{header::ALLOW, HeaderMap}; use std::time::Duration; @@ -1106,6 +1277,19 @@ mod tests { assert!(body.is_empty()); } + #[tokio::test] + async fn get_service_fn() { + async fn handle(_req: Request) -> Result, Infallible> { + Ok(Response::new(Body::from("ok"))) + } + + let mut svc = get_service(service_fn(handle)); + + let (status, _, body) = call(Method::GET, &mut svc).await; + assert_eq!(status, StatusCode::OK); + assert_eq!(body, "ok"); + } + #[tokio::test] async fn get_handler() { let mut svc = MethodRouter::new().get(ok); @@ -1183,7 +1367,7 @@ mod tests { delete_service(ServeDir::new(".")) .handle_error(|_| async { StatusCode::NOT_FOUND }), ) - .fallback((|| async { StatusCode::NOT_FOUND }).into_service()) + .fallback(|| async { StatusCode::NOT_FOUND }) .put(ok) .layer( ServiceBuilder::new() @@ -1243,9 +1427,9 @@ mod tests { #[tokio::test] async fn allow_header_with_fallback() { - let mut svc = MethodRouter::new().get(ok).fallback( - (|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }).into_service(), - ); + let mut svc = MethodRouter::new() + .get(ok) + .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }); let (status, headers, _) = call(Method::DELETE, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); @@ -1267,9 +1451,7 @@ mod tests { } } - let mut svc = MethodRouter::new() - .get(ok) - .fallback(fallback.into_service()); + let mut svc = MethodRouter::new().get(ok).fallback(fallback); let (status, _, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); @@ -1287,7 +1469,7 @@ mod tests { expected = "Overlapping method route. Cannot add two method routes that both handle `GET`" )] async fn handler_overlaps() { - let _: MethodRouter = get(ok).get(ok); + let _: MethodRouter<()> = get(ok).get(ok); } #[tokio::test] @@ -1295,17 +1477,58 @@ mod tests { expected = "Overlapping method route. Cannot add two method routes that both handle `POST`" )] async fn service_overlaps() { - let _: MethodRouter = post_service(ok.into_service()).post_service(ok.into_service()); + let _: MethodRouter<()> = post_service(IntoServiceStateInExtension::<_, _, (), _>::new(ok)) + .post_service(IntoServiceStateInExtension::<_, _, (), _>::new(ok)); } #[tokio::test] async fn get_head_does_not_overlap() { - let _: MethodRouter = get(ok).head(ok); + let _: MethodRouter<()> = get(ok).head(ok); } #[tokio::test] async fn head_get_does_not_overlap() { - let _: MethodRouter = head(ok).get(ok); + let _: MethodRouter<()> = head(ok).get(ok); + } + + #[tokio::test] + async fn accessing_state() { + let mut svc = MethodRouter::new() + .get(|State(state): State<&'static str>| async move { state }) + .with_state("state"); + + let (status, _, text) = call(Method::GET, &mut svc).await; + + assert_eq!(status, StatusCode::OK); + assert_eq!(text, "state"); + } + + #[tokio::test] + async fn fallback_accessing_state() { + let mut svc = MethodRouter::new() + .fallback(|State(state): State<&'static str>| async move { state }) + .with_state("state"); + + let (status, _, text) = call(Method::GET, &mut svc).await; + + assert_eq!(status, StatusCode::OK); + assert_eq!(text, "state"); + } + + #[tokio::test] + async fn merge_accessing_state() { + let one = get(|State(state): State<&'static str>| async move { state }); + let two = post(|State(state): State<&'static str>| async move { state }); + + let mut svc = one.merge(two).with_state("state"); + + let (status, _, text) = call(Method::GET, &mut svc).await; + assert_eq!(status, StatusCode::OK); + assert_eq!(text, "state"); + + let (status, _, _) = call(Method::POST, &mut svc).await; + assert_eq!(status, StatusCode::OK); + assert_eq!(text, "state"); } async fn call(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String) diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 8c4559d299..62d9e50700 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -4,8 +4,10 @@ use self::{future::RouteFuture, not_found::NotFound}; use crate::{ body::{Body, HttpBody}, extract::connect_info::IntoMakeServiceWithConnectInfo, + handler::Handler, response::Response, util::try_downcast, + Extension, }; use axum_core::response::IntoResponse; use http::Request; @@ -22,10 +24,10 @@ use tower_layer::Layer; use tower_service::Service; pub mod future; +pub mod method_routing; mod into_make_service; mod method_filter; -mod method_routing; mod not_found; mod route; mod strip_prefix; @@ -59,15 +61,20 @@ impl RouteId { } /// The router type for composing handlers and services. -pub struct Router { - routes: HashMap>, +pub struct Router { + state: S, + routes: HashMap>, node: Arc, fallback: Fallback, } -impl Clone for Router { +impl Clone for Router +where + S: Clone, +{ fn clone(&self) -> Self { Self { + state: self.state.clone(), routes: self.routes.clone(), node: Arc::clone(&self.node), fallback: self.fallback.clone(), @@ -75,18 +82,23 @@ impl Clone for Router { } } -impl Default for Router +impl Default for Router where B: HttpBody + Send + 'static, + S: Default + Clone + Send + Sync + 'static, { fn default() -> Self { - Self::new() + Self::with_state(S::default()) } } -impl fmt::Debug for Router { +impl fmt::Debug for Router +where + S: fmt::Debug, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Router") + .field("state", &self.state) .field("routes", &self.routes) .field("node", &self.node) .field("fallback", &self.fallback) @@ -97,7 +109,7 @@ impl fmt::Debug for Router { pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param"; const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param"; -impl Router +impl Router<(), B> where B: HttpBody + Send + 'static, { @@ -106,7 +118,24 @@ where /// Unless you add additional routes this will respond with `404 Not Found` to /// all requests. pub fn new() -> Self { + Self::with_state(()) + } +} + +impl Router +where + B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, +{ + /// Create a new `Router` with the given state. + /// + /// See [`State`](crate::extract::State) for more details about accessing state. + /// + /// Unless you add additional routes this will respond with `404 Not Found` to + /// all requests. + pub fn with_state(state: S) -> Self { Self { + state, routes: Default::default(), node: Default::default(), fallback: Fallback::Default(Route::new(NotFound)), @@ -115,12 +144,8 @@ where #[doc = include_str!("../docs/routing/route.md")] #[track_caller] - pub fn route(mut self, path: &str, service: T) -> Self - where - T: Service, Error = Infallible> + Clone + Send + 'static, - T::Response: IntoResponse, - T::Future: Send + 'static, - { + pub fn route(mut self, path: &str, method_router: MethodRouter) -> Self { + #[track_caller] fn validate_path(path: &str) { if path.is_empty() { panic!("Paths must start with a `/`. Use \"/\" for root routes"); @@ -131,39 +156,53 @@ where validate_path(path); - let service = match try_downcast::, _>(service) { - Ok(_) => { - panic!("Invalid route: `Router::route` cannot be used with `Router`s. Use `Router::nest` instead") - } - Err(svc) => svc, + let id = RouteId::next(); + + let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self + .node + .path_to_route_id + .get(path) + .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc))) + { + // if we're adding a new `MethodRouter` to a route that already has one just + // merge them. This makes `.route("/", get(_)).route("/", post(_))` work + let service = Endpoint::MethodRouter(prev_method_router.clone().merge(method_router)); + self.routes.insert(route_id, service); + return self; + } else { + Endpoint::MethodRouter(method_router) }; - let id = RouteId::next(); + self.set_node(path, id); + self.routes.insert(id, endpoint); - let service = match try_downcast::, _>(service) { - Ok(method_router) => { - if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self - .node - .path_to_route_id - .get(path) - .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc))) - { - // if we're adding a new `MethodRouter` to a route that already has one just - // merge them. This makes `.route("/", get(_)).route("/", post(_))` work - let service = - Endpoint::MethodRouter(prev_method_router.clone().merge(method_router)); - self.routes.insert(route_id, service); - return self; - } else { - Endpoint::MethodRouter(method_router) - } + self + } + + #[doc = include_str!("../docs/routing/route_service.md")] + pub fn route_service(mut self, path: &str, service: T) -> Self + where + T: Service, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, + T::Future: Send + 'static, + { + if path.is_empty() { + panic!("Paths must start with a `/`. Use \"/\" for root routes"); + } else if !path.starts_with('/') { + panic!("Paths must start with a `/`"); + } + + let service = match try_downcast::, _>(service) { + Ok(_) => { + panic!("Invalid route: `Router::route_service` cannot be used with `Router`s. Use `Router::nest` instead") } - Err(service) => Endpoint::Route(Route::new(service)), + Err(svc) => svc, }; + let id = RouteId::next(); + let endpoint = Endpoint::Route(Route::new(service)); self.set_node(path, id); - - self.routes.insert(id, service); + self.routes.insert(id, endpoint); self } @@ -204,15 +243,15 @@ where }; let svc = strip_prefix::StripPrefix::new(svc, prefix); - self = self.route(&path, svc.clone()); + self = self.route_service(&path, svc.clone()); // `/*rest` is not matched by `/` so we need to also register a router at the // prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself // wouldn't match, which it should - self = self.route(prefix, svc.clone()); + self = self.route_service(prefix, svc.clone()); if !prefix.ends_with('/') { // same goes for `/foo/`, that should also match - self = self.route(&format!("{prefix}/"), svc); + self = self.route_service(&format!("{prefix}/"), svc); } self @@ -220,11 +259,13 @@ where #[doc = include_str!("../docs/routing/merge.md")] #[track_caller] - pub fn merge(mut self, other: R) -> Self + pub fn merge(mut self, other: R) -> Self where - R: Into>, + R: Into>, + S2: Clone + Send + Sync + 'static, { let Router { + state, routes, node, fallback, @@ -236,8 +277,15 @@ where .get(&id) .expect("no path for route id. This is a bug in axum. Please file an issue"); self = match route { - Endpoint::MethodRouter(route) => self.route(path, route), - Endpoint::Route(route) => self.route(path, route), + Endpoint::MethodRouter(method_router) => self.route( + path, + method_router + // this will set the state for each route + // such we don't override the inner state later in `MethodRouterWithState` + .layer(Extension(state.clone())) + .downcast_state(), + ), + Endpoint::Route(route) => self.route_service(path, route), }; } @@ -254,7 +302,7 @@ where } #[doc = include_str!("../docs/routing/layer.md")] - pub fn layer(self, layer: L) -> Router + pub fn layer(self, layer: L) -> Router where L: Layer>, L::Service: Service> + Clone + Send + 'static, @@ -285,6 +333,7 @@ where let fallback = self.fallback.map(|svc| Route::new(layer.layer(svc))); Router { + state: self.state, routes, node: self.node, fallback, @@ -321,6 +370,7 @@ where .collect(); Router { + state: self.state, routes, node: self.node, fallback: self.fallback, @@ -328,7 +378,19 @@ where } #[doc = include_str!("../docs/routing/fallback.md")] - pub fn fallback(mut self, svc: T) -> Self + pub fn fallback(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let state = self.state.clone(); + self.fallback_service(handler.with_state(state)) + } + + /// Add a fallback [`Service`] to the router. + /// + /// See [`Router::fallback`] for more details. + pub fn fallback_service(mut self, svc: T) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, @@ -422,15 +484,21 @@ where .clone(); match &mut route { - Endpoint::MethodRouter(inner) => inner.call(req), + Endpoint::MethodRouter(inner) => inner.clone().with_state(self.state.clone()).call(req), Endpoint::Route(inner) => inner.call(req), } } + + /// Get a reference to the state. + pub fn state(&self) -> &S { + &self.state + } } -impl Service> for Router +impl Service> for Router where B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, { type Response = Response; type Error = Infallible; @@ -545,12 +613,15 @@ impl Fallback { } } -enum Endpoint { - MethodRouter(MethodRouter), +enum Endpoint { + MethodRouter(MethodRouter), Route(Route), } -impl Clone for Endpoint { +impl Clone for Endpoint +where + S: Clone, +{ fn clone(&self) -> Self { match self { Endpoint::MethodRouter(inner) => Endpoint::MethodRouter(inner.clone()), @@ -559,7 +630,10 @@ impl Clone for Endpoint { } } -impl fmt::Debug for Endpoint { +impl fmt::Debug for Endpoint +where + S: fmt::Debug, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::MethodRouter(inner) => inner.fmt(f), @@ -572,5 +646,5 @@ impl fmt::Debug for Endpoint { #[allow(warnings)] fn traits() { use crate::test_helpers::*; - assert_send::>(); + assert_send::>(); } diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index e116c57aa7..8b3959df1f 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -48,13 +48,13 @@ impl Route { } } -impl Clone for Route { +impl Clone for Route { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl fmt::Debug for Route { +impl fmt::Debug for Route { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Route").finish() } diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index 498740772c..4da166baea 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -1,11 +1,10 @@ use super::*; -use crate::handler::Handler; #[tokio::test] async fn basic() { let app = Router::new() .route("/foo", get(|| async {})) - .fallback((|| async { "fallback" }).into_service()); + .fallback(|| async { "fallback" }); let client = TestClient::new(app); @@ -20,7 +19,7 @@ async fn basic() { async fn nest() { let app = Router::new() .nest("/foo", Router::new().route("/bar", get(|| async {}))) - .fallback((|| async { "fallback" }).into_service()); + .fallback(|| async { "fallback" }); let client = TestClient::new(app); @@ -36,9 +35,7 @@ async fn or() { let one = Router::new().route("/one", get(|| async {})); let two = Router::new().route("/two", get(|| async {})); - let app = one - .merge(two) - .fallback((|| async { "fallback" }).into_service()); + let app = one.merge(two).fallback(|| async { "fallback" }); let client = TestClient::new(app); @@ -49,3 +46,15 @@ async fn or() { assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "fallback"); } + +#[tokio::test] +async fn fallback_accessing_state() { + let app = Router::with_state("state") + .fallback(|State(state): State<&'static str>| async move { state }); + + let client = TestClient::new(app); + + let res = client.get("/does-not-exist").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "state"); +} diff --git a/axum/src/routing/tests/merge.rs b/axum/src/routing/tests/merge.rs index afb0bcb94f..5f67dcd784 100644 --- a/axum/src/routing/tests/merge.rs +++ b/axum/src/routing/tests/merge.rs @@ -408,3 +408,72 @@ async fn middleware_that_return_early() { ); assert_eq!(client.get("/public").send().await.status(), StatusCode::OK); } + +#[tokio::test] +async fn merge_with_different_state_type() { + let inner = Router::with_state("inner".to_owned()).route( + "/foo", + get(|State(state): State| async move { state }), + ); + + let app = Router::with_state("outer").merge(inner).route( + "/bar", + get(|State(state): State<&'static str>| async move { state }), + ); + + let client = TestClient::new(app); + + let res = client.get("/foo").send().await; + assert_eq!(res.text().await, "inner"); + + let res = client.get("/bar").send().await; + assert_eq!(res.text().await, "outer"); +} + +#[tokio::test] +async fn merging_routes_different_method_different_states() { + let get = Router::with_state("get state").route( + "/", + get(|State(state): State<&'static str>| async move { state }), + ); + + let post = Router::with_state("post state").route( + "/", + post(|State(state): State<&'static str>| async move { state }), + ); + + let app = Router::new().merge(get).merge(post); + + let client = TestClient::new(app); + + let res = client.get("/").send().await; + assert_eq!(res.text().await, "get state"); + + let res = client.post("/").send().await; + assert_eq!(res.text().await, "post state"); +} + +#[tokio::test] +async fn merging_routes_different_paths_different_states() { + let foo = Router::with_state("foo state").route( + "/foo", + get(|State(state): State<&'static str>| async move { state }), + ); + + let bar = Router::with_state("bar state").route( + "/bar", + get(|State(state): State<&'static str>| async move { state }), + ); + + let app = Router::new().merge(foo).merge(bar); + + let client = TestClient::new(app); + + let res = client.get("/foo").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "foo state"); + + let res = client.get("/bar").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "bar state"); +} diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index b7f53cb163..68746a2fa1 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -1,8 +1,8 @@ use crate::{ body::{Bytes, Empty}, error_handling::HandleErrorLayer, - extract::{self, Path}, - handler::Handler, + extract::{self, FromRef, Path, State}, + handler::{Handler, HandlerWithoutStateExt}, response::IntoResponse, routing::{delete, get, get_service, on, on_service, patch, patch_service, post, MethodFilter}, test_helpers::*, @@ -444,10 +444,10 @@ async fn middleware_still_run_for_unmatched_requests() { #[tokio::test] #[should_panic( - expected = "Invalid route: `Router::route` cannot be used with `Router`s. Use `Router::nest` instead" + expected = "Invalid route: `Router::route_service` cannot be used with `Router`s. Use `Router::nest` instead" )] async fn routing_to_router_panics() { - TestClient::new(Router::new().route("/", Router::new())); + TestClient::new(Router::new().route_service("/", Router::new())); } #[tokio::test] @@ -499,8 +499,8 @@ async fn different_methods_added_in_different_routes() { #[should_panic(expected = "Cannot merge two `Router`s that both have a fallback")] async fn merging_routers_with_fallbacks_panics() { async fn fallback() {} - let one = Router::new().fallback(fallback.into_service()); - let two = Router::new().fallback(fallback.into_service()); + let one = Router::new().fallback(fallback); + let two = Router::new().fallback(fallback); TestClient::new(one.merge(two)); } @@ -539,7 +539,7 @@ async fn head_content_length_through_hyper_server() { #[tokio::test] async fn head_content_length_through_hyper_server_that_hits_fallback() { - let app = Router::new().fallback((|| async { "foo" }).into_service()); + let app = Router::new().fallback(|| async { "foo" }); let client = TestClient::new(app); @@ -641,6 +641,54 @@ async fn limited_body_with_streaming_body() { assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } +#[tokio::test] +async fn extract_state() { + #[derive(Clone)] + struct AppState { + value: i32, + inner: InnerState, + } + + #[derive(Clone)] + struct InnerState { + value: i32, + } + + impl FromRef for InnerState { + fn from_ref(state: &AppState) -> Self { + state.inner.clone() + } + } + + async fn handler(State(outer): State, State(inner): State) { + assert_eq!(outer.value, 1); + assert_eq!(inner.value, 2); + } + + let state = AppState { + value: 1, + inner: InnerState { value: 2 }, + }; + + let app = Router::with_state(state).route("/", get(handler)); + let client = TestClient::new(app); + + let res = client.get("/").send().await; + assert_eq!(res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn explicitly_set_state() { + let app = Router::with_state("...").route_service( + "/", + get(|State(state): State<&'static str>| async move { state }).with_state("foo"), + ); + + let client = TestClient::new(app); + let res = client.get("/").send().await; + assert_eq!(res.text().await, "foo"); +} + #[tokio::test] async fn layer_response_into_response() { fn map_response(_res: Response) -> Result, impl IntoResponse> { diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index cabadaf545..f856b6e8bc 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -182,7 +182,7 @@ async fn nested_service_sees_stripped_uri() { "/foo", Router::new().nest( "/bar", - Router::new().route( + Router::new().route_service( "/baz", service_fn(|req: Request| async move { let body = boxed(Body::from(req.uri().to_string())); @@ -264,7 +264,7 @@ async fn multiple_top_level_nests() { #[tokio::test] #[should_panic(expected = "Invalid route: nested routes cannot contain wildcards (*)")] async fn nest_cannot_contain_wildcards() { - Router::::new().nest("/one/*rest", Router::new()); + Router::<_, Body>::new().nest("/one/*rest", Router::new()); } #[tokio::test] @@ -275,7 +275,7 @@ async fn outer_middleware_still_see_whole_url() { #[derive(Clone)] struct Uri(http::Uri); - impl Service> for SetUriExtension + impl Service> for SetUriExtension where S: Service>, { @@ -303,7 +303,7 @@ async fn outer_middleware_still_see_whole_url() { .route("/foo", get(handler)) .route("/foo/bar", get(handler)) .nest("/one", Router::new().route("/two", get(handler))) - .fallback(handler.into_service()) + .fallback(handler) .layer(tower::layer::layer_fn(SetUriExtension)); let client = TestClient::new(app); @@ -356,7 +356,7 @@ async fn nest_with_and_without_trailing() { async fn doesnt_call_outer_fallback() { let app = Router::new() .nest("/foo", Router::new().route("/", get(|| async {}))) - .fallback((|| async { (StatusCode::NOT_FOUND, "outer fallback") }).into_service()); + .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") }); let client = TestClient::new(app); @@ -396,9 +396,9 @@ async fn fallback_on_inner() { "/foo", Router::new() .route("/", get(|| async {})) - .fallback((|| async { (StatusCode::NOT_FOUND, "inner fallback") }).into_service()), + .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }), ) - .fallback((|| async { (StatusCode::NOT_FOUND, "outer fallback") }).into_service()); + .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") }); let client = TestClient::new(app); @@ -442,3 +442,33 @@ nested_route_test!(nest_9, nest = "/a", route = "/a/", expected = "/a/a/"); nested_route_test!(nest_11, nest = "/a/", route = "/", expected = "/a/"); nested_route_test!(nest_12, nest = "/a/", route = "/a", expected = "/a/a"); nested_route_test!(nest_13, nest = "/a/", route = "/a/", expected = "/a/a/"); + +#[tokio::test] +async fn nesting_with_different_state() { + let inner = Router::with_state("inner".to_owned()).route( + "/foo", + get(|State(state): State| async move { state }), + ); + + let outer = Router::with_state("outer") + .route( + "/foo", + get(|State(state): State<&'static str>| async move { state }), + ) + .nest("/nested", inner) + .route( + "/bar", + get(|State(state): State<&'static str>| async move { state }), + ); + + let client = TestClient::new(outer); + + let res = client.get("/foo").send().await; + assert_eq!(res.text().await, "outer"); + + let res = client.get("/nested/foo").send().await; + assert_eq!(res.text().await, "inner"); + + let res = client.get("/bar").send().await; + assert_eq!(res.text().await, "outer"); +} diff --git a/axum/src/typed_header.rs b/axum/src/typed_header.rs index 1ce72d04f8..c28a24a81a 100644 --- a/axum/src/typed_header.rs +++ b/axum/src/typed_header.rs @@ -52,14 +52,15 @@ use std::{convert::Infallible, ops::Deref}; pub struct TypedHeader(pub T); #[async_trait] -impl FromRequest for TypedHeader +impl FromRequest for TypedHeader where T: headers::Header, B: Send, + S: Send, { type Rejection = TypedHeaderRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match req.headers().typed_try_get::() { Ok(Some(value)) => Ok(Self(value)), Ok(None) => Err(TypedHeaderRejection { diff --git a/examples/async-graphql/src/main.rs b/examples/async-graphql/src/main.rs new file mode 100644 index 0000000000..a8d84cb9a1 --- /dev/null +++ b/examples/async-graphql/src/main.rs @@ -0,0 +1,45 @@ +//! Example async-graphql application. +//! +//! Run with +//! +//! ```not_rust +//! cd examples && cargo run -p example-async-graphql +//! ``` + +mod starwars; + +use async_graphql::{ + http::{playground_source, GraphQLPlaygroundConfig}, + EmptyMutation, EmptySubscription, Request, Response, Schema, +}; +use axum::{ + extract::State, + response::{Html, IntoResponse}, + routing::get, + Json, Router, +}; +use starwars::{QueryRoot, StarWars, StarWarsSchema}; + +async fn graphql_handler(schema: State, req: Json) -> Json { + schema.execute(req.0).await.into() +} + +async fn graphql_playground() -> impl IntoResponse { + Html(playground_source(GraphQLPlaygroundConfig::new("/"))) +} + +#[tokio::main] +async fn main() { + let schema = Schema::build(QueryRoot, EmptyMutation, EmptySubscription) + .data(StarWars::new()) + .finish(); + + let app = Router::with_state(schema).route("/", get(graphql_playground).post(graphql_handler)); + + println!("Playground: http://localhost:3000"); + + axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) + .serve(app.into_make_service()) + .await + .unwrap(); +} diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs index 5107323ea8..092a7d96de 100644 --- a/examples/chat/src/main.rs +++ b/examples/chat/src/main.rs @@ -9,7 +9,7 @@ use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, - Extension, + State, }, response::{Html, IntoResponse}, routing::get, @@ -44,10 +44,9 @@ async fn main() { let app_state = Arc::new(AppState { user_set, tx }); - let app = Router::new() + let app = Router::with_state(app_state) .route("/", get(index)) - .route("/websocket", get(websocket_handler)) - .layer(Extension(app_state)); + .route("/websocket", get(websocket_handler)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); @@ -59,7 +58,7 @@ async fn main() { async fn websocket_handler( ws: WebSocketUpgrade, - Extension(state): Extension>, + State(state): State>, ) -> impl IntoResponse { ws.on_upgrade(|socket| websocket(socket, state)) } diff --git a/examples/consume-body-in-extractor-or-middleware/src/main.rs b/examples/consume-body-in-extractor-or-middleware/src/main.rs index 1fdd90221a..be948375d4 100644 --- a/examples/consume-body-in-extractor-or-middleware/src/main.rs +++ b/examples/consume-body-in-extractor-or-middleware/src/main.rs @@ -80,17 +80,22 @@ async fn handler(_: PrintRequestBody, body: Bytes) { struct PrintRequestBody; #[async_trait] -impl FromRequest for PrintRequestBody { +impl FromRequest for PrintRequestBody +where + S: Send + Clone, +{ type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { + let state = req.state().clone(); + let request = Request::from_request(req) .await .map_err(|err| err.into_response())?; let request = buffer_request_body(request).await?; - *req = RequestParts::new(request); + *req = RequestParts::with_state(state, request); Ok(Self) } diff --git a/examples/customize-extractor-error/src/main.rs b/examples/customize-extractor-error/src/main.rs index be5ef59513..20e3b4d482 100644 --- a/examples/customize-extractor-error/src/main.rs +++ b/examples/customize-extractor-error/src/main.rs @@ -56,8 +56,9 @@ struct User { struct Json(T); #[async_trait] -impl FromRequest for Json +impl FromRequest for Json where + S: Send, // these trait bounds are copied from `impl FromRequest for axum::Json` T: DeserializeOwned, B: axum::body::HttpBody + Send, @@ -66,7 +67,7 @@ where { type Rejection = (StatusCode, axum::Json); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match axum::Json::::from_request(req).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { diff --git a/examples/customize-path-rejection/src/main.rs b/examples/customize-path-rejection/src/main.rs index 4e26894959..8330b95a93 100644 --- a/examples/customize-path-rejection/src/main.rs +++ b/examples/customize-path-rejection/src/main.rs @@ -52,15 +52,16 @@ struct Params { struct Path(T); #[async_trait] -impl FromRequest for Path +impl FromRequest for Path where // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` T: DeserializeOwned + Send, B: Send, + S: Send, { type Rejection = (StatusCode, axum::Json); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match axum::extract::Path::::from_request(req).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { diff --git a/examples/error-handling-and-dependency-injection/src/main.rs b/examples/error-handling-and-dependency-injection/src/main.rs index d92b43bf49..914ae18155 100644 --- a/examples/error-handling-and-dependency-injection/src/main.rs +++ b/examples/error-handling-and-dependency-injection/src/main.rs @@ -9,7 +9,7 @@ use axum::{ async_trait, - extract::{Extension, Path}, + extract::{Path, State}, http::StatusCode, response::{IntoResponse, Response}, routing::{get, post}, @@ -36,12 +36,9 @@ async fn main() { let user_repo = Arc::new(ExampleUserRepo) as DynUserRepo; // Build our application with some routes - let app = Router::new() + let app = Router::with_state(user_repo) .route("/users/:id", get(users_show)) - .route("/users", post(users_create)) - // Add our `user_repo` to all request's extensions so handlers can access - // it. - .layer(Extension(user_repo)); + .route("/users", post(users_create)); // Run our application let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -59,7 +56,7 @@ async fn main() { /// so it can be returned from handlers directly. async fn users_show( Path(user_id): Path, - Extension(user_repo): Extension, + State(user_repo): State, ) -> Result, AppError> { let user = user_repo.find(user_id).await?; @@ -69,7 +66,7 @@ async fn users_show( /// Handler for `POST /users`. async fn users_create( Json(params): Json, - Extension(user_repo): Extension, + State(user_repo): State, ) -> Result, AppError> { let user = user_repo.create(params).await?; diff --git a/examples/global-404-handler/src/main.rs b/examples/global-404-handler/src/main.rs index 385a0e21e3..a3a5ea15dc 100644 --- a/examples/global-404-handler/src/main.rs +++ b/examples/global-404-handler/src/main.rs @@ -5,7 +5,6 @@ //! ``` use axum::{ - handler::Handler, http::StatusCode, response::{Html, IntoResponse}, routing::get, @@ -27,7 +26,7 @@ async fn main() { let app = Router::new().route("/", get(handler)); // add a fallback service for handling routes to unknown paths - let app = app.fallback(handler_404.into_service()); + let app = app.fallback(handler_404); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/examples/jwt/src/main.rs b/examples/jwt/src/main.rs index 0ac4053e6c..8725581da7 100644 --- a/examples/jwt/src/main.rs +++ b/examples/jwt/src/main.rs @@ -122,13 +122,14 @@ impl AuthBody { } #[async_trait] -impl FromRequest for Claims +impl FromRequest for Claims where + S: Send, B: Send, { type Rejection = AuthError; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { // Extract the token from the authorization header let TypedHeader(Authorization(bearer)) = TypedHeader::>::from_request(req) diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index 0ad5b7f670..c65ee75a0f 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -9,7 +9,7 @@ use axum::{ body::Bytes, error_handling::HandleErrorLayer, - extract::{ContentLengthLimit, Extension, Path}, + extract::{ContentLengthLimit, Path, State}, handler::Handler, http::StatusCode, response::IntoResponse, @@ -39,8 +39,10 @@ async fn main() { .with(tracing_subscriber::fmt::layer()) .init(); + let shared_state = SharedState::default(); + // Build our application by composing routes - let app = Router::new() + let app = Router::with_state(Arc::clone(&shared_state)) .route( "/:key", // Add compression to `kv_get` @@ -50,7 +52,7 @@ async fn main() { ) .route("/keys", get(list_keys)) // Nest our admin routes under `/admin` - .nest("/admin", admin_routes()) + .nest("/admin", admin_routes(shared_state)) // Add middleware to all routes .layer( ServiceBuilder::new() @@ -60,7 +62,6 @@ async fn main() { .concurrency_limit(1024) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) - .layer(Extension(SharedState::default())) .into_inner(), ); @@ -73,16 +74,16 @@ async fn main() { .unwrap(); } -type SharedState = Arc>; +type SharedState = Arc>; #[derive(Default)] -struct State { +struct AppState { db: HashMap, } async fn kv_get( Path(key): Path, - Extension(state): Extension, + State(state): State, ) -> Result { let db = &state.read().unwrap().db; @@ -96,12 +97,12 @@ async fn kv_get( async fn kv_set( Path(key): Path, ContentLengthLimit(bytes): ContentLengthLimit, // ~5mb - Extension(state): Extension, + State(state): State, ) { state.write().unwrap().db.insert(key, bytes); } -async fn list_keys(Extension(state): Extension) -> String { +async fn list_keys(State(state): State) -> String { let db = &state.read().unwrap().db; db.keys() @@ -110,16 +111,16 @@ async fn list_keys(Extension(state): Extension) -> String { .join("\n") } -fn admin_routes() -> Router { - async fn delete_all_keys(Extension(state): Extension) { +fn admin_routes(state: SharedState) -> Router { + async fn delete_all_keys(State(state): State) { state.write().unwrap().db.clear(); } - async fn remove_key(Path(key): Path, Extension(state): Extension) { + async fn remove_key(Path(key): Path, State(state): State) { state.write().unwrap().db.remove(&key); } - Router::new() + Router::with_state(state) .route("/keys", delete(delete_all_keys)) .route("/key/:key", delete(remove_key)) // Require bearer auth for all admin routes diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index 6357a7fc08..a61113b97d 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -12,7 +12,7 @@ use async_session::{MemoryStore, Session, SessionStore}; use axum::{ async_trait, extract::{ - rejection::TypedHeaderRejectionReason, Extension, FromRequest, Query, RequestParts, + rejection::TypedHeaderRejectionReason, FromRef, FromRequest, Query, RequestParts, State, TypedHeader, }, http::{header::SET_COOKIE, HeaderMap}, @@ -42,17 +42,18 @@ async fn main() { // `MemoryStore` is just used as an example. Don't use this in production. let store = MemoryStore::new(); - let oauth_client = oauth_client(); + let app_state = AppState { + store, + oauth_client, + }; - let app = Router::new() + let app = Router::with_state(app_state) .route("/", get(index)) .route("/auth/discord", get(discord_auth)) .route("/auth/authorized", get(login_authorized)) .route("/protected", get(protected)) - .route("/logout", get(logout)) - .layer(Extension(store)) - .layer(Extension(oauth_client)); + .route("/logout", get(logout)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); @@ -63,6 +64,24 @@ async fn main() { .unwrap(); } +#[derive(Clone)] +struct AppState { + store: MemoryStore, + oauth_client: BasicClient, +} + +impl FromRef for MemoryStore { + fn from_ref(state: &AppState) -> Self { + state.store.clone() + } +} + +impl FromRef for BasicClient { + fn from_ref(state: &AppState) -> Self { + state.oauth_client.clone() + } +} + fn oauth_client() -> BasicClient { // Environment variables (* = required): // *"CLIENT_ID" "REPLACE_ME"; @@ -113,7 +132,7 @@ async fn index(user: Option) -> impl IntoResponse { } } -async fn discord_auth(Extension(client): Extension) -> impl IntoResponse { +async fn discord_auth(State(client): State) -> impl IntoResponse { let (auth_url, _csrf_token) = client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("identify".to_string())) @@ -132,7 +151,7 @@ async fn protected(user: User) -> impl IntoResponse { } async fn logout( - Extension(store): Extension, + State(store): State, TypedHeader(cookies): TypedHeader, ) -> impl IntoResponse { let cookie = cookies.get(COOKIE_NAME).unwrap(); @@ -156,8 +175,8 @@ struct AuthRequest { async fn login_authorized( Query(query): Query, - Extension(store): Extension, - Extension(oauth_client): Extension, + State(store): State, + State(oauth_client): State, ) -> impl IntoResponse { // Get an auth token let token = oauth_client @@ -205,17 +224,15 @@ impl IntoResponse for AuthRedirect { } #[async_trait] -impl FromRequest for User +impl FromRequest for User where B: Send, { // If anything goes wrong or no session is found, redirect to the auth page type Rejection = AuthRedirect; - async fn from_request(req: &mut RequestParts) -> Result { - let Extension(store) = Extension::::from_request(req) - .await - .expect("`MemoryStore` extension is missing"); + async fn from_request(req: &mut RequestParts) -> Result { + let store = req.state().clone().store; let cookies = TypedHeader::::from_request(req) .await diff --git a/examples/reverse-proxy/src/main.rs b/examples/reverse-proxy/src/main.rs index a9d2a5c750..af74ea1295 100644 --- a/examples/reverse-proxy/src/main.rs +++ b/examples/reverse-proxy/src/main.rs @@ -8,7 +8,7 @@ //! ``` use axum::{ - extract::Extension, + extract::State, http::{uri::Uri, Request, Response}, routing::get, Router, @@ -24,9 +24,7 @@ async fn main() { let client = Client::new(); - let app = Router::new() - .route("/", get(handler)) - .layer(Extension(client)); + let app = Router::with_state(client).route("/", get(handler)); let addr = SocketAddr::from(([127, 0, 0, 1], 4000)); println!("reverse proxy listening on {}", addr); @@ -36,12 +34,7 @@ async fn main() { .unwrap(); } -async fn handler( - Extension(client): Extension, - // NOTE: Make sure to put the request extractor last because once the request - // is extracted, extensions can't be extracted anymore. - mut req: Request, -) -> Response { +async fn handler(State(client): State, mut req: Request) -> Response { let path = req.uri().path(); let path_query = req .uri() diff --git a/examples/routes-and-handlers-close-together/src/main.rs b/examples/routes-and-handlers-close-together/src/main.rs index 5e52ad7b55..41aaa49db5 100644 --- a/examples/routes-and-handlers-close-together/src/main.rs +++ b/examples/routes-and-handlers-close-together/src/main.rs @@ -49,6 +49,6 @@ fn post_foo() -> Router { route("/foo", post(handler)) } -fn route(path: &str, method_router: MethodRouter) -> Router { +fn route(path: &str, method_router: MethodRouter<()>) -> Router { Router::new().route(path, method_router) } diff --git a/examples/sessions/src/main.rs b/examples/sessions/src/main.rs index 3251122c9a..cd0d41a1f6 100644 --- a/examples/sessions/src/main.rs +++ b/examples/sessions/src/main.rs @@ -7,7 +7,7 @@ use async_session::{MemoryStore, Session, SessionStore as _}; use axum::{ async_trait, - extract::{Extension, FromRequest, RequestParts, TypedHeader}, + extract::{FromRequest, RequestParts, TypedHeader}, headers::Cookie, http::{ self, @@ -38,9 +38,7 @@ async fn main() { // `MemoryStore` just used as an example. Don't use this in production. let store = MemoryStore::new(); - let app = Router::new() - .route("/", get(handler)) - .layer(Extension(store)); + let app = Router::with_state(store).route("/", get(handler)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); @@ -82,20 +80,16 @@ enum UserIdFromSession { } #[async_trait] -impl FromRequest for UserIdFromSession +impl FromRequest for UserIdFromSession where B: Send, { type Rejection = (StatusCode, &'static str); - async fn from_request(req: &mut RequestParts) -> Result { - let Extension(store) = Extension::::from_request(req) - .await - .expect("`MemoryStore` extension missing"); + async fn from_request(req: &mut RequestParts) -> Result { + let store = req.state().clone(); - let cookie = Option::>::from_request(req) - .await - .unwrap(); + let cookie = req.extract::>>().await.unwrap(); let session_cookie = cookie .as_ref() diff --git a/examples/sqlx-postgres/src/main.rs b/examples/sqlx-postgres/src/main.rs index 9d101618db..6548cdeb97 100644 --- a/examples/sqlx-postgres/src/main.rs +++ b/examples/sqlx-postgres/src/main.rs @@ -15,7 +15,7 @@ use axum::{ async_trait, - extract::{Extension, FromRequest, RequestParts}, + extract::{FromRequest, RequestParts, State}, http::StatusCode, routing::get, Router, @@ -46,12 +46,10 @@ async fn main() { .expect("can connect to database"); // build our application with some routes - let app = Router::new() - .route( - "/", - get(using_connection_pool_extractor).post(using_connection_extractor), - ) - .layer(Extension(pool)); + let app = Router::with_state(pool).route( + "/", + get(using_connection_pool_extractor).post(using_connection_extractor), + ); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -62,9 +60,9 @@ async fn main() { .unwrap(); } -// we can extract the connection pool with `Extension` +// we can extract the connection pool with `State` async fn using_connection_pool_extractor( - Extension(pool): Extension, + State(pool): State, ) -> Result { sqlx::query_scalar("select 'hello world from pg'") .fetch_one(&pool) @@ -77,16 +75,14 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(sqlx::pool::PoolConnection); #[async_trait] -impl FromRequest for DatabaseConnection +impl FromRequest for DatabaseConnection where B: Send, { type Rejection = (StatusCode, String); - async fn from_request(req: &mut RequestParts) -> Result { - let Extension(pool) = Extension::::from_request(req) - .await - .map_err(internal_error)?; + async fn from_request(req: &mut RequestParts) -> Result { + let pool = req.state().clone(); let conn = pool.acquire().await.map_err(internal_error)?; diff --git a/examples/sse/src/main.rs b/examples/sse/src/main.rs index 4dbfcb4673..6679971152 100644 --- a/examples/sse/src/main.rs +++ b/examples/sse/src/main.rs @@ -41,7 +41,7 @@ async fn main() { // build our application with a route let app = Router::new() - .fallback(static_files_service) + .fallback_service(static_files_service) .route("/sse", get(sse_handler)) .layer(TraceLayer::new_for_http()); diff --git a/examples/static-file-server/src/main.rs b/examples/static-file-server/src/main.rs index 1862ecab5a..f7ac2bb9a8 100644 --- a/examples/static-file-server/src/main.rs +++ b/examples/static-file-server/src/main.rs @@ -34,7 +34,7 @@ async fn main() { // as the fallback to a `Router` let app: _ = Router::new() .route("/foo", get(|| async { "Hi from /foo" })) - .fallback(get_service(ServeDir::new(".")).handle_error(handle_error)) + .fallback_service(get_service(ServeDir::new(".")).handle_error(handle_error)) .layer(TraceLayer::new_for_http()); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/examples/tls-rustls/src/main.rs b/examples/tls-rustls/src/main.rs index 40008a9631..6eec0e9c2a 100644 --- a/examples/tls-rustls/src/main.rs +++ b/examples/tls-rustls/src/main.rs @@ -6,7 +6,7 @@ use axum::{ extract::Host, - handler::Handler, + handler::HandlerWithoutStateExt, http::{StatusCode, Uri}, response::Redirect, routing::get, diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs index 9a33416be3..b82a308db4 100644 --- a/examples/todos/src/main.rs +++ b/examples/todos/src/main.rs @@ -15,7 +15,7 @@ use axum::{ error_handling::HandleErrorLayer, - extract::{Extension, Path, Query}, + extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, routing::{get, patch}, @@ -46,7 +46,7 @@ async fn main() { let db = Db::default(); // Compose the routes - let app = Router::new() + let app = Router::with_state(db) .route("/todos", get(todos_index).post(todos_create)) .route("/todos/:id", patch(todos_update).delete(todos_delete)) // Add middleware to all routes @@ -64,7 +64,6 @@ async fn main() { })) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) - .layer(Extension(db)) .into_inner(), ); @@ -85,7 +84,7 @@ pub struct Pagination { async fn todos_index( pagination: Option>, - Extension(db): Extension, + State(db): State, ) -> impl IntoResponse { let todos = db.read().unwrap(); @@ -106,10 +105,7 @@ struct CreateTodo { text: String, } -async fn todos_create( - Json(input): Json, - Extension(db): Extension, -) -> impl IntoResponse { +async fn todos_create(Json(input): Json, State(db): State) -> impl IntoResponse { let todo = Todo { id: Uuid::new_v4(), text: input.text, @@ -130,7 +126,7 @@ struct UpdateTodo { async fn todos_update( Path(id): Path, Json(input): Json, - Extension(db): Extension, + State(db): State, ) -> Result { let mut todo = db .read() @@ -152,7 +148,7 @@ async fn todos_update( Ok(Json(todo)) } -async fn todos_delete(Path(id): Path, Extension(db): Extension) -> impl IntoResponse { +async fn todos_delete(Path(id): Path, State(db): State) -> impl IntoResponse { if db.write().unwrap().remove(&id).is_some() { StatusCode::NO_CONTENT } else { diff --git a/examples/tokio-postgres/src/main.rs b/examples/tokio-postgres/src/main.rs index 66b03a8fa9..e0c60453e3 100644 --- a/examples/tokio-postgres/src/main.rs +++ b/examples/tokio-postgres/src/main.rs @@ -6,7 +6,7 @@ use axum::{ async_trait, - extract::{Extension, FromRequest, RequestParts}, + extract::{FromRequest, RequestParts, State}, http::StatusCode, routing::get, Router, @@ -33,12 +33,10 @@ async fn main() { let pool = Pool::builder().build(manager).await.unwrap(); // build our application with some routes - let app = Router::new() - .route( - "/", - get(using_connection_pool_extractor).post(using_connection_extractor), - ) - .layer(Extension(pool)); + let app = Router::with_state(pool).route( + "/", + get(using_connection_pool_extractor).post(using_connection_extractor), + ); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -51,9 +49,8 @@ async fn main() { type ConnectionPool = Pool>; -// we can extract the connection pool with `Extension` async fn using_connection_pool_extractor( - Extension(pool): Extension, + State(pool): State, ) -> Result { let conn = pool.get().await.map_err(internal_error)?; @@ -71,16 +68,16 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager>); #[async_trait] -impl FromRequest for DatabaseConnection +impl FromRequest for DatabaseConnection where B: Send, { type Rejection = (StatusCode, String); - async fn from_request(req: &mut RequestParts) -> Result { - let Extension(pool) = Extension::::from_request(req) - .await - .map_err(internal_error)?; + async fn from_request( + req: &mut RequestParts, + ) -> Result { + let pool = req.state().clone(); let conn = pool.get_owned().await.map_err(internal_error)?; diff --git a/examples/validator/src/main.rs b/examples/validator/src/main.rs index c8ce8c0884..8682eb85e5 100644 --- a/examples/validator/src/main.rs +++ b/examples/validator/src/main.rs @@ -60,16 +60,17 @@ async fn handler(ValidatedForm(input): ValidatedForm) -> Html pub struct ValidatedForm(pub T); #[async_trait] -impl FromRequest for ValidatedForm +impl FromRequest for ValidatedForm where T: DeserializeOwned + Validate, + S: Send, B: http_body::Body + Send, B::Data: Send, B::Error: Into, { type Rejection = ServerError; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let Form(value) = Form::::from_request(req).await?; value.validate()?; Ok(ValidatedForm(value)) diff --git a/examples/versioning/src/main.rs b/examples/versioning/src/main.rs index 48ade3c90c..cf8e15f280 100644 --- a/examples/versioning/src/main.rs +++ b/examples/versioning/src/main.rs @@ -48,13 +48,14 @@ enum Version { } #[async_trait] -impl FromRequest for Version +impl FromRequest for Version where B: Send, + S: Send, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let params = Path::>::from_request(req) .await .map_err(IntoResponse::into_response)?; diff --git a/examples/websockets/src/main.rs b/examples/websockets/src/main.rs index bbdcaa0821..a317bfa470 100644 --- a/examples/websockets/src/main.rs +++ b/examples/websockets/src/main.rs @@ -37,7 +37,7 @@ async fn main() { // build our application with some routes let app = Router::new() - .fallback( + .fallback_service( get_service(ServeDir::new(assets_dir).append_index_html_on_directories(true)) .handle_error(|error: std::io::Error| async move { (