From 2eb4c14c5c990e8e33e1a2f468863a034e4affeb Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sat, 9 Jul 2022 23:12:20 +0200 Subject: [PATCH 01/45] begin threading the state through --- Cargo.toml | 4 +- axum-extra/src/extract/cookie/mod.rs | 4 +- axum-extra/src/routing/mod.rs | 80 ++--- axum-extra/src/routing/resource.rs | 39 ++- axum-extra/src/routing/spa.rs | 13 +- axum/src/extract/matched_path.rs | 2 +- axum/src/handler/into_service.rs | 22 +- .../into_service_state_in_extension.rs | 88 ++++++ axum/src/handler/mod.rs | 61 ++-- axum/src/response/mod.rs | 4 +- axum/src/routing/method_routing.rs | 283 +++++++++++------- axum/src/routing/mod.rs | 195 +++++++----- axum/src/routing/tests/fallback.rs | 8 +- axum/src/routing/tests/mod.rs | 23 +- axum/src/routing/tests/nest.rs | 25 +- 15 files changed, 536 insertions(+), 315 deletions(-) create mode 100644 axum/src/handler/into_service_state_in_extension.rs diff --git a/Cargo.toml b/Cargo.toml index a221fbc5a9..a84f6fddbb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,8 +2,8 @@ members = [ "axum", "axum-core", - "axum-extra", - "axum-macros", + # "axum-extra", + # "axum-macros", # internal crate used to bump the minimum versions we # get for some dependencies which otherwise wouldn't build diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index 7bda1ea34e..25d6d1fa3b 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -226,7 +226,7 @@ mod tests { jar.remove(Cookie::named("key")) } - let app = Router::::new() + let app = Router::<_, Body>::new() .route("/set", get(set_cookie)) .route("/get", get(get_cookie)) .route("/remove", get(remove_cookie)) @@ -294,7 +294,7 @@ mod tests { format!("{:?}", jar.get("key")) } - let app = Router::::new() + let app = Router::<_, Body>::new() .route("/get", get(get_cookie)) .layer(Extension(Key::generate())); diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 6298a36e5c..9c4b512a43 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -4,6 +4,7 @@ use axum::{ handler::Handler, http::Request, response::{Redirect, Response}, + routing::MethodRouter, Router, }; use std::{convert::Infallible, future::ready}; @@ -29,7 +30,7 @@ pub use self::typed::{FirstElementIs, TypedPath}; pub use self::spa::SpaRouter; /// Extension trait that adds additional methods to [`Router`]. -pub trait RouterExt: sealed::Sealed { +pub trait RouterExt: sealed::Sealed { /// Add a typed `GET` route to the router. /// /// The path will be inferred from the first argument to the handler function which must @@ -39,7 +40,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_get(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -52,7 +53,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_delete(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -65,7 +66,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_head(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -78,7 +79,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_options(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -91,7 +92,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_patch(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -104,7 +105,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_post(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -117,7 +118,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_put(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -130,7 +131,7 @@ pub trait RouterExt: sealed::Sealed { #[cfg(feature = "typed-routing")] fn typed_trace(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath; @@ -159,21 +160,20 @@ pub trait RouterExt: sealed::Sealed { /// .route_with_tsr("/bar/", get(|| async {})); /// # let _: Router = app; /// ``` - fn route_with_tsr(self, path: &str, service: T) -> Self + fn route_with_tsr(self, path: &str, method_router: MethodRouter) -> Self where - T: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, - T::Future: Send + 'static, Self: Sized; } -impl RouterExt for Router +impl RouterExt for Router where B: axum::body::HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, { #[cfg(feature = "typed-routing")] fn typed_get(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -183,7 +183,7 @@ where #[cfg(feature = "typed-routing")] fn typed_delete(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -193,7 +193,7 @@ where #[cfg(feature = "typed-routing")] fn typed_head(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -203,7 +203,7 @@ where #[cfg(feature = "typed-routing")] fn typed_options(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -213,7 +213,7 @@ where #[cfg(feature = "typed-routing")] fn typed_patch(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -223,7 +223,7 @@ where #[cfg(feature = "typed-routing")] fn typed_post(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -233,7 +233,7 @@ where #[cfg(feature = "typed-routing")] fn typed_put(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { @@ -243,40 +243,40 @@ where #[cfg(feature = "typed-routing")] fn typed_trace(self, handler: H) -> Self where - H: Handler, + H: Handler, T: FirstElementIs

+ 'static, P: TypedPath, { self.route(P::PATH, axum::routing::trace(handler)) } - fn route_with_tsr(mut self, path: &str, service: T) -> Self + fn route_with_tsr(self, path: &str, method_router: MethodRouter) -> Self where - T: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, - T::Future: Send + 'static, Self: Sized, { - self = self.route(path, service); - - let redirect = Redirect::permanent(path); - - if let Some(path_without_trailing_slash) = path.strip_suffix('/') { - self.route( - path_without_trailing_slash, - (move || ready(redirect.clone())).into_service(), - ) - } else { - self.route( - &format!("{}/", path), - (move || ready(redirect.clone())).into_service(), - ) - } + todo!() + + // self = self.route(path, service); + + // let redirect = Redirect::permanent(path); + + // if let Some(path_without_trailing_slash) = path.strip_suffix('/') { + // self.route( + // path_without_trailing_slash, + // (move || ready(redirect.clone())).into_service(), + // ) + // } else { + // self.route( + // &format!("{}/", path), + // (move || ready(redirect.clone())).into_service(), + // ) + // } } } mod sealed { pub trait Sealed {} - impl Sealed for axum::Router {} + impl Sealed for axum::Router {} } #[cfg(test)] diff --git a/axum-extra/src/routing/resource.rs b/axum-extra/src/routing/resource.rs index 98105facd9..b8dc68d47d 100644 --- a/axum-extra/src/routing/resource.rs +++ b/axum-extra/src/routing/resource.rs @@ -3,7 +3,7 @@ use axum::{ handler::Handler, http::Request, response::Response, - routing::{delete, get, on, post, MethodFilter}, + routing::{delete, get, on, post, MethodFilter, MethodRouter}, Router, }; use std::convert::Infallible; @@ -48,29 +48,30 @@ use tower_service::Service; /// # let _: Router = app; /// ``` #[derive(Debug)] -pub struct Resource { +pub struct Resource { pub(crate) name: String, - pub(crate) router: Router, + pub(crate) router: Router, } -impl Resource +impl Resource where B: axum::body::HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, { /// Create a `Resource` with the given name. /// /// All routes will be nested at `/{resource_name}`. - pub fn named(resource_name: &str) -> Self { + pub fn named(state: S, resource_name: &str) -> Self { Self { name: resource_name.to_owned(), - router: Default::default(), + router: Router::with_state(state), } } /// Add a handler at `GET /{resource_name}`. pub fn index(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = self.index_create_path(); @@ -80,7 +81,7 @@ where /// Add a handler at `POST /{resource_name}`. pub fn create(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = self.index_create_path(); @@ -90,7 +91,7 @@ where /// Add a handler at `GET /{resource_name}/new`. pub fn new(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = format!("/{}/new", self.name); @@ -100,7 +101,7 @@ where /// Add a handler at `GET /{resource_name}/:{resource_name}_id`. pub fn show(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = self.show_update_destroy_path(); @@ -110,7 +111,7 @@ where /// Add a handler at `GET /{resource_name}/:{resource_name}_id/edit`. pub fn edit(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = format!("/{0}/:{0}_id/edit", self.name); @@ -120,7 +121,7 @@ where /// Add a handler at `PUT or PATCH /resource_name/:{resource_name}_id`. pub fn update(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = self.show_update_destroy_path(); @@ -130,7 +131,7 @@ where /// Add a handler at `DELETE /{resource_name}/:{resource_name}_id`. pub fn destroy(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, { let path = self.show_update_destroy_path(); @@ -171,12 +172,8 @@ where format!("/{0}/:{0}_id", self.name) } - fn route(mut self, path: &str, svc: T) -> Self - where - T: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, - T::Future: Send + 'static, - { - self.router = self.router.route(path, svc); + fn route(mut self, path: &str, method_router: MethodRouter) -> Self { + self.router = self.router.route(path, method_router); self } } @@ -196,7 +193,7 @@ mod tests { #[tokio::test] async fn works() { - let users = Resource::named("users") + let users = Resource::named((), "users") .index(|| async { "users#index" }) .create(|| async { "users#create" }) .new(|| async { "users#new" }) @@ -265,7 +262,7 @@ mod tests { ); } - async fn call_route(app: &mut Router, method: Method, uri: &str) -> String { + async fn call_route(app: &mut Router<()>, method: Method, uri: &str) -> String { let res = app .ready() .await diff --git a/axum-extra/src/routing/spa.rs b/axum-extra/src/routing/spa.rs index 594b15c237..fac7e21297 100644 --- a/axum-extra/src/routing/spa.rs +++ b/axum-extra/src/routing/spa.rs @@ -147,7 +147,7 @@ impl SpaRouter { } } -impl From> for Router +impl From> for Router<(), B> where F: Clone + Send + 'static, HandleError, F, T>: @@ -158,12 +158,15 @@ where { fn from(spa: SpaRouter) -> Self { let assets_service = get_service(ServeDir::new(&spa.paths.assets_dir)) - .handle_error(spa.handle_error.clone()); + .handle_error(spa.handle_error.clone()) + .with_state(()); Router::new() .nest(&spa.paths.assets_path, assets_service) - .fallback( - get_service(ServeFile::new(&spa.paths.index_file)).handle_error(spa.handle_error), + .fallback_service( + get_service(ServeFile::new(&spa.paths.index_file)) + .handle_error(spa.handle_error) + .with_state(()), ) } } @@ -264,6 +267,6 @@ mod tests { let spa = SpaRouter::new("/assets", "test_files").handle_error(handle_error); - Router::::new().merge(spa); + Router::<_, Body>::new().merge(spa); } } diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index 8965bd30bf..125422f203 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -149,7 +149,7 @@ mod tests { "/public", Router::new().route("/assets/*path", get(handler)), ) - .nest("/foo", handler.into_service()) + .nest("/foo", handler.into_service(())) .layer(tower::layer::layer_fn(SetMatchedPathExtension)); let client = TestClient::new(app); diff --git a/axum/src/handler/into_service.rs b/axum/src/handler/into_service.rs index 34f36b2d21..73775133cc 100644 --- a/axum/src/handler/into_service.rs +++ b/axum/src/handler/into_service.rs @@ -12,28 +12,30 @@ use tower_service::Service; /// An adapter that makes a [`Handler`] into a [`Service`]. /// /// Created with [`Handler::into_service`]. -pub struct IntoService { +pub struct IntoService { handler: H, + state: S, _marker: PhantomData (T, B)>, } #[test] fn traits() { use crate::test_helpers::*; - assert_send::>(); - assert_sync::>(); + assert_send::>(); + assert_sync::>(); } -impl IntoService { - pub(super) fn new(handler: H) -> Self { +impl IntoService { + pub(super) fn new(handler: H, state: S) -> Self { Self { handler, + state, _marker: PhantomData, } } } -impl fmt::Debug for IntoService { +impl fmt::Debug for IntoService { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("IntoService") .field(&format_args!("...")) @@ -41,21 +43,23 @@ impl fmt::Debug for IntoService { } } -impl Clone for IntoService +impl Clone for IntoService where H: Clone, + S: Clone, { fn clone(&self) -> Self { Self { handler: self.handler.clone(), + state: self.state.clone(), _marker: PhantomData, } } } -impl Service> for IntoService +impl Service> for IntoService where - H: Handler + Clone + Send + 'static, + H: Handler + Clone + Send + 'static, B: Send + 'static, { type Response = Response; diff --git a/axum/src/handler/into_service_state_in_extension.rs b/axum/src/handler/into_service_state_in_extension.rs new file mode 100644 index 0000000000..feb9c193a9 --- /dev/null +++ b/axum/src/handler/into_service_state_in_extension.rs @@ -0,0 +1,88 @@ +use super::Handler; +use crate::response::Response; +use http::Request; +use std::{ + convert::Infallible, + fmt, + marker::PhantomData, + task::{Context, Poll}, +}; +use tower_service::Service; + +pub(crate) struct IntoServiceStateInExtension { + handler: H, + _marker: PhantomData (T, S, B)>, +} + +#[test] +fn traits() { + use crate::test_helpers::*; + assert_send::>(); + assert_sync::>(); +} + +impl IntoServiceStateInExtension { + pub(crate) fn new(handler: H) -> Self { + Self { + handler, + _marker: PhantomData, + } + } +} + +impl fmt::Debug for IntoServiceStateInExtension { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("IntoServiceStateInExtension") + .field(&format_args!("...")) + .finish() + } +} + +impl Clone for IntoServiceStateInExtension +where + H: Clone, +{ + fn clone(&self) -> Self { + Self { + handler: self.handler.clone(), + _marker: PhantomData, + } + } +} + +impl Service> for IntoServiceStateInExtension +where + H: Handler + Clone + Send + 'static, + B: Send + 'static, + S: Clone + Send + Sync + 'static, +{ + type Response = Response; + type Error = Infallible; + type Future = super::future::IntoServiceFuture; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + // `IntoServiceStateInExtension` can only be constructed from async functions which are always ready, or + // from `Layered` which bufferes in `::call` and is therefore + // also always ready. + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + use futures_util::future::FutureExt; + + let state = req + .extensions() + .get::() + .expect("state extension missing. This is a bug in axum, please file an issue") + .clone(); + + todo!() + + // let handler = self.handler.clone(); + // let future = Handler::call(handler, req); + // let future = future.map(Ok as _); + + // super::future::IntoServiceFuture::new(future) + } +} diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 69482fecef..8c819744a8 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -50,8 +50,10 @@ use tower_service::Service; pub mod future; mod into_service; +mod into_service_state_in_extension; pub use self::into_service::IntoService; +pub(crate) use self::into_service_state_in_extension::IntoServiceStateInExtension; /// Trait for async functions that can be used to handle requests. /// @@ -61,7 +63,7 @@ pub use self::into_service::IntoService; /// See the [module docs](crate::handler) for more details. /// #[doc = include_str!("../docs/debugging_handler_type_errors.md")] -pub trait Handler: Clone + Send + Sized + 'static { +pub trait Handler: Clone + Send + Sized + 'static { /// The type of future calling this handler returns. type Future: Future + Send + 'static; @@ -104,11 +106,12 @@ pub trait Handler: Clone + Send + Sized + 'static { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - fn layer(self, layer: L) -> Layered + fn layer(self, layer: L) -> Layered where - L: Layer>, + L: Layer>, { - Layered::new(layer.layer(self.into_service())) + todo!() + // Layered::new(layer.layer(self.into_service())) } /// Convert the handler into a [`Service`]. @@ -143,8 +146,8 @@ pub trait Handler: Clone + Send + Sized + 'static { /// ``` /// /// [`Router::fallback`]: crate::routing::Router::fallback - fn into_service(self) -> IntoService { - IntoService::new(self) + fn into_service(self, state: S) -> IntoService { + IntoService::new(self, state) } /// Convert the handler into a [`MakeService`]. @@ -170,8 +173,8 @@ pub trait Handler: Clone + Send + Sized + 'static { /// ``` /// /// [`MakeService`]: tower::make::MakeService - fn into_make_service(self) -> IntoMakeService> { - IntoMakeService::new(self.into_service()) + fn into_make_service(self, state: S) -> IntoMakeService> { + IntoMakeService::new(self.into_service(state)) } /// Convert the handler into a [`MakeService`] which stores information @@ -204,12 +207,13 @@ pub trait Handler: Clone + Send + Sized + 'static { /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info fn into_make_service_with_connect_info( self, - ) -> IntoMakeServiceWithConnectInfo, C> { - IntoMakeServiceWithConnectInfo::new(self.into_service()) + state: S, + ) -> IntoMakeServiceWithConnectInfo, C> { + IntoMakeServiceWithConnectInfo::new(self.into_service(state)) } } -impl Handler<(), B> for F +impl Handler<(), S, B> for F where F: FnOnce() -> Fut + Clone + Send + 'static, Fut: Future + Send, @@ -226,7 +230,7 @@ where macro_rules! impl_handler { ( $($ty:ident),* $(,)? ) => { #[allow(non_snake_case)] - impl Handler<($($ty,)*), B> for F + impl Handler<($($ty,)*), S, B> for F where F: FnOnce($($ty,)*) -> Fut + Clone + Send + 'static, Fut: Future + Send, @@ -261,45 +265,46 @@ all_the_tuples!(impl_handler); /// A [`Service`] created from a [`Handler`] by applying a Tower middleware. /// /// Created with [`Handler::layer`]. See that method for more details. -pub struct Layered { - svc: S, - _input: PhantomData T>, +pub struct Layered { + svc: Svc, + _input: PhantomData (T, S)>, } -impl fmt::Debug for Layered +impl fmt::Debug for Layered where - S: fmt::Debug, + Svc: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Layered").field("svc", &self.svc).finish() } } -impl Clone for Layered +impl Clone for Layered where - S: Clone, + Svc: Clone, { fn clone(&self) -> Self { Self::new(self.svc.clone()) } } -impl Handler for Layered +impl Handler for Layered where - S: Service, Response = Response> + Clone + Send + 'static, - S::Error: IntoResponse, - S::Future: Send, + Svc: Service, Response = Response> + Clone + Send + 'static, + Svc::Error: IntoResponse, + Svc::Future: Send, T: 'static, + S: 'static, ReqBody: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { - type Future = future::LayeredFuture; + type Future = future::LayeredFuture; fn call(self, req: Request) -> Self::Future { use futures_util::future::{FutureExt, Map}; - let future: Map<_, fn(Result) -> _> = + let future: Map<_, fn(Result) -> _> = self.svc.oneshot(req).map(|result| match result { Ok(res) => res.map(boxed), Err(res) => res.into_response(), @@ -309,8 +314,8 @@ where } } -impl Layered { - pub(crate) fn new(svc: S) -> Self { +impl Layered { + pub(crate) fn new(svc: Svc) -> Self { Self { svc, _input: PhantomData, @@ -330,7 +335,7 @@ mod tests { format!("you said: {}", body) } - let client = TestClient::new(handle.into_service()); + let client = TestClient::new(handle.into_service(())); let res = client.post("/").body("hi there!").send().await; assert_eq!(res.status(), StatusCode::OK); diff --git a/axum/src/response/mod.rs b/axum/src/response/mod.rs index 695b931c22..5cf19c04ee 100644 --- a/axum/src/response/mod.rs +++ b/axum/src/response/mod.rs @@ -93,7 +93,7 @@ mod tests { } } - Router::::new() + Router::<_, Body>::new() .route("/", get(impl_trait_ok)) .route("/", get(impl_trait_err)) .route("/", get(impl_trait_both)) @@ -203,7 +203,7 @@ mod tests { ) } - Router::::new() + Router::<_, Body>::new() .route("/", get(status)) .route("/", get(status_headermap)) .route("/", get(status_header_array)) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 4edb8bdb66..1353df1cd9 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -3,7 +3,7 @@ use crate::{ body::{boxed, Body, Bytes, Empty, HttpBody}, error_handling::{HandleError, HandleErrorLayer}, extract::connect_info::IntoMakeServiceWithConnectInfo, - handler::Handler, + handler::{Handler, IntoServiceStateInExtension}, http::{Method, Request, StatusCode}, response::Response, routing::{future::RouteFuture, Fallback, MethodFilter, Route}, @@ -76,10 +76,10 @@ macro_rules! top_level_service_fn { $name:ident, $method:ident ) => { $(#[$m])+ - pub fn $name(svc: S) -> MethodRouter + pub fn $name(svc: T) -> MethodRouter where - S: Service, Response = Response> + Clone + Send + 'static, - S::Future: Send + 'static, + T: Service, Response = Response> + Clone + Send + 'static, + T::Future: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { @@ -137,11 +137,12 @@ macro_rules! top_level_handler_fn { $name:ident, $method:ident ) => { $(#[$m])+ - pub fn $name(handler: H) -> MethodRouter + pub fn $name(handler: H) -> MethodRouter where - H: Handler, + H: Handler, B: Send + 'static, T: 'static, + S: Clone + Send + Sync + 'static, { on(MethodFilter::$method, handler) } @@ -208,13 +209,13 @@ macro_rules! chained_service_fn { $name:ident, $method:ident ) => { $(#[$m])+ - pub fn $name(self, svc: S) -> Self + pub fn $name(self, svc: T) -> Self where - S: Service, Response = Response, Error = E> + T: Service, Response = Response, Error = E> + Clone + Send + 'static, - S::Future: Send + 'static, + T::Future: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { @@ -274,8 +275,9 @@ macro_rules! chained_handler_fn { $(#[$m])+ pub fn $name(self, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, + S: Clone + Send + Sync + 'static, { self.on(MethodFilter::$method, handler) } @@ -316,13 +318,13 @@ top_level_service_fn!(trace_service, TRACE); /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn on_service( +pub fn on_service( filter: MethodFilter, - svc: S, -) -> MethodRouter + svc: T, +) -> MethodRouter where - S: Service, Response = Response> + Clone + Send + 'static, - S::Future: Send + 'static, + T: Service, Response = Response> + Clone + Send + 'static, + T::Future: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { @@ -382,10 +384,10 @@ where /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn any_service(svc: S) -> MethodRouter +pub fn any_service(svc: T) -> MethodRouter where - S: Service, Response = Response> + Clone + Send + 'static, - S::Future: Send + 'static, + T: Service, Response = Response> + Clone + Send + 'static, + T::Future: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { @@ -420,11 +422,12 @@ top_level_handler_fn!(trace, TRACE); /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn on(filter: MethodFilter, handler: H) -> MethodRouter +pub fn on(filter: MethodFilter, handler: H) -> MethodRouter where - H: Handler, + H: Handler, B: Send + 'static, T: 'static, + S: Clone + Send + Sync + 'static, { MethodRouter::new().on(filter, handler) } @@ -466,20 +469,21 @@ where /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn any(handler: H) -> MethodRouter +pub fn any(handler: H) -> MethodRouter where - H: Handler, + H: Handler, B: Send + 'static, T: 'static, + S: Clone + Send + Sync + 'static, { MethodRouter::new() - .fallback_boxed_response_body(handler.into_service()) + .fallback_boxed_response_body(IntoServiceStateInExtension::new(handler)) .skip_allow_header() } /// A [`Service`] that accepts requests based on a [`MethodFilter`] and /// allows chaining additional handlers and services. -pub struct MethodRouter { +pub struct MethodRouter { get: Option>, head: Option>, delete: Option>, @@ -490,7 +494,7 @@ pub struct MethodRouter { trace: Option>, fallback: Fallback, allow_header: AllowHeader, - _request_body: PhantomData (B, E)>, + _request_body: PhantomData (B, S, E)>, } #[derive(Clone)] @@ -503,7 +507,7 @@ enum AllowHeader { Bytes(BytesMut), } -impl fmt::Debug for MethodRouter { +impl fmt::Debug for MethodRouter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MethodRouter") .field("get", &self.get) @@ -519,7 +523,7 @@ impl fmt::Debug for MethodRouter { } } -impl MethodRouter { +impl MethodRouter { /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all /// requests. pub fn new() -> Self { @@ -543,9 +547,32 @@ impl MethodRouter { _request_body: PhantomData, } } + + pub fn with_state(self, state: S) -> MethodRouterWithState { + MethodRouterWithState { + method_router: self, + state, + } + } + + pub(crate) fn downcast_state(self) -> MethodRouter { + MethodRouter { + get: self.get, + head: self.head, + delete: self.delete, + options: self.options, + patch: self.patch, + post: self.post, + put: self.put, + trace: self.trace, + fallback: self.fallback, + allow_header: self.allow_header, + _request_body: PhantomData, + } + } } -impl MethodRouter +impl MethodRouter where B: Send + 'static, { @@ -574,10 +601,11 @@ where /// ``` pub fn on(self, filter: MethodFilter, handler: H) -> Self where - H: Handler, + H: Handler, T: 'static, + S: Clone + Send + Sync + 'static, { - self.on_service_boxed_response_body(filter, handler.into_service()) + self.on_service_boxed_response_body(filter, IntoServiceStateInExtension::new(handler)) } chained_handler_fn!(delete, DELETE); @@ -658,7 +686,7 @@ where } } -impl MethodRouter { +impl MethodRouter { /// Chain an additional service that will accept requests matching the given /// `MethodFilter`. /// @@ -684,13 +712,13 @@ impl MethodRouter { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - pub fn on_service(self, filter: MethodFilter, svc: S) -> Self + pub fn on_service(self, filter: MethodFilter, svc: T) -> Self where - S: Service, Response = Response, Error = E> + T: Service, Response = Response, Error = E> + Clone + Send + 'static, - S::Future: Send + 'static, + T::Future: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { @@ -707,13 +735,13 @@ impl MethodRouter { chained_service_fn!(trace_service, TRACE); #[doc = include_str!("../docs/method_routing/fallback.md")] - pub fn fallback(mut self, svc: S) -> Self + pub fn fallback(mut self, svc: T) -> Self where - S: Service, Response = Response, Error = E> + T: Service, Response = Response, Error = E> + Clone + Send + 'static, - S::Future: Send + 'static, + T::Future: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { @@ -721,10 +749,10 @@ impl MethodRouter { self } - fn fallback_boxed_response_body(mut self, svc: S) -> Self + fn fallback_boxed_response_body(mut self, svc: T) -> Self where - S: Service, Response = Response, Error = E> + Clone + Send + 'static, - S::Future: Send + 'static, + T: Service, Response = Response, Error = E> + Clone + Send + 'static, + T::Future: Send + 'static, { self.fallback = Fallback::Custom(Route::new(svc)); self @@ -734,7 +762,7 @@ impl MethodRouter { pub fn layer( self, layer: L, - ) -> MethodRouter + ) -> MethodRouter where L: Layer>, L::Service: Service, Response = Response, Error = NewError> @@ -768,7 +796,7 @@ impl MethodRouter { } #[doc = include_str!("../docs/method_routing/route_layer.md")] - pub fn route_layer(self, layer: L) -> MethodRouter + pub fn route_layer(self, layer: L) -> MethodRouter where L: Layer>, L::Service: Service, Response = Response, Error = E> @@ -802,7 +830,7 @@ impl MethodRouter { } #[doc = include_str!("../docs/method_routing/merge.md")] - pub fn merge(self, other: MethodRouter) -> Self { + pub fn merge(self, other: MethodRouter) -> Self { macro_rules! merge { ( $first:ident, $second:ident ) => { match ($first, $second) { @@ -894,7 +922,7 @@ impl MethodRouter { /// Apply a [`HandleErrorLayer`]. /// /// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`. - pub fn handle_error(self, f: F) -> MethodRouter + pub fn handle_error(self, f: F) -> MethodRouter where F: Clone + Send + 'static, HandleError, F, T>: @@ -907,10 +935,10 @@ impl MethodRouter { self.layer(HandleErrorLayer::new(f)) } - fn on_service_boxed_response_body(self, filter: MethodFilter, svc: S) -> Self + fn on_service_boxed_response_body(self, filter: MethodFilter, svc: T) -> Self where - S: Service, Response = Response, Error = E> + Clone + Send + 'static, - S::Future: Send + 'static, + T: Service, Response = Response, Error = E> + Clone + Send + 'static, + T::Future: Send + 'static, { macro_rules! set_service { ( @@ -1009,7 +1037,7 @@ fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) { } } -impl Clone for MethodRouter { +impl Clone for MethodRouter { fn clone(&self) -> Self { Self { get: self.get.clone(), @@ -1027,7 +1055,7 @@ impl Clone for MethodRouter { } } -impl Default for MethodRouter +impl Default for MethodRouter where B: Send + 'static, { @@ -1036,7 +1064,36 @@ where } } -impl Service> for MethodRouter +pub struct MethodRouterWithState { + method_router: MethodRouter, + state: S, +} + +impl Clone for MethodRouterWithState +where + S: Clone, +{ + fn clone(&self) -> Self { + Self { + method_router: self.method_router.clone(), + state: self.state.clone(), + } + } +} + +impl fmt::Debug for MethodRouterWithState +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MethodRouterWithState") + .field("method_router", &self.method_router) + .field("state", &self.state) + .finish() + } +} + +impl Service> for MethodRouterWithState where B: HttpBody, { @@ -1068,8 +1125,11 @@ where let method = req.method().clone(); + // set state in request extensions + todo!(); + // written with a pattern match like this to ensure we call all routes - let Self { + let MethodRouter { get, head, delete, @@ -1081,7 +1141,7 @@ where fallback, allow_header, _request_body: _, - } = self; + } = self.method_router; call!(req, method, HEAD, head); call!(req, method, HEAD, get); @@ -1120,7 +1180,7 @@ mod tests { #[tokio::test] async fn method_not_allowed_by_default() { - let mut svc = MethodRouter::new(); + let mut svc = MethodRouter::new().with_state(()); let (status, _, body) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert!(body.is_empty()); @@ -1128,7 +1188,7 @@ mod tests { #[tokio::test] async fn get_handler() { - let mut svc = MethodRouter::new().get(ok); + let mut svc = MethodRouter::new().get(ok).with_state(()); let (status, _, body) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); assert_eq!(body, "ok"); @@ -1136,7 +1196,7 @@ mod tests { #[tokio::test] async fn get_accepts_head() { - let mut svc = MethodRouter::new().get(ok); + let mut svc = MethodRouter::new().get(ok).with_state(()); let (status, _, body) = call(Method::HEAD, &mut svc).await; assert_eq!(status, StatusCode::OK); assert!(body.is_empty()); @@ -1144,7 +1204,7 @@ mod tests { #[tokio::test] async fn head_takes_precedence_over_get() { - let mut svc = MethodRouter::new().head(created).get(ok); + let mut svc = MethodRouter::new().head(created).get(ok).with_state(()); let (status, _, body) = call(Method::HEAD, &mut svc).await; assert_eq!(status, StatusCode::CREATED); assert!(body.is_empty()); @@ -1152,7 +1212,7 @@ mod tests { #[tokio::test] async fn merge() { - let mut svc = get(ok).merge(post(ok)); + let mut svc = get(ok).merge(post(ok)).with_state(()); let (status, _, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); @@ -1165,7 +1225,8 @@ mod tests { async fn layer() { let mut svc = MethodRouter::new() .get(|| async { std::future::pending::<()>().await }) - .layer(RequireAuthorizationLayer::bearer("password")); + .layer(RequireAuthorizationLayer::bearer("password")) + .with_state(()); // method with route let (status, _, _) = call(Method::GET, &mut svc).await; @@ -1180,7 +1241,8 @@ mod tests { async fn route_layer() { let mut svc = MethodRouter::new() .get(|| async { std::future::pending::<()>().await }) - .route_layer(RequireAuthorizationLayer::bearer("password")); + .route_layer(RequireAuthorizationLayer::bearer("password")) + .with_state(()); // method with route let (status, _, _) = call(Method::GET, &mut svc).await; @@ -1203,7 +1265,8 @@ mod tests { delete_service(ServeDir::new(".")) .handle_error(|_| async { StatusCode::NOT_FOUND }), ) - .fallback((|| async { StatusCode::NOT_FOUND }).into_service()) + // TODO(david): add `fallback` and `fallback_service` + // .fallback((|| async { StatusCode::NOT_FOUND }).into_service()) .put(ok) .layer( ServiceBuilder::new() @@ -1219,7 +1282,7 @@ mod tests { #[tokio::test] async fn sets_allow_header() { - let mut svc = MethodRouter::new().put(ok).patch(ok); + let mut svc = MethodRouter::new().put(ok).patch(ok).with_state(()); let (status, headers, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(headers[ALLOW], "PUT,PATCH"); @@ -1227,7 +1290,7 @@ mod tests { #[tokio::test] async fn sets_allow_header_get_head() { - let mut svc = MethodRouter::new().get(ok).head(ok); + let mut svc = MethodRouter::new().get(ok).head(ok).with_state(()); let (status, headers, _) = call(Method::PUT, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(headers[ALLOW], "GET,HEAD"); @@ -1235,7 +1298,7 @@ mod tests { #[tokio::test] async fn empty_allow_header_by_default() { - let mut svc = MethodRouter::new(); + let mut svc = MethodRouter::new().with_state(()); let (status, headers, _) = call(Method::PATCH, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(headers[ALLOW], ""); @@ -1245,7 +1308,7 @@ mod tests { async fn allow_header_when_merging() { let a = put(ok).patch(ok); let b = get(ok).head(ok); - let mut svc = a.merge(b); + let mut svc = a.merge(b).with_state(()); let (status, headers, _) = call(Method::DELETE, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); @@ -1254,60 +1317,61 @@ mod tests { #[tokio::test] async fn allow_header_any() { - let mut svc = any(ok); + let mut svc = any(ok).with_state(()); let (status, headers, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); assert!(!headers.contains_key(ALLOW)); } - #[tokio::test] - async fn allow_header_with_fallback() { - let mut svc = MethodRouter::new().get(ok).fallback( - (|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }).into_service(), - ); - - let (status, headers, _) = call(Method::DELETE, &mut svc).await; - assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); - assert_eq!(headers[ALLOW], "GET,HEAD"); - } - - #[tokio::test] - async fn allow_header_with_fallback_that_sets_allow() { - async fn fallback(method: Method) -> Response { - if method == Method::POST { - "OK".into_response() - } else { - ( - StatusCode::METHOD_NOT_ALLOWED, - [(ALLOW, "GET,POST")], - "Method not allowed", - ) - .into_response() - } - } - - let mut svc = MethodRouter::new() - .get(ok) - .fallback(fallback.into_service()); - - let (status, _, _) = call(Method::GET, &mut svc).await; - assert_eq!(status, StatusCode::OK); - - let (status, _, _) = call(Method::POST, &mut svc).await; - assert_eq!(status, StatusCode::OK); - - let (status, headers, _) = call(Method::DELETE, &mut svc).await; - assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); - assert_eq!(headers[ALLOW], "GET,POST"); - } + // TODO(david): add `fallback` and `fallback_service` + // #[tokio::test] + // async fn allow_header_with_fallback() { + // let mut svc = MethodRouter::new().get(ok).fallback( + // (|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }).into_service(), + // ); + + // let (status, headers, _) = call(Method::DELETE, &mut svc).await; + // assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); + // assert_eq!(headers[ALLOW], "GET,HEAD"); + // } + + // #[tokio::test] + // async fn allow_header_with_fallback_that_sets_allow() { + // async fn fallback(method: Method) -> Response { + // if method == Method::POST { + // "OK".into_response() + // } else { + // ( + // StatusCode::METHOD_NOT_ALLOWED, + // [(ALLOW, "GET,POST")], + // "Method not allowed", + // ) + // .into_response() + // } + // } + + // let mut svc = MethodRouter::new() + // .get(ok) + // .fallback(fallback.into_service()); + + // let (status, _, _) = call(Method::GET, &mut svc).await; + // assert_eq!(status, StatusCode::OK); + + // let (status, _, _) = call(Method::POST, &mut svc).await; + // assert_eq!(status, StatusCode::OK); + + // let (status, headers, _) = call(Method::DELETE, &mut svc).await; + // assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); + // assert_eq!(headers[ALLOW], "GET,POST"); + // } #[tokio::test] #[should_panic( expected = "Overlapping method route. Cannot add two method routes that both handle `GET`" )] async fn handler_overlaps() { - let _: MethodRouter = get(ok).get(ok); + let _: MethodRouter<()> = get(ok).get(ok); } #[tokio::test] @@ -1315,17 +1379,18 @@ mod tests { expected = "Overlapping method route. Cannot add two method routes that both handle `POST`" )] async fn service_overlaps() { - let _: MethodRouter = post_service(ok.into_service()).post_service(ok.into_service()); + let _: MethodRouter<()> = post_service(IntoServiceStateInExtension::<_, _, (), _>::new(ok)) + .post_service(IntoServiceStateInExtension::<_, _, (), _>::new(ok)); } #[tokio::test] async fn get_head_does_not_overlap() { - let _: MethodRouter = get(ok).head(ok); + let _: MethodRouter<()> = get(ok).head(ok); } #[tokio::test] async fn head_get_does_not_overlap() { - let _: MethodRouter = head(ok).get(ok); + let _: MethodRouter<()> = head(ok).get(ok); } async fn call(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String) diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 4cf7abbeb0..9f61fa0c22 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -4,6 +4,7 @@ use self::{future::RouteFuture, not_found::NotFound}; use crate::{ body::{boxed, Body, Bytes, HttpBody}, extract::connect_info::IntoMakeServiceWithConnectInfo, + handler::Handler, response::Response, routing::strip_prefix::StripPrefix, util::try_downcast, @@ -42,7 +43,7 @@ pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, pub use self::method_routing::{ any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service, options, options_service, patch, patch_service, post, post_service, put, put_service, trace, - trace_service, MethodRouter, + trace_service, MethodRouter, MethodRouterWithState, }; #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -62,16 +63,21 @@ impl RouteId { } /// The router type for composing handlers and services. -pub struct Router { - routes: HashMap>, +pub struct Router { + state: S, + routes: HashMap>, node: Arc, fallback: Fallback, nested_at_root: bool, } -impl Clone for Router { +impl Clone for Router +where + S: Clone, +{ fn clone(&self) -> Self { Self { + state: self.state.clone(), routes: self.routes.clone(), node: Arc::clone(&self.node), fallback: self.fallback.clone(), @@ -80,18 +86,23 @@ impl Clone for Router { } } -impl Default for Router +impl Default for Router where B: HttpBody + Send + 'static, + S: Default + Clone + Send + Sync + 'static, { fn default() -> Self { - Self::new() + Self::with_state(S::default()) } } -impl fmt::Debug for Router { +impl fmt::Debug for Router +where + S: fmt::Debug, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Router") + .field("state", &self.state) .field("routes", &self.routes) .field("node", &self.node) .field("fallback", &self.fallback) @@ -103,7 +114,7 @@ impl fmt::Debug for Router { pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param"; const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param"; -impl Router +impl Router<(), B> where B: HttpBody + Send + 'static, { @@ -112,7 +123,18 @@ where /// Unless you add additional routes this will respond with `404 Not Found` to /// all requests. pub fn new() -> Self { + Self::with_state(()) + } +} + +impl Router +where + B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, +{ + pub fn with_state(state: S) -> Self { Self { + state, routes: Default::default(), node: Default::default(), fallback: Fallback::Default(Route::new(NotFound)), @@ -120,58 +142,65 @@ where } } + pub fn route(mut self, path: &str, method_router: MethodRouter) -> Self { + // self.route_service(path, method_router.with_state(self.state.clone())) + todo!() + } + #[doc = include_str!("../docs/routing/route.md")] - pub fn route(mut self, path: &str, service: T) -> Self + pub fn route_service(mut self, path: &str, service: T) -> Self where T: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, T::Future: Send + 'static, { - if path.is_empty() { - panic!("Paths must start with a `/`. Use \"/\" for root routes"); - } else if !path.starts_with('/') { - panic!("Paths must start with a `/`"); - } - - let service = match try_downcast::, _>(service) { - Ok(_) => { - panic!("Invalid route: `Router::route` cannot be used with `Router`s. Use `Router::nest` instead") - } - Err(svc) => svc, - }; - - let id = RouteId::next(); - - let service = match try_downcast::, _>(service) { - Ok(method_router) => { - if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self - .node - .path_to_route_id - .get(path) - .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc))) - { - // if we're adding a new `MethodRouter` to a route that already has one just - // merge them. This makes `.route("/", get(_)).route("/", post(_))` work - let service = - Endpoint::MethodRouter(prev_method_router.clone().merge(method_router)); - self.routes.insert(route_id, service); - return self; - } else { - Endpoint::MethodRouter(method_router) - } - } - Err(service) => Endpoint::Route(Route::new(service)), - }; - - let mut node = - Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone()); - if let Err(err) = node.insert(path, id) { - self.panic_on_matchit_error(err); - } - self.node = Arc::new(node); - - self.routes.insert(id, service); - - self + todo!() + + // if path.is_empty() { + // panic!("Paths must start with a `/`. Use \"/\" for root routes"); + // } else if !path.starts_with('/') { + // panic!("Paths must start with a `/`"); + // } + + // let service = match try_downcast::, _>(service) { + // Ok(_) => { + // panic!("Invalid route: `Router::route` cannot be used with `Router`s. Use `Router::nest` instead") + // } + // Err(svc) => svc, + // }; + + // let id = RouteId::next(); + + // let service = match try_downcast::, _>(service) { + // Ok(method_router) => { + // if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self + // .node + // .path_to_route_id + // .get(path) + // .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc))) + // { + // // if we're adding a new `MethodRouter` to a route that already has one just + // // merge them. This makes `.route("/", get(_)).route("/", post(_))` work + // let service = + // Endpoint::MethodRouter(prev_method_router.clone().merge(method_router)); + // self.routes.insert(route_id, service); + // return self; + // } else { + // Endpoint::MethodRouter(method_router) + // } + // } + // Err(service) => Endpoint::Route(Route::new(service)), + // }; + + // let mut node = + // Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone()); + // if let Err(err) = node.insert(path, id) { + // self.panic_on_matchit_error(err); + // } + // self.node = Arc::new(node); + + // self.routes.insert(id, service); + + // self } #[doc = include_str!("../docs/routing/nest.md")] @@ -195,12 +224,13 @@ where self.nested_at_root = true; } - match try_downcast::, _>(svc) { + match try_downcast::, _>(svc) { // if the user is nesting a `Router` we can implement nesting // by simplying copying all the routes and adding the prefix in // front Ok(router) => { let Router { + state, mut routes, node, fallback, @@ -231,7 +261,7 @@ where method_router.layer(layer_fn(|s| StripPrefix::new(s, prefix))), ), Endpoint::Route(route) => { - self.route(&full_path, StripPrefix::new(route, prefix)) + self.route_service(&full_path, StripPrefix::new(route, prefix)) } }; } @@ -246,7 +276,7 @@ where format!("{}/*{}", path, NEST_TAIL_PARAM) }; - self = self.route(&path, strip_prefix::StripPrefix::new(svc, prefix)); + self = self.route_service(&path, strip_prefix::StripPrefix::new(svc, prefix)); } } @@ -254,11 +284,14 @@ where } #[doc = include_str!("../docs/routing/merge.md")] - pub fn merge(mut self, other: R) -> Self + pub fn merge(mut self, other: R) -> Self where - R: Into>, + // TODO(david): can we use a different state type here? Since the state cannot be changed + // and has already been provided? + R: Into>, { let Router { + state: _, routes, node, fallback, @@ -271,8 +304,10 @@ where .get(&id) .expect("no path for route id. This is a bug in axum. Please file an issue"); self = match route { - Endpoint::MethodRouter(route) => self.route(path, route), - Endpoint::Route(route) => self.route(path, route), + Endpoint::MethodRouter(method_router) => { + self.route(path, method_router.downcast_state()) + } + Endpoint::Route(route) => self.route_service(path, route), }; } @@ -291,7 +326,7 @@ where } #[doc = include_str!("../docs/routing/layer.md")] - pub fn layer(self, layer: L) -> Router + pub fn layer(self, layer: L) -> Router where L: Layer>, L::Service: @@ -324,6 +359,7 @@ where let fallback = self.fallback.map(|svc| Route::new(layer.layer(svc))); Router { + state: self.state, routes, node: self.node, fallback, @@ -362,6 +398,7 @@ where .collect(); Router { + state: self.state, routes, node: self.node, fallback: self.fallback, @@ -370,7 +407,16 @@ where } #[doc = include_str!("../docs/routing/fallback.md")] - pub fn fallback(mut self, svc: T) -> Self + pub fn fallback(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + let state = self.state.clone(); + self.fallback_service(handler.into_service(state)) + } + + pub fn fallback_service(mut self, svc: T) -> Self where T: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, T::Future: Send + 'static, @@ -453,7 +499,7 @@ where .clone(); match &mut route { - Endpoint::MethodRouter(inner) => inner.call(req), + Endpoint::MethodRouter(inner) => inner.clone().with_state(self.state.clone()).call(req), Endpoint::Route(inner) => inner.call(req), } } @@ -470,9 +516,10 @@ where } } -impl Service> for Router +impl Service> for Router where B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, { type Response = Response; type Error = Infallible; @@ -587,12 +634,15 @@ impl Fallback { } } -enum Endpoint { - MethodRouter(MethodRouter), +enum Endpoint { + MethodRouter(MethodRouter), Route(Route), } -impl Clone for Endpoint { +impl Clone for Endpoint +where + S: Clone, +{ fn clone(&self) -> Self { match self { Endpoint::MethodRouter(inner) => Endpoint::MethodRouter(inner.clone()), @@ -601,7 +651,10 @@ impl Clone for Endpoint { } } -impl fmt::Debug for Endpoint { +impl fmt::Debug for Endpoint +where + S: fmt::Debug, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::MethodRouter(inner) => inner.fmt(f), @@ -614,5 +667,5 @@ impl fmt::Debug for Endpoint { #[allow(warnings)] fn traits() { use crate::test_helpers::*; - assert_send::>(); + assert_send::>(); } diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index 498740772c..3433264a1c 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -5,7 +5,7 @@ use crate::handler::Handler; async fn basic() { let app = Router::new() .route("/foo", get(|| async {})) - .fallback((|| async { "fallback" }).into_service()); + .fallback(|| async { "fallback" }); let client = TestClient::new(app); @@ -20,7 +20,7 @@ async fn basic() { async fn nest() { let app = Router::new() .nest("/foo", Router::new().route("/bar", get(|| async {}))) - .fallback((|| async { "fallback" }).into_service()); + .fallback(|| async { "fallback" }); let client = TestClient::new(app); @@ -36,9 +36,7 @@ async fn or() { let one = Router::new().route("/one", get(|| async {})); let two = Router::new().route("/two", get(|| async {})); - let app = one - .merge(two) - .fallback((|| async { "fallback" }).into_service()); + let app = one.merge(two).fallback(|| async { "fallback" }); let client = TestClient::new(app); diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 5976f8a0fc..204bd4bd61 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -100,7 +100,7 @@ async fn routing() { #[tokio::test] async fn router_type_doesnt_change() { - let app: Router = Router::new() + let app: Router<()> = Router::new() .route( "/", on(MethodFilter::GET, |_: Request| async { @@ -148,7 +148,10 @@ async fn routing_between_services() { }), ), ) - .route("/two", on_service(MethodFilter::GET, handle.into_service())); + .route( + "/two", + on_service(MethodFilter::GET, handle.into_service(())), + ); let client = TestClient::new(app); @@ -365,7 +368,7 @@ async fn wildcard_with_trailing_slash() { path: String, } - let app: Router = Router::new().route( + let app: Router<()> = Router::new().route( "/:user/:repo/tree/*path", get(|Path(tree): Path| async move { Json(tree) }), ); @@ -467,7 +470,7 @@ async fn middleware_still_run_for_unmatched_requests() { expected = "Invalid route: `Router::route` cannot be used with `Router`s. Use `Router::nest` instead" )] async fn routing_to_router_panics() { - TestClient::new(Router::new().route("/", Router::new())); + TestClient::new(Router::new().route_service("/", Router::new())); } #[tokio::test] @@ -507,7 +510,7 @@ async fn route_layer() { )] async fn good_error_message_if_using_nest_root() { let app = Router::new() - .nest("/", get(|| async {})) + .nest("/", get(|| async {}).with_state(())) .route("/", get(|| async {})); TestClient::new(app); } @@ -520,7 +523,7 @@ async fn good_error_message_if_using_nest_root() { Use `Router::fallback` instead" )] async fn good_error_message_if_using_nest_root_when_merging() { - let one = Router::new().nest("/", get(|| async {})); + let one = Router::new().nest("/", get(|| async {}).with_state(())); let two = Router::new().route("/", get(|| async {})); let app = one.merge(two); TestClient::new(app); @@ -570,8 +573,8 @@ async fn different_methods_added_in_different_routes_deeply_nested() { #[should_panic(expected = "Cannot merge two `Router`s that both have a fallback")] async fn merging_routers_with_fallbacks_panics() { async fn fallback() {} - let one = Router::new().fallback(fallback.into_service()); - let two = Router::new().fallback(fallback.into_service()); + let one = Router::new().fallback(fallback); + let two = Router::new().fallback(fallback); TestClient::new(one.merge(two)); } @@ -579,7 +582,7 @@ async fn merging_routers_with_fallbacks_panics() { #[should_panic(expected = "Cannot nest `Router`s that has a fallback")] async fn nesting_router_with_fallbacks_panics() { async fn fallback() {} - let one = Router::new().fallback(fallback.into_service()); + let one = Router::new().fallback(fallback); let app = Router::new().nest("/", one); TestClient::new(app); } @@ -619,7 +622,7 @@ async fn head_content_length_through_hyper_server() { #[tokio::test] async fn head_content_length_through_hyper_server_that_hits_fallback() { - let app = Router::new().fallback((|| async { "foo" }).into_service()); + let app = Router::new().fallback(|| async { "foo" }); let client = TestClient::new(app); diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index d4cb466c56..76e2b9a70e 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -117,7 +117,10 @@ async fn nesting_router_at_empty_path() { #[tokio::test] async fn nesting_handler_at_root() { - let app = Router::new().nest("/", get(|uri: Uri| async move { uri.to_string() })); + let app = Router::new().nest( + "/", + get(|uri: Uri| async move { uri.to_string() }).with_state(()), + ); let client = TestClient::new(app); @@ -186,7 +189,7 @@ async fn nested_service_sees_stripped_uri() { "/foo", Router::new().nest( "/bar", - Router::new().route( + Router::new().route_service( "/baz", service_fn(|req: Request| async move { let body = boxed(Body::from(req.uri().to_string())); @@ -207,12 +210,14 @@ async fn nested_service_sees_stripped_uri() { async fn nest_static_file_server() { let app = Router::new().nest( "/static", - get_service(ServeDir::new(".")).handle_error(|error| async move { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Unhandled internal error: {}", error), - ) - }), + get_service(ServeDir::new(".")) + .handle_error(|error| async move { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Unhandled internal error: {}", error), + ) + }) + .with_state(()), ); let client = TestClient::new(app); @@ -294,7 +299,7 @@ async fn multiple_top_level_nests() { #[tokio::test] #[should_panic(expected = "Invalid route: nested routes cannot contain wildcards (*)")] async fn nest_cannot_contain_wildcards() { - Router::::new().nest("/one/*rest", Router::new()); + Router::<_, Body>::new().nest("/one/*rest", Router::new()); } #[tokio::test] @@ -333,7 +338,7 @@ async fn outer_middleware_still_see_whole_url() { .route("/foo", get(handler)) .route("/foo/bar", get(handler)) .nest("/one", Router::new().route("/two", get(handler))) - .fallback(handler.into_service()) + .fallback(handler) .layer(tower::layer::layer_fn(SetUriExtension)); let client = TestClient::new(app); From d9cc822a29f0edfa8f5c2ff654359b9afea20d9f Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sat, 9 Jul 2022 23:51:32 +0200 Subject: [PATCH 02/45] Pass state to extractors --- axum-core/src/extract/mod.rs | 36 +++++++++++++------- axum-core/src/extract/request_parts.rs | 43 ++++++++++++++---------- axum-core/src/extract/tuple.rs | 12 ++++--- axum-extra/src/extract/cached.rs | 14 ++++---- axum-extra/src/extract/cookie/mod.rs | 9 ++--- axum-extra/src/extract/cookie/private.rs | 7 ++-- axum-extra/src/extract/cookie/signed.rs | 7 ++-- axum-extra/src/extract/form.rs | 7 ++-- axum-extra/src/extract/query.rs | 5 +-- axum-extra/src/json_lines.rs | 5 +-- axum/src/error_handling/mod.rs | 4 +-- axum/src/extension.rs | 5 +-- axum/src/extract/connect_info.rs | 7 ++-- axum/src/extract/content_length_limit.rs | 7 ++-- axum/src/extract/host.rs | 5 +-- axum/src/extract/matched_path.rs | 5 +-- axum/src/extract/mod.rs | 8 +++-- axum/src/extract/multipart.rs | 5 +-- axum/src/extract/path/mod.rs | 5 +-- axum/src/extract/query.rs | 8 +++-- axum/src/extract/raw_query.rs | 5 +-- axum/src/extract/request_parts.rs | 15 +++++---- axum/src/extract/state.rs | 21 ++++++++++++ axum/src/extract/ws.rs | 9 ++--- axum/src/form.rs | 8 +++-- axum/src/handler/into_service.rs | 3 +- axum/src/handler/mod.rs | 29 +++++++++------- axum/src/json.rs | 7 ++-- axum/src/middleware/from_extractor.rs | 19 ++++++----- axum/src/middleware/from_fn.rs | 4 +-- axum/src/routing/tests/mod.rs | 26 +++++++++++++- axum/src/typed_header.rs | 5 +-- 32 files changed, 230 insertions(+), 125 deletions(-) create mode 100644 axum/src/extract/state.rs diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index 2316633be5..de2320828c 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -60,20 +60,21 @@ mod tuple; /// [`http::Request`]: http::Request /// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html #[async_trait] -pub trait FromRequest: Sized { +pub trait FromRequest: Sized { /// If the extractor fails it'll use this "rejection" type. A rejection is /// a kind of error that can be converted into a response. type Rejection: IntoResponse; /// Perform the extraction. - async fn from_request(req: &mut RequestParts) -> Result; + async fn from_request(req: &mut RequestParts) -> Result; } /// The type used with [`FromRequest`] to extract data from requests. /// /// Has several convenience methods for getting owned parts of the request. #[derive(Debug)] -pub struct RequestParts { +pub struct RequestParts { + state: S, method: Method, uri: Uri, version: Version, @@ -82,7 +83,7 @@ pub struct RequestParts { body: Option, } -impl RequestParts { +impl RequestParts { /// Create a new `RequestParts`. /// /// You generally shouldn't need to construct this type yourself, unless @@ -90,7 +91,7 @@ impl RequestParts { /// [`tower::Service`]. /// /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html - pub fn new(req: Request) -> Self { + pub fn new(state: S, req: Request) -> Self { let ( http::request::Parts { method, @@ -104,6 +105,7 @@ impl RequestParts { ) = req.into_parts(); RequestParts { + state, method, uri, version, @@ -141,7 +143,10 @@ impl RequestParts { /// } /// } /// ``` - pub async fn extract>(&mut self) -> Result { + pub async fn extract(&mut self) -> Result + where + E: FromRequest, + { E::from_request(self).await } @@ -153,6 +158,7 @@ impl RequestParts { /// [`take_body`]: RequestParts::take_body pub fn try_into_request(self) -> Result, BodyAlreadyExtracted> { let Self { + state: _, method, uri, version, @@ -245,30 +251,36 @@ impl RequestParts { pub fn take_body(&mut self) -> Option { self.body.take() } + + pub fn state(&self) -> &S { + &self.state + } } #[async_trait] -impl FromRequest for Option +impl FromRequest for Option where - T: FromRequest, + T: FromRequest, B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result, Self::Rejection> { + async fn from_request(req: &mut RequestParts) -> Result, Self::Rejection> { Ok(T::from_request(req).await.ok()) } } #[async_trait] -impl FromRequest for Result +impl FromRequest for Result where - T: FromRequest, + T: FromRequest, B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(T::from_request(req).await) } } diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index 33383a7d8d..4faaf2d355 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -6,16 +6,18 @@ use http::{Extensions, HeaderMap, Method, Request, Uri, Version}; use std::convert::Infallible; #[async_trait] -impl FromRequest for Request +impl FromRequest for Request where B: Send, + S: Clone + Send, { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let req = std::mem::replace( req, RequestParts { + state: req.state().clone(), method: req.method.clone(), version: req.version, uri: req.uri.clone(), @@ -30,37 +32,40 @@ where } #[async_trait] -impl FromRequest for Method +impl FromRequest for Method where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.method().clone()) } } #[async_trait] -impl FromRequest for Uri +impl FromRequest for Uri where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.uri().clone()) } } #[async_trait] -impl FromRequest for Version +impl FromRequest for Version where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.version()) } } @@ -71,27 +76,29 @@ where /// /// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html #[async_trait] -impl FromRequest for HeaderMap +impl FromRequest for HeaderMap where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.headers().clone()) } } #[async_trait] -impl FromRequest for Bytes +impl FromRequest for Bytes where B: http_body::Body + Send, B::Data: Send, B::Error: Into, + S: Send, { type Rejection = BytesRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; let bytes = crate::body::to_bytes(body) @@ -103,15 +110,16 @@ where } #[async_trait] -impl FromRequest for String +impl FromRequest for String where B: http_body::Body + Send, B::Data: Send, B::Error: Into, + S: Send, { type Rejection = StringRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; let bytes = crate::body::to_bytes(body) @@ -126,13 +134,14 @@ where } #[async_trait] -impl FromRequest for http::request::Parts +impl FromRequest for http::request::Parts where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let method = unwrap_infallible(Method::from_request(req).await); let uri = unwrap_infallible(Uri::from_request(req).await); let version = unwrap_infallible(Version::from_request(req).await); @@ -159,6 +168,6 @@ fn unwrap_infallible(result: Result) -> T { } } -pub(crate) fn take_body(req: &mut RequestParts) -> Result { +pub(crate) fn take_body(req: &mut RequestParts) -> Result { req.take_body().ok_or(BodyAlreadyExtracted) } diff --git a/axum-core/src/extract/tuple.rs b/axum-core/src/extract/tuple.rs index 8c781a8d3e..05e38bf004 100644 --- a/axum-core/src/extract/tuple.rs +++ b/axum-core/src/extract/tuple.rs @@ -4,13 +4,14 @@ use async_trait::async_trait; use std::convert::Infallible; #[async_trait] -impl FromRequest for () +impl FromRequest for () where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(_: &mut RequestParts) -> Result<(), Self::Rejection> { + async fn from_request(_: &mut RequestParts) -> Result<(), Self::Rejection> { Ok(()) } } @@ -21,14 +22,15 @@ macro_rules! impl_from_request { ( $($ty:ident),* $(,)? ) => { #[async_trait] #[allow(non_snake_case)] - impl FromRequest for ($($ty,)*) + impl FromRequest for ($($ty,)*) where - $( $ty: FromRequest + Send, )* + $( $ty: FromRequest + Send, )* B: Send, + S: Send, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { $( let $ty = $ty::from_request(req).await.map_err(|err| err.into_response())?; )* Ok(($($ty,)*)) } diff --git a/axum-extra/src/extract/cached.rs b/axum-extra/src/extract/cached.rs index 0ada78a888..9ced87e450 100644 --- a/axum-extra/src/extract/cached.rs +++ b/axum-extra/src/extract/cached.rs @@ -88,14 +88,15 @@ pub struct Cached(pub T); struct CachedEntry(T); #[async_trait] -impl FromRequest for Cached +impl FromRequest for Cached where B: Send, - T: FromRequest + Clone + Send + Sync + 'static, + S: Send, + T: FromRequest + Clone + Send + Sync + 'static, { type Rejection = T::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match Extension::>::from_request(req).await { Ok(Extension(CachedEntry(value))) => Ok(Self(value)), Err(_) => { @@ -139,19 +140,20 @@ mod tests { struct Extractor(Instant); #[async_trait] - impl FromRequest for Extractor + impl FromRequest for Extractor where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { COUNTER.fetch_add(1, Ordering::SeqCst); Ok(Self(Instant::now())) } } - let mut req = RequestParts::new(Request::new(())); + let mut req = RequestParts::new((), Request::new(())); let first = Cached::::from_request(&mut req).await.unwrap().0; assert_eq!(COUNTER.load(Ordering::SeqCst), 1); diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index 25d6d1fa3b..2c0b2f3e28 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -88,13 +88,14 @@ pub struct CookieJar { } #[async_trait] -impl FromRequest for CookieJar +impl FromRequest for CookieJar where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let mut jar = cookie_lib::CookieJar::new(); for cookie in cookies_from_request(req) { jar.add_original(cookie); @@ -103,8 +104,8 @@ where } } -fn cookies_from_request( - req: &mut RequestParts, +fn cookies_from_request( + req: &mut RequestParts, ) -> impl Iterator> + '_ { req.headers() .get_all(COOKIE) diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index a583a3c9a9..57adb74e29 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -74,14 +74,15 @@ impl fmt::Debug for PrivateCookieJar { } #[async_trait] -impl FromRequest for PrivateCookieJar +impl FromRequest for PrivateCookieJar where B: Send, + S: Send, K: Into + Clone + Send + Sync + 'static, { - type Rejection = as FromRequest>::Rejection; + type Rejection = as FromRequest>::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let key = Extension::::from_request(req).await?.0.into(); let mut jar = cookie_lib::CookieJar::new(); diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index 51fb865c6e..cab8d6af9f 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -92,14 +92,15 @@ impl fmt::Debug for SignedCookieJar { } #[async_trait] -impl FromRequest for SignedCookieJar +impl FromRequest for SignedCookieJar where B: Send, + S: Send, K: Into + Clone + Send + Sync + 'static, { - type Rejection = as FromRequest>::Rejection; + type Rejection = as FromRequest>::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let key = Extension::::from_request(req).await?.0.into(); let mut jar = cookie_lib::CookieJar::new(); diff --git a/axum-extra/src/extract/form.rs b/axum-extra/src/extract/form.rs index bcd8c1809f..ec8091e3aa 100644 --- a/axum-extra/src/extract/form.rs +++ b/axum-extra/src/extract/form.rs @@ -54,16 +54,17 @@ impl Deref for Form { } #[async_trait] -impl FromRequest for Form +impl FromRequest for Form where T: DeserializeOwned, B: HttpBody + Send, B::Data: Send, B::Error: Into, + S: Send, { type Rejection = FormRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if req.method() == Method::GET { let query = req.uri().query().unwrap_or_default(); let value = serde_html_form::from_str(query) @@ -84,7 +85,7 @@ where } // this is duplicated in `axum/src/extract/mod.rs` -fn has_content_type(req: &RequestParts, expected_content_type: &mime::Mime) -> bool { +fn has_content_type(req: &RequestParts, expected_content_type: &mime::Mime) -> bool { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { content_type } else { diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index dcbcdb31fd..f53a8a1a5f 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -58,14 +58,15 @@ use std::ops::Deref; pub struct Query(pub T); #[async_trait] -impl FromRequest for Query +impl FromRequest for Query where T: DeserializeOwned, B: Send, + S: Send, { type Rejection = QueryRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let query = req.uri().query().unwrap_or_default(); let value = serde_html_form::from_str(query) .map_err(FailedToDeserializeQueryString::__private_new::)?; diff --git a/axum-extra/src/json_lines.rs b/axum-extra/src/json_lines.rs index 1bc8d1a0bd..242b43e70f 100644 --- a/axum-extra/src/json_lines.rs +++ b/axum-extra/src/json_lines.rs @@ -98,16 +98,17 @@ impl JsonLines { } #[async_trait] -impl FromRequest for JsonLines +impl FromRequest for JsonLines where B: HttpBody + Send + 'static, B::Data: Into, B::Error: Into, T: DeserializeOwned, + S: Send, { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { // `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead` // so we can call `AsyncRead::lines` and then convert it back to a `Stream` diff --git a/axum/src/error_handling/mod.rs b/axum/src/error_handling/mod.rs index e0027eed33..3c5779a2d5 100644 --- a/axum/src/error_handling/mod.rs +++ b/axum/src/error_handling/mod.rs @@ -163,7 +163,7 @@ macro_rules! impl_service { F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static, Fut: Future + Send, Res: IntoResponse, - $( $ty: FromRequest + Send,)* + $( $ty: FromRequest<(), ReqBody> + Send,)* ReqBody: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, @@ -185,7 +185,7 @@ macro_rules! impl_service { let inner = std::mem::replace(&mut self.inner, clone); let future = Box::pin(async move { - let mut req = RequestParts::new(req); + let mut req = RequestParts::new((), req); $( let $ty = match $ty::from_request(&mut req).await { diff --git a/axum/src/extension.rs b/axum/src/extension.rs index 390f29e031..d040b9e4e5 100644 --- a/axum/src/extension.rs +++ b/axum/src/extension.rs @@ -73,14 +73,15 @@ use tower_service::Service; pub struct Extension(pub T); #[async_trait] -impl FromRequest for Extension +impl FromRequest for Extension where T: Clone + Send + Sync + 'static, B: Send, + S: Send, { type Rejection = ExtensionRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let value = req .extensions() .get::() diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index 8363a25ebf..3aa7684c81 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -128,14 +128,15 @@ opaque_future! { pub struct ConnectInfo(pub T); #[async_trait] -impl FromRequest for ConnectInfo +impl FromRequest for ConnectInfo where B: Send, + S: Send, T: Clone + Send + Sync + 'static, { - type Rejection = as FromRequest>::Rejection; + type Rejection = as FromRequest>::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let Extension(connect_info) = Extension::::from_request(req).await?; Ok(connect_info) } diff --git a/axum/src/extract/content_length_limit.rs b/axum/src/extract/content_length_limit.rs index 62148a8ea7..f4c475437f 100644 --- a/axum/src/extract/content_length_limit.rs +++ b/axum/src/extract/content_length_limit.rs @@ -36,15 +36,16 @@ use std::ops::Deref; pub struct ContentLengthLimit(pub T); #[async_trait] -impl FromRequest for ContentLengthLimit +impl FromRequest for ContentLengthLimit where - T: FromRequest, + T: FromRequest, T::Rejection: IntoResponse, B: Send, + S: Send, { type Rejection = ContentLengthLimitRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let content_length = req .headers() .get(http::header::CONTENT_LENGTH) diff --git a/axum/src/extract/host.rs b/axum/src/extract/host.rs index 85ab8a9dcd..79ae13fc28 100644 --- a/axum/src/extract/host.rs +++ b/axum/src/extract/host.rs @@ -21,13 +21,14 @@ const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; pub struct Host(pub String); #[async_trait] -impl FromRequest for Host +impl FromRequest for Host where B: Send, + S: Send, { type Rejection = HostRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if let Some(host) = parse_forwarded(req.headers()) { return Ok(Host(host.to_owned())); } diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index 125422f203..bce7bf67b2 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -64,13 +64,14 @@ impl MatchedPath { } #[async_trait] -impl FromRequest for MatchedPath +impl FromRequest for MatchedPath where B: Send, + S: Send, { type Rejection = MatchedPathRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let matched_path = req .extensions() .get::() diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 6f0435f6e9..c4aeb3ad70 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -14,6 +14,7 @@ mod content_length_limit; mod host; mod raw_query; mod request_parts; +mod state; #[doc(inline)] pub use axum_core::extract::{FromRequest, RequestParts}; @@ -27,6 +28,7 @@ pub use self::{ path::Path, raw_query::RawQuery, request_parts::{BodyStream, RawBody}, + state::State, }; #[doc(no_inline)] @@ -73,13 +75,13 @@ pub use self::ws::WebSocketUpgrade; #[doc(no_inline)] pub use crate::TypedHeader; -pub(crate) fn take_body(req: &mut RequestParts) -> Result { +pub(crate) fn take_body(req: &mut RequestParts) -> Result { req.take_body().ok_or_else(BodyAlreadyExtracted::default) } // this is duplicated in `axum-extra/src/extract/form.rs` -pub(super) fn has_content_type( - req: &RequestParts, +pub(super) fn has_content_type( + req: &RequestParts, expected_content_type: &mime::Mime, ) -> bool { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index 8f58455fb9..6b5b46f779 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -50,14 +50,15 @@ pub struct Multipart { } #[async_trait] -impl FromRequest for Multipart +impl FromRequest for Multipart where B: HttpBody + Default + Unpin + Send + 'static, B::Error: Into, + S: Send, { type Rejection = MultipartRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let stream = BodyStream::from_request(req).await?; let headers = req.headers(); let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?; diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index f559ae1fb0..4b04cce65f 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -163,14 +163,15 @@ impl DerefMut for Path { } #[async_trait] -impl FromRequest for Path +impl FromRequest for Path where T: DeserializeOwned + Send, B: Send, + S: Send, { type Rejection = PathRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let params = match req.extensions_mut().get::() { Some(UrlParams::Params(params)) => params, Some(UrlParams::InvalidUtf8InPathParam { key }) => { diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index c267ce05fc..bc462eeb1b 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -49,14 +49,15 @@ use std::ops::Deref; pub struct Query(pub T); #[async_trait] -impl FromRequest for Query +impl FromRequest for Query where T: DeserializeOwned, B: Send, + S: Send, { type Rejection = QueryRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let query = req.uri().query().unwrap_or_default(); let value = serde_urlencoded::from_str(query) .map_err(FailedToDeserializeQueryString::__private_new::)?; @@ -81,7 +82,8 @@ mod tests { use std::fmt::Debug; async fn check(uri: impl AsRef, value: T) { - let mut req = RequestParts::new(Request::builder().uri(uri.as_ref()).body(()).unwrap()); + let req = Request::builder().uri(uri.as_ref()).body(()).unwrap(); + let mut req = RequestParts::new((), req); assert_eq!(Query::::from_request(&mut req).await.unwrap().0, value); } diff --git a/axum/src/extract/raw_query.rs b/axum/src/extract/raw_query.rs index 463c31d88a..faf8df6e4c 100644 --- a/axum/src/extract/raw_query.rs +++ b/axum/src/extract/raw_query.rs @@ -27,13 +27,14 @@ use std::convert::Infallible; pub struct RawQuery(pub Option); #[async_trait] -impl FromRequest for RawQuery +impl FromRequest for RawQuery where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let query = req.uri().query().map(|query| query.to_owned()); Ok(Self(query)) } diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs index 02b044c50c..c04ff5349b 100644 --- a/axum/src/extract/request_parts.rs +++ b/axum/src/extract/request_parts.rs @@ -86,13 +86,14 @@ pub struct OriginalUri(pub Uri); #[cfg(feature = "original-uri")] #[async_trait] -impl FromRequest for OriginalUri +impl FromRequest for OriginalUri where B: Send, + S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let uri = Extension::::from_request(req) .await .unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone()))) @@ -140,15 +141,16 @@ impl Stream for BodyStream { } #[async_trait] -impl FromRequest for BodyStream +impl FromRequest for BodyStream where B: HttpBody + Send + 'static, B::Data: Into, B::Error: Into, + S: Send, { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)? .map_data(Into::into) .map_err(|err| Error::new(err.into())); @@ -196,13 +198,14 @@ fn body_stream_traits() { pub struct RawBody(pub B); #[async_trait] -impl FromRequest for RawBody +impl FromRequest for RawBody where B: Send, + S: Send, { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; Ok(Self(body)) } diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs new file mode 100644 index 0000000000..fb9aa26e76 --- /dev/null +++ b/axum/src/extract/state.rs @@ -0,0 +1,21 @@ +use async_trait::async_trait; +use axum_core::extract::{FromRequest, RequestParts}; +use std::convert::Infallible; + +#[derive(Debug, Clone)] +pub struct State(pub S); + +#[async_trait] +impl FromRequest for State +where + B: Send, + OuterState: Clone + Into + Send, +{ + type Rejection = Infallible; + + async fn from_request(req: &mut RequestParts) -> Result { + let outer_state = req.state().clone(); + let inner_state = outer_state.into(); + Ok(Self(inner_state)) + } +} diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 7ac75ab697..2840ad635c 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -244,13 +244,14 @@ impl WebSocketUpgrade { } #[async_trait] -impl FromRequest for WebSocketUpgrade +impl FromRequest for WebSocketUpgrade where B: Send, + S: Send, { type Rejection = WebSocketUpgradeRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if req.method() != Method::GET { return Err(MethodNotGet.into()); } @@ -288,7 +289,7 @@ where } } -fn header_eq(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { +fn header_eq(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { if let Some(header) = req.headers().get(&key) { header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) } else { @@ -296,7 +297,7 @@ fn header_eq(req: &RequestParts, key: HeaderName, value: &'static str) -> } } -fn header_contains(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { +fn header_contains(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { let header = if let Some(header) = req.headers().get(&key) { header } else { diff --git a/axum/src/form.rs b/axum/src/form.rs index 9974a46036..7542841620 100644 --- a/axum/src/form.rs +++ b/axum/src/form.rs @@ -56,16 +56,17 @@ use std::ops::Deref; pub struct Form(pub T); #[async_trait] -impl FromRequest for Form +impl FromRequest for Form where T: DeserializeOwned, B: HttpBody + Send, B::Data: Send, B::Error: Into, + S: Send, { type Rejection = FormRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if req.method() == Method::GET { let query = req.uri().query().unwrap_or_default(); let value = serde_urlencoded::from_str(query) @@ -126,6 +127,7 @@ mod tests { async fn check_query(uri: impl AsRef, value: T) { let mut req = RequestParts::new( + (), Request::builder() .uri(uri.as_ref()) .body(Empty::::new()) @@ -136,6 +138,7 @@ mod tests { async fn check_body(value: T) { let mut req = RequestParts::new( + (), Request::builder() .uri("http://example.com/test") .method(Method::POST) @@ -205,6 +208,7 @@ mod tests { #[tokio::test] async fn test_incorrect_content_type() { let mut req = RequestParts::new( + (), Request::builder() .uri("http://example.com/test") .method(Method::POST) diff --git a/axum/src/handler/into_service.rs b/axum/src/handler/into_service.rs index 73775133cc..590af8dd4e 100644 --- a/axum/src/handler/into_service.rs +++ b/axum/src/handler/into_service.rs @@ -61,6 +61,7 @@ impl Service> for IntoService where H: Handler + Clone + Send + 'static, B: Send + 'static, + S: Clone, { type Response = Response; type Error = Infallible; @@ -78,7 +79,7 @@ where use futures_util::future::FutureExt; let handler = self.handler.clone(); - let future = Handler::call(handler, req); + let future = Handler::call(handler, self.state.clone(), req); let future = future.map(Ok as _); super::future::IntoServiceFuture::new(future) diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 8c819744a8..dcb04973db 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -68,7 +68,7 @@ pub trait Handler: Clone + Send + Sized + 'static { type Future: Future + Send + 'static; /// Call the handler with the given request. - fn call(self, req: Request) -> Self::Future; + fn call(self, state: S, req: Request) -> Self::Future; /// Apply a [`tower::Layer`] to the handler. /// @@ -222,7 +222,7 @@ where { type Future = Pin + Send>>; - fn call(self, _req: Request) -> Self::Future { + fn call(self, _state: S, _req: Request) -> Self::Future { Box::pin(async move { self().await.into_response() }) } } @@ -235,14 +235,15 @@ macro_rules! impl_handler { F: FnOnce($($ty,)*) -> Fut + Clone + Send + 'static, Fut: Future + Send, B: Send + 'static, + S: Send + 'static, Res: IntoResponse, - $( $ty: FromRequest + Send,)* + $( $ty: FromRequest + Send,)* { type Future = Pin + Send>>; - fn call(self, req: Request) -> Self::Future { + fn call(self, state: S, req: Request) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::new(req); + let mut req = RequestParts::new(state, req); $( let $ty = match $ty::from_request(&mut req).await { @@ -301,16 +302,18 @@ where { type Future = future::LayeredFuture; - fn call(self, req: Request) -> Self::Future { - use futures_util::future::{FutureExt, Map}; + fn call(self, state: S, req: Request) -> Self::Future { + todo!() + + // use futures_util::future::{FutureExt, Map}; - let future: Map<_, fn(Result) -> _> = - self.svc.oneshot(req).map(|result| match result { - Ok(res) => res.map(boxed), - Err(res) => res.into_response(), - }); + // let future: Map<_, fn(Result) -> _> = + // self.svc.oneshot(req).map(|result| match result { + // Ok(res) => res.map(boxed), + // Err(res) => res.into_response(), + // }); - future::LayeredFuture::new(future) + // future::LayeredFuture::new(future) } } diff --git a/axum/src/json.rs b/axum/src/json.rs index e1d068fa1b..a7bfedcb9d 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -93,16 +93,17 @@ use std::ops::{Deref, DerefMut}; pub struct Json(pub T); #[async_trait] -impl FromRequest for Json +impl FromRequest for Json where T: DeserializeOwned, B: HttpBody + Send, B::Data: Send, B::Error: Into, + S: Send, { type Rejection = JsonRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if json_content_type(req) { let bytes = Bytes::from_request(req).await?; @@ -135,7 +136,7 @@ where } } -fn json_content_type(req: &RequestParts) -> bool { +fn json_content_type(req: &RequestParts) -> bool { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { content_type } else { diff --git a/axum/src/middleware/from_extractor.rs b/axum/src/middleware/from_extractor.rs index 45bb951380..2b2884c7a0 100644 --- a/axum/src/middleware/from_extractor.rs +++ b/axum/src/middleware/from_extractor.rs @@ -168,7 +168,7 @@ where impl Service> for FromExtractor where - E: FromRequest + 'static, + E: FromRequest<(), ReqBody> + 'static, ReqBody: Default + Send + 'static, S: Service, Response = Response> + Clone, ResBody: HttpBody + Send + 'static, @@ -185,7 +185,7 @@ where fn call(&mut self, req: Request) -> Self::Future { let extract_future = Box::pin(async move { - let mut req = RequestParts::new(req); + let mut req = RequestParts::new((), req); let extracted = E::from_request(&mut req).await; (req, extracted) }); @@ -204,7 +204,7 @@ pin_project! { #[allow(missing_debug_implementations)] pub struct ResponseFuture where - E: FromRequest, + E: FromRequest<(), ReqBody>, S: Service>, { #[pin] @@ -217,17 +217,19 @@ pin_project! { #[project = StateProj] enum State where - E: FromRequest, + E: FromRequest<(), ReqBody>, S: Service>, { - Extracting { future: BoxFuture<'static, (RequestParts, Result)> }, + Extracting { + future: BoxFuture<'static, (RequestParts<(), ReqBody>, Result)>, + }, Call { #[pin] future: S::Future }, } } impl Future for ResponseFuture where - E: FromRequest, + E: FromRequest<(), ReqBody>, S: Service, Response = Response>, ReqBody: Default, ResBody: HttpBody + Send + 'static, @@ -279,13 +281,14 @@ mod tests { struct RequireAuth; #[async_trait::async_trait] - impl FromRequest for RequireAuth + impl FromRequest for RequireAuth where B: Send, + S: Send, { type Rejection = StatusCode; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if let Some(auth) = req .headers() .get(header::AUTHORIZATION) diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index 8e46fa690f..c0cda708aa 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -259,7 +259,7 @@ macro_rules! impl_service { impl Service> for FromFn where F: FnMut($($ty),*, Next) -> Fut + Clone + Send + 'static, - $( $ty: FromRequest + Send, )* + $( $ty: FromRequest<(), ReqBody> + Send, )* Fut: Future + Send + 'static, Out: IntoResponse + 'static, S: Service, Response = Response, Error = Infallible> @@ -286,7 +286,7 @@ macro_rules! impl_service { let mut f = self.f.clone(); let future = Box::pin(async move { - let mut parts = RequestParts::new(req); + let mut parts = RequestParts::new((), req); $( let $ty = match $ty::from_request(&mut parts).await { Ok(value) => value, diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 204bd4bd61..613fc41dcb 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -1,7 +1,7 @@ use crate::{ body::{Bytes, Empty}, error_handling::HandleErrorLayer, - extract::{self, Path}, + extract::{self, Path, State}, handler::Handler, response::IntoResponse, routing::{delete, get, get_service, on, on_service, patch, patch_service, post, MethodFilter}, @@ -723,3 +723,27 @@ async fn limited_body_with_streaming_body() { .await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } + +#[tokio::test] +async fn extracting_state() { + #[derive(Clone)] + struct AppState { + inner: InnerState, + } + + #[derive(Clone)] + struct InnerState {} + + impl From for InnerState { + fn from(state: AppState) -> Self { + state.inner + } + } + + async fn handler(State(_): State, State(_): State) {} + + let state = AppState { + inner: InnerState {}, + }; + let _: Router = Router::with_state(state).route("/", get(handler)); +} diff --git a/axum/src/typed_header.rs b/axum/src/typed_header.rs index 1ce72d04f8..c28a24a81a 100644 --- a/axum/src/typed_header.rs +++ b/axum/src/typed_header.rs @@ -52,14 +52,15 @@ use std::{convert::Infallible, ops::Deref}; pub struct TypedHeader(pub T); #[async_trait] -impl FromRequest for TypedHeader +impl FromRequest for TypedHeader where T: headers::Header, B: Send, + S: Send, { type Rejection = TypedHeaderRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match req.headers().typed_try_get::() { Ok(Some(value)) => Ok(Self(value)), Ok(None) => Err(TypedHeaderRejection { From 1f95766c3548651956983f7c04196686e5905960 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sun, 10 Jul 2022 14:25:55 +0200 Subject: [PATCH 03/45] make state extractor work --- Cargo.toml | 4 +- axum-core/src/extract/mod.rs | 3 +- axum-extra/src/routing/mod.rs | 39 ++--- .../into_service_state_in_extension.rs | 19 +-- axum/src/handler/mod.rs | 84 +++++---- axum/src/lib.rs | 2 +- axum/src/routing/method_routing.rs | 159 +++++++++++------- axum/src/routing/mod.rs | 136 ++++++++------- axum/src/routing/tests/fallback.rs | 13 +- axum/src/routing/tests/merge.rs | 44 +++++ axum/src/routing/tests/mod.rs | 32 +++- 11 files changed, 335 insertions(+), 200 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a84f6fddbb..a221fbc5a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,8 +2,8 @@ members = [ "axum", "axum-core", - # "axum-extra", - # "axum-macros", + "axum-extra", + "axum-macros", # internal crate used to bump the minimum versions we # get for some dependencies which otherwise wouldn't build diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index de2320828c..11236279ee 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -84,7 +84,7 @@ pub struct RequestParts { } impl RequestParts { - /// Create a new `RequestParts`. + /// Create a new `RequestParts` with the given state. /// /// You generally shouldn't need to construct this type yourself, unless /// using extractors outside of axum for example to implement a @@ -252,6 +252,7 @@ impl RequestParts { self.body.take() } + /// Get a reference to the state. pub fn state(&self) -> &S { &self.state } diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 9c4b512a43..67da57f106 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -2,13 +2,11 @@ use axum::{ handler::Handler, - http::Request, - response::{Redirect, Response}, - routing::MethodRouter, + response::Redirect, + routing::{any, MethodRouter}, Router, }; -use std::{convert::Infallible, future::ready}; -use tower_service::Service; +use std::future::ready; mod resource; @@ -250,27 +248,22 @@ where self.route(P::PATH, axum::routing::trace(handler)) } - fn route_with_tsr(self, path: &str, method_router: MethodRouter) -> Self + fn route_with_tsr(mut self, path: &str, method_router: MethodRouter) -> Self where Self: Sized, { - todo!() - - // self = self.route(path, service); - - // let redirect = Redirect::permanent(path); - - // if let Some(path_without_trailing_slash) = path.strip_suffix('/') { - // self.route( - // path_without_trailing_slash, - // (move || ready(redirect.clone())).into_service(), - // ) - // } else { - // self.route( - // &format!("{}/", path), - // (move || ready(redirect.clone())).into_service(), - // ) - // } + self = self.route(path, method_router); + + let redirect = Redirect::permanent(path); + + if let Some(path_without_trailing_slash) = path.strip_suffix('/') { + self.route( + path_without_trailing_slash, + any(move || ready(redirect.clone())), + ) + } else { + self.route(&format!("{}/", path), any(move || ready(redirect.clone()))) + } } } diff --git a/axum/src/handler/into_service_state_in_extension.rs b/axum/src/handler/into_service_state_in_extension.rs index feb9c193a9..011161d93a 100644 --- a/axum/src/handler/into_service_state_in_extension.rs +++ b/axum/src/handler/into_service_state_in_extension.rs @@ -68,21 +68,18 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { use futures_util::future::FutureExt; let state = req - .extensions() - .get::() - .expect("state extension missing. This is a bug in axum, please file an issue") - .clone(); + .extensions_mut() + .remove::() + .expect("state extension missing. This is a bug in axum, please file an issue"); - todo!() + let handler = self.handler.clone(); + let future = Handler::call(handler, state, req); + let future = future.map(Ok as _); - // let handler = self.handler.clone(); - // let future = Handler::call(handler, req); - // let future = future.map(Ok as _); - - // super::future::IntoServiceFuture::new(future) + super::future::IntoServiceFuture::new(future) } } diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index dcb04973db..b9dd781433 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -106,12 +106,15 @@ pub trait Handler: Clone + Send + Sized + 'static { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - fn layer(self, layer: L) -> Layered + fn layer(self, layer: L) -> Layered where L: Layer>, { - todo!() - // Layered::new(layer.layer(self.into_service())) + Layered { + layer, + handler: self, + _marker: PhantomData, + } } /// Convert the handler into a [`Service`]. @@ -266,63 +269,72 @@ all_the_tuples!(impl_handler); /// A [`Service`] created from a [`Handler`] by applying a Tower middleware. /// /// Created with [`Handler::layer`]. See that method for more details. -pub struct Layered { - svc: Svc, - _input: PhantomData (T, S)>, +pub struct Layered { + layer: L, + handler: H, + _marker: PhantomData (T, S, B)>, } -impl fmt::Debug for Layered +impl fmt::Debug for Layered where - Svc: fmt::Debug, + L: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Layered").field("svc", &self.svc).finish() + f.debug_struct("Layered") + .field("layer", &self.layer) + .finish() } } -impl Clone for Layered +impl Clone for Layered where - Svc: Clone, + L: Clone, + H: Clone, { fn clone(&self) -> Self { - Self::new(self.svc.clone()) + Self { + layer: self.layer.clone(), + handler: self.handler.clone(), + _marker: PhantomData, + } } } -impl Handler for Layered +impl Handler for Layered where - Svc: Service, Response = Response> + Clone + Send + 'static, - Svc::Error: IntoResponse, - Svc::Future: Send, + L: Layer> + Clone + Send + 'static, + H: Handler, + L::Service: Service, Response = Response> + Clone + Send + 'static, + >>::Error: IntoResponse, + >>::Future: Send, T: 'static, S: 'static, - ReqBody: Send + 'static, + B: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { - type Future = future::LayeredFuture; + type Future = future::LayeredFuture; - fn call(self, state: S, req: Request) -> Self::Future { - todo!() + fn call(self, state: S, req: Request) -> Self::Future { + use futures_util::future::{FutureExt, Map}; - // use futures_util::future::{FutureExt, Map}; + let svc = self.handler.into_service(state); + let svc = self.layer.layer(svc); - // let future: Map<_, fn(Result) -> _> = - // self.svc.oneshot(req).map(|result| match result { - // Ok(res) => res.map(boxed), - // Err(res) => res.into_response(), - // }); + let future: Map< + _, + fn( + Result< + >>::Response, + >>::Error, + >, + ) -> _, + > = svc.oneshot(req).map(|result| match result { + Ok(res) => res.map(boxed), + Err(res) => res.into_response(), + }); - // future::LayeredFuture::new(future) - } -} - -impl Layered { - pub(crate) fn new(svc: Svc) -> Self { - Self { - svc, - _input: PhantomData, - } + future::LayeredFuture::new(future) } } diff --git a/axum/src/lib.rs b/axum/src/lib.rs index be095caaaa..fabddf499a 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -384,7 +384,7 @@ future_incompatible, nonstandard_style, missing_debug_implementations, - missing_docs + // missing_docs )] #![deny(unreachable_pub, private_in_public)] #![allow(elided_lifetimes_in_paths, clippy::type_complexity)] diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 1353df1cd9..5a7236accc 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -391,7 +391,9 @@ where ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { - MethodRouter::new().fallback(svc).skip_allow_header() + MethodRouter::new() + .fallback_service(svc) + .skip_allow_header() } top_level_handler_fn!(delete, DELETE); @@ -684,6 +686,15 @@ where pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { IntoMakeServiceWithConnectInfo::new(self) } + + pub fn fallback(self, handler: H) -> Self + where + H: Handler, + T: 'static, + S: Clone + Send + Sync + 'static, + { + self.fallback_service(IntoServiceStateInExtension::new(handler)) + } } impl MethodRouter { @@ -735,7 +746,7 @@ impl MethodRouter { chained_service_fn!(trace_service, TRACE); #[doc = include_str!("../docs/method_routing/fallback.md")] - pub fn fallback(mut self, svc: T) -> Self + pub fn fallback_service(mut self, svc: T) -> Self where T: Service, Response = Response, Error = E> + Clone @@ -1069,6 +1080,12 @@ pub struct MethodRouterWithState { state: S, } +impl MethodRouterWithState { + pub fn state(&self) -> &S { + &self.state + } +} + impl Clone for MethodRouterWithState where S: Clone, @@ -1096,6 +1113,7 @@ where impl Service> for MethodRouterWithState where B: HttpBody, + S: Clone + Send + Sync + 'static, { type Response = Response; type Error = E; @@ -1106,7 +1124,7 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { macro_rules! call { ( $req:expr, @@ -1125,23 +1143,26 @@ where let method = req.method().clone(); - // set state in request extensions - todo!(); - // written with a pattern match like this to ensure we call all routes - let MethodRouter { - get, - head, - delete, - options, - patch, - post, - put, - trace, - fallback, - allow_header, - _request_body: _, - } = self.method_router; + let Self { + state, + method_router: + MethodRouter { + get, + head, + delete, + options, + patch, + post, + put, + trace, + fallback, + allow_header, + _request_body: _, + }, + } = self; + + req.extensions_mut().insert(state.clone()); call!(req, method, HEAD, head); call!(req, method, HEAD, get); @@ -1171,7 +1192,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{body::Body, error_handling::HandleErrorLayer}; + use crate::{body::Body, error_handling::HandleErrorLayer, extract::State}; use axum_core::response::IntoResponse; use http::{header::ALLOW, HeaderMap}; use std::time::Duration; @@ -1265,8 +1286,7 @@ mod tests { delete_service(ServeDir::new(".")) .handle_error(|_| async { StatusCode::NOT_FOUND }), ) - // TODO(david): add `fallback` and `fallback_service` - // .fallback((|| async { StatusCode::NOT_FOUND }).into_service()) + .fallback(|| async { StatusCode::NOT_FOUND }) .put(ok) .layer( ServiceBuilder::new() @@ -1324,47 +1344,48 @@ mod tests { assert!(!headers.contains_key(ALLOW)); } - // TODO(david): add `fallback` and `fallback_service` - // #[tokio::test] - // async fn allow_header_with_fallback() { - // let mut svc = MethodRouter::new().get(ok).fallback( - // (|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }).into_service(), - // ); - - // let (status, headers, _) = call(Method::DELETE, &mut svc).await; - // assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); - // assert_eq!(headers[ALLOW], "GET,HEAD"); - // } - - // #[tokio::test] - // async fn allow_header_with_fallback_that_sets_allow() { - // async fn fallback(method: Method) -> Response { - // if method == Method::POST { - // "OK".into_response() - // } else { - // ( - // StatusCode::METHOD_NOT_ALLOWED, - // [(ALLOW, "GET,POST")], - // "Method not allowed", - // ) - // .into_response() - // } - // } - - // let mut svc = MethodRouter::new() - // .get(ok) - // .fallback(fallback.into_service()); - - // let (status, _, _) = call(Method::GET, &mut svc).await; - // assert_eq!(status, StatusCode::OK); - - // let (status, _, _) = call(Method::POST, &mut svc).await; - // assert_eq!(status, StatusCode::OK); - - // let (status, headers, _) = call(Method::DELETE, &mut svc).await; - // assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); - // assert_eq!(headers[ALLOW], "GET,POST"); - // } + #[tokio::test] + async fn allow_header_with_fallback() { + let mut svc = MethodRouter::new() + .get(ok) + .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }) + .with_state(()); + + let (status, headers, _) = call(Method::DELETE, &mut svc).await; + assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); + assert_eq!(headers[ALLOW], "GET,HEAD"); + } + + #[tokio::test] + async fn allow_header_with_fallback_that_sets_allow() { + async fn fallback(method: Method) -> Response { + if method == Method::POST { + "OK".into_response() + } else { + ( + StatusCode::METHOD_NOT_ALLOWED, + [(ALLOW, "GET,POST")], + "Method not allowed", + ) + .into_response() + } + } + + let mut svc = MethodRouter::new() + .get(ok) + .fallback(fallback) + .with_state(()); + + let (status, _, _) = call(Method::GET, &mut svc).await; + assert_eq!(status, StatusCode::OK); + + let (status, _, _) = call(Method::POST, &mut svc).await; + assert_eq!(status, StatusCode::OK); + + let (status, headers, _) = call(Method::DELETE, &mut svc).await; + assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); + assert_eq!(headers[ALLOW], "GET,POST"); + } #[tokio::test] #[should_panic( @@ -1393,6 +1414,18 @@ mod tests { let _: MethodRouter<()> = head(ok).get(ok); } + #[tokio::test] + async fn fallback_accessing_state() { + let mut svc = MethodRouter::new() + .fallback(|State(state): State<&'static str>| async move { state }) + .with_state("state"); + + let (status, _, text) = call(Method::GET, &mut svc).await; + + assert_eq!(status, StatusCode::OK); + assert_eq!(text, "state"); + } + async fn call(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String) where S: Service, Response = Response, Error = Infallible>, diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 9f61fa0c22..68fcd043a0 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -3,7 +3,7 @@ use self::{future::RouteFuture, not_found::NotFound}; use crate::{ body::{boxed, Body, Bytes, HttpBody}, - extract::connect_info::IntoMakeServiceWithConnectInfo, + extract::{connect_info::IntoMakeServiceWithConnectInfo, Extension}, handler::Handler, response::Response, routing::strip_prefix::StripPrefix, @@ -143,8 +143,39 @@ where } pub fn route(mut self, path: &str, method_router: MethodRouter) -> Self { - // self.route_service(path, method_router.with_state(self.state.clone())) - todo!() + if path.is_empty() { + panic!("Paths must start with a `/`. Use \"/\" for root routes"); + } else if !path.starts_with('/') { + panic!("Paths must start with a `/`"); + } + + let id = RouteId::next(); + + let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self + .node + .path_to_route_id + .get(path) + .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc))) + { + // if we're adding a new `MethodRouter` to a route that already has one just + // merge them. This makes `.route("/", get(_)).route("/", post(_))` work + let service = Endpoint::MethodRouter(prev_method_router.clone().merge(method_router)); + self.routes.insert(route_id, service); + return self; + } else { + Endpoint::MethodRouter(method_router) + }; + + let mut node = + Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone()); + if let Err(err) = node.insert(path, id) { + self.panic_on_matchit_error(err); + } + self.node = Arc::new(node); + + self.routes.insert(id, endpoint); + + self } #[doc = include_str!("../docs/routing/route.md")] @@ -153,54 +184,33 @@ where T: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, T::Future: Send + 'static, { - todo!() - - // if path.is_empty() { - // panic!("Paths must start with a `/`. Use \"/\" for root routes"); - // } else if !path.starts_with('/') { - // panic!("Paths must start with a `/`"); - // } - - // let service = match try_downcast::, _>(service) { - // Ok(_) => { - // panic!("Invalid route: `Router::route` cannot be used with `Router`s. Use `Router::nest` instead") - // } - // Err(svc) => svc, - // }; - - // let id = RouteId::next(); - - // let service = match try_downcast::, _>(service) { - // Ok(method_router) => { - // if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self - // .node - // .path_to_route_id - // .get(path) - // .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc))) - // { - // // if we're adding a new `MethodRouter` to a route that already has one just - // // merge them. This makes `.route("/", get(_)).route("/", post(_))` work - // let service = - // Endpoint::MethodRouter(prev_method_router.clone().merge(method_router)); - // self.routes.insert(route_id, service); - // return self; - // } else { - // Endpoint::MethodRouter(method_router) - // } - // } - // Err(service) => Endpoint::Route(Route::new(service)), - // }; - - // let mut node = - // Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone()); - // if let Err(err) = node.insert(path, id) { - // self.panic_on_matchit_error(err); - // } - // self.node = Arc::new(node); - - // self.routes.insert(id, service); - - // self + if path.is_empty() { + panic!("Paths must start with a `/`. Use \"/\" for root routes"); + } else if !path.starts_with('/') { + panic!("Paths must start with a `/`"); + } + + let service = match try_downcast::, _>(service) { + Ok(_) => { + panic!("Invalid route: `Router::route` cannot be used with `Router`s. Use `Router::nest` instead") + } + Err(svc) => svc, + }; + + let id = RouteId::next(); + + let endpoint = Endpoint::Route(Route::new(service)); + + let mut node = + Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone()); + if let Err(err) = node.insert(path, id) { + self.panic_on_matchit_error(err); + } + self.node = Arc::new(node); + + self.routes.insert(id, endpoint); + + self } #[doc = include_str!("../docs/routing/nest.md")] @@ -230,7 +240,9 @@ where // front Ok(router) => { let Router { - state, + // nesting has changed in https://github.com/tokio-rs/axum/pull/1086 + // so once that is merged we can make sure states work currectly with nesting + state: _, mut routes, node, fallback, @@ -286,12 +298,11 @@ where #[doc = include_str!("../docs/routing/merge.md")] pub fn merge(mut self, other: R) -> Self where - // TODO(david): can we use a different state type here? Since the state cannot be changed - // and has already been provided? R: Into>, + S2: Clone + Send + Sync + 'static, { let Router { - state: _, + state, routes, node, fallback, @@ -304,9 +315,14 @@ where .get(&id) .expect("no path for route id. This is a bug in axum. Please file an issue"); self = match route { - Endpoint::MethodRouter(method_router) => { - self.route(path, method_router.downcast_state()) - } + Endpoint::MethodRouter(method_router) => self.route( + path, + method_router + // this will set the state for each route + // such we don't override the inner state later in `MethodRouterWithState` + .layer(Extension(state.clone())) + .downcast_state(), + ), Endpoint::Route(route) => self.route_service(path, route), }; } @@ -514,6 +530,10 @@ where panic!("Invalid route: {}", err); } } + + pub fn state(&self) -> &S { + &self.state + } } impl Service> for Router diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index 3433264a1c..4da166baea 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -1,5 +1,4 @@ use super::*; -use crate::handler::Handler; #[tokio::test] async fn basic() { @@ -47,3 +46,15 @@ async fn or() { assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "fallback"); } + +#[tokio::test] +async fn fallback_accessing_state() { + let app = Router::with_state("state") + .fallback(|State(state): State<&'static str>| async move { state }); + + let client = TestClient::new(app); + + let res = client.get("/does-not-exist").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "state"); +} diff --git a/axum/src/routing/tests/merge.rs b/axum/src/routing/tests/merge.rs index 58af01511b..eab034d8d8 100644 --- a/axum/src/routing/tests/merge.rs +++ b/axum/src/routing/tests/merge.rs @@ -408,3 +408,47 @@ async fn middleware_that_return_early() { ); assert_eq!(client.get("/public").send().await.status(), StatusCode::OK); } + +#[tokio::test] +async fn merge_with_different_state_type() { + let inner = Router::with_state("inner".to_owned()).route( + "/foo", + get(|State(state): State| async move { state }), + ); + + let app = Router::with_state("outer").merge(inner).route( + "/bar", + get(|State(state): State<&'static str>| async move { state }), + ); + + let client = TestClient::new(app); + + let res = client.get("/foo").send().await; + assert_eq!(res.text().await, "inner"); + + let res = client.get("/bar").send().await; + assert_eq!(res.text().await, "outer"); +} + +#[tokio::test] +async fn merging_routes_different_method_different_states() { + let get = Router::with_state("get state").route( + "/", + get(|State(state): State<&'static str>| async move { state }), + ); + + let post = Router::with_state("post state").route( + "/", + post(|State(state): State<&'static str>| async move { state }), + ); + + let app = Router::new().merge(get).merge(post); + + let client = TestClient::new(app); + + let res = client.get("/").send().await; + assert_eq!(res.text().await, "get state"); + + let res = client.post("/").send().await; + assert_eq!(res.text().await, "post state"); +} diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 613fc41dcb..94904422bb 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -728,11 +728,14 @@ async fn limited_body_with_streaming_body() { async fn extracting_state() { #[derive(Clone)] struct AppState { + value: i32, inner: InnerState, } #[derive(Clone)] - struct InnerState {} + struct InnerState { + value: i32, + } impl From for InnerState { fn from(state: AppState) -> Self { @@ -740,10 +743,31 @@ async fn extracting_state() { } } - async fn handler(State(_): State, State(_): State) {} + async fn handler(State(outer): State, State(inner): State) { + assert_eq!(outer.value, 1); + assert_eq!(inner.value, 2); + } let state = AppState { - inner: InnerState {}, + value: 1, + inner: InnerState { value: 2 }, }; - let _: Router = Router::with_state(state).route("/", get(handler)); + + let app = Router::with_state(state).route("/", get(handler)); + let client = TestClient::new(app); + + let res = client.get("/").send().await; + assert_eq!(res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn explicitly_setting_state() { + let app = Router::with_state("...").route_service( + "/", + get(|State(state): State<&'static str>| async move { state }).with_state("foo"), + ); + + let client = TestClient::new(app); + let res = client.get("/").send().await; + assert_eq!(res.text().await, "foo"); } From 18bb70ba37738fb9bbead9b861d15b1d4297ed92 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 11 Jul 2022 20:49:19 +0200 Subject: [PATCH 04/45] make sure nesting with different states work --- axum/src/routing/tests/nest.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index 76e2b9a70e..3083141995 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -460,3 +460,33 @@ nested_route_test!(nest_9, nest = "/a", route = "/a/", expected = "/a/a/"); nested_route_test!(nest_11, nest = "/a/", route = "/", expected = "/a/"); nested_route_test!(nest_12, nest = "/a/", route = "/a", expected = "/a/a"); nested_route_test!(nest_13, nest = "/a/", route = "/a/", expected = "/a/a/"); + +#[tokio::test] +async fn nesting_with_different_state() { + let inner = Router::with_state("inner".to_owned()).route( + "/foo", + get(|State(state): State| async move { state }), + ); + + let outer = Router::with_state("outer") + .route( + "/foo", + get(|State(state): State<&'static str>| async move { state }), + ) + .nest("/nested", inner) + .route( + "/bar", + get(|State(state): State<&'static str>| async move { state }), + ); + + let client = TestClient::new(outer); + + let res = client.get("/foo").send().await; + assert_eq!(res.text().await, "outer"); + + let res = client.get("/nested/foo").send().await; + assert_eq!(res.text().await, "inner"); + + let res = client.get("/bar").send().await; + assert_eq!(res.text().await, "outer"); +} From 61bb2146d8d37cce26ab90fd366ec777ad18f16b Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 11 Jul 2022 20:54:56 +0200 Subject: [PATCH 05/45] impl Service for MethodRouter<()> --- axum-extra/src/routing/spa.rs | 7 +-- axum/src/routing/method_routing.rs | 77 ++++++++++++++++++++++-------- axum/src/routing/tests/mod.rs | 4 +- axum/src/routing/tests/nest.rs | 19 +++----- 4 files changed, 68 insertions(+), 39 deletions(-) diff --git a/axum-extra/src/routing/spa.rs b/axum-extra/src/routing/spa.rs index fac7e21297..69f1185cea 100644 --- a/axum-extra/src/routing/spa.rs +++ b/axum-extra/src/routing/spa.rs @@ -158,15 +158,12 @@ where { fn from(spa: SpaRouter) -> Self { let assets_service = get_service(ServeDir::new(&spa.paths.assets_dir)) - .handle_error(spa.handle_error.clone()) - .with_state(()); + .handle_error(spa.handle_error.clone()); Router::new() .nest(&spa.paths.assets_path, assets_service) .fallback_service( - get_service(ServeFile::new(&spa.paths.index_file)) - .handle_error(spa.handle_error) - .with_state(()), + get_service(ServeFile::new(&spa.paths.index_file)).handle_error(spa.handle_error), ) } } diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 5a7236accc..0ba802775e 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1048,6 +1048,24 @@ fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) { } } +impl Service> for MethodRouter<(), B, E> +where + B: HttpBody, +{ + type Response = Response; + type Error = E; + type Future = RouteFuture; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + self.clone().with_state(()).call(req) + } +} + impl Clone for MethodRouter { fn clone(&self) -> Self { Self { @@ -1201,15 +1219,28 @@ mod tests { #[tokio::test] async fn method_not_allowed_by_default() { - let mut svc = MethodRouter::new().with_state(()); + let mut svc = MethodRouter::new(); let (status, _, body) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert!(body.is_empty()); } + #[tokio::test] + async fn get_service_fn() { + async fn handle(_req: Request) -> Result, Infallible> { + Ok(Response::new(Body::from("ok"))) + } + + let mut svc = get_service(service_fn(handle)); + + let (status, _, body) = call(Method::GET, &mut svc).await; + assert_eq!(status, StatusCode::OK); + assert_eq!(body, "ok"); + } + #[tokio::test] async fn get_handler() { - let mut svc = MethodRouter::new().get(ok).with_state(()); + let mut svc = MethodRouter::new().get(ok); let (status, _, body) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); assert_eq!(body, "ok"); @@ -1217,7 +1248,7 @@ mod tests { #[tokio::test] async fn get_accepts_head() { - let mut svc = MethodRouter::new().get(ok).with_state(()); + let mut svc = MethodRouter::new().get(ok); let (status, _, body) = call(Method::HEAD, &mut svc).await; assert_eq!(status, StatusCode::OK); assert!(body.is_empty()); @@ -1225,7 +1256,7 @@ mod tests { #[tokio::test] async fn head_takes_precedence_over_get() { - let mut svc = MethodRouter::new().head(created).get(ok).with_state(()); + let mut svc = MethodRouter::new().head(created).get(ok); let (status, _, body) = call(Method::HEAD, &mut svc).await; assert_eq!(status, StatusCode::CREATED); assert!(body.is_empty()); @@ -1233,7 +1264,7 @@ mod tests { #[tokio::test] async fn merge() { - let mut svc = get(ok).merge(post(ok)).with_state(()); + let mut svc = get(ok).merge(post(ok)); let (status, _, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); @@ -1246,8 +1277,7 @@ mod tests { async fn layer() { let mut svc = MethodRouter::new() .get(|| async { std::future::pending::<()>().await }) - .layer(RequireAuthorizationLayer::bearer("password")) - .with_state(()); + .layer(RequireAuthorizationLayer::bearer("password")); // method with route let (status, _, _) = call(Method::GET, &mut svc).await; @@ -1262,8 +1292,7 @@ mod tests { async fn route_layer() { let mut svc = MethodRouter::new() .get(|| async { std::future::pending::<()>().await }) - .route_layer(RequireAuthorizationLayer::bearer("password")) - .with_state(()); + .route_layer(RequireAuthorizationLayer::bearer("password")); // method with route let (status, _, _) = call(Method::GET, &mut svc).await; @@ -1302,7 +1331,7 @@ mod tests { #[tokio::test] async fn sets_allow_header() { - let mut svc = MethodRouter::new().put(ok).patch(ok).with_state(()); + let mut svc = MethodRouter::new().put(ok).patch(ok); let (status, headers, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(headers[ALLOW], "PUT,PATCH"); @@ -1310,7 +1339,7 @@ mod tests { #[tokio::test] async fn sets_allow_header_get_head() { - let mut svc = MethodRouter::new().get(ok).head(ok).with_state(()); + let mut svc = MethodRouter::new().get(ok).head(ok); let (status, headers, _) = call(Method::PUT, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(headers[ALLOW], "GET,HEAD"); @@ -1318,7 +1347,7 @@ mod tests { #[tokio::test] async fn empty_allow_header_by_default() { - let mut svc = MethodRouter::new().with_state(()); + let mut svc = MethodRouter::new(); let (status, headers, _) = call(Method::PATCH, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); assert_eq!(headers[ALLOW], ""); @@ -1328,7 +1357,7 @@ mod tests { async fn allow_header_when_merging() { let a = put(ok).patch(ok); let b = get(ok).head(ok); - let mut svc = a.merge(b).with_state(()); + let mut svc = a.merge(b); let (status, headers, _) = call(Method::DELETE, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); @@ -1337,7 +1366,7 @@ mod tests { #[tokio::test] async fn allow_header_any() { - let mut svc = any(ok).with_state(()); + let mut svc = any(ok); let (status, headers, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); @@ -1348,8 +1377,7 @@ mod tests { async fn allow_header_with_fallback() { let mut svc = MethodRouter::new() .get(ok) - .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }) - .with_state(()); + .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") }); let (status, headers, _) = call(Method::DELETE, &mut svc).await; assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED); @@ -1371,10 +1399,7 @@ mod tests { } } - let mut svc = MethodRouter::new() - .get(ok) - .fallback(fallback) - .with_state(()); + let mut svc = MethodRouter::new().get(ok).fallback(fallback); let (status, _, _) = call(Method::GET, &mut svc).await; assert_eq!(status, StatusCode::OK); @@ -1414,6 +1439,18 @@ mod tests { let _: MethodRouter<()> = head(ok).get(ok); } + #[tokio::test] + async fn accessing_state() { + let mut svc = MethodRouter::new() + .get(|State(state): State<&'static str>| async move { state }) + .with_state("state"); + + let (status, _, text) = call(Method::GET, &mut svc).await; + + assert_eq!(status, StatusCode::OK); + assert_eq!(text, "state"); + } + #[tokio::test] async fn fallback_accessing_state() { let mut svc = MethodRouter::new() diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 94904422bb..f447d6fde7 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -510,7 +510,7 @@ async fn route_layer() { )] async fn good_error_message_if_using_nest_root() { let app = Router::new() - .nest("/", get(|| async {}).with_state(())) + .nest("/", get(|| async {})) .route("/", get(|| async {})); TestClient::new(app); } @@ -523,7 +523,7 @@ async fn good_error_message_if_using_nest_root() { Use `Router::fallback` instead" )] async fn good_error_message_if_using_nest_root_when_merging() { - let one = Router::new().nest("/", get(|| async {}).with_state(())); + let one = Router::new().nest("/", get(|| async {})); let two = Router::new().route("/", get(|| async {})); let app = one.merge(two); TestClient::new(app); diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index 3083141995..7876b05754 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -117,10 +117,7 @@ async fn nesting_router_at_empty_path() { #[tokio::test] async fn nesting_handler_at_root() { - let app = Router::new().nest( - "/", - get(|uri: Uri| async move { uri.to_string() }).with_state(()), - ); + let app = Router::new().nest("/", get(|uri: Uri| async move { uri.to_string() })); let client = TestClient::new(app); @@ -210,14 +207,12 @@ async fn nested_service_sees_stripped_uri() { async fn nest_static_file_server() { let app = Router::new().nest( "/static", - get_service(ServeDir::new(".")) - .handle_error(|error| async move { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Unhandled internal error: {}", error), - ) - }) - .with_state(()), + get_service(ServeDir::new(".")).handle_error(|error| async move { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Unhandled internal error: {}", error), + ) + }), ); let client = TestClient::new(app); From 2117c97367917027cf08e8e6090a091ff6820609 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 11 Jul 2022 21:12:15 +0200 Subject: [PATCH 06/45] Fix some of axum-macro's tests --- axum-macros/src/debug_handler.rs | 2 +- axum-macros/src/from_request.rs | 25 +++++++++++-------- axum-macros/src/typed_path.rs | 17 +++++++------ .../fail/argument_not_extractor.stderr | 24 +++++++++--------- .../debug_handler/fail/extract_self_mut.rs | 7 +++--- .../fail/extract_self_mut.stderr | 4 +-- .../debug_handler/fail/extract_self_ref.rs | 7 +++--- .../fail/extract_self_ref.stderr | 4 +-- .../pass/result_impl_into_response.rs | 7 +++--- .../tests/debug_handler/pass/self_receiver.rs | 7 +++--- .../tests/from_request/pass/container.rs | 2 +- .../tests/from_request/pass/derive_opt_out.rs | 7 +++--- .../tests/from_request/pass/empty_named.rs | 2 +- .../tests/from_request/pass/empty_tuple.rs | 2 +- axum-macros/tests/from_request/pass/named.rs | 2 +- .../tests/from_request/pass/named_via.rs | 2 +- axum-macros/tests/from_request/pass/tuple.rs | 2 +- .../pass/tuple_same_type_twice.rs | 2 +- .../pass/tuple_same_type_twice_via.rs | 2 +- .../tests/from_request/pass/tuple_via.rs | 2 +- axum-macros/tests/from_request/pass/unit.rs | 2 +- .../typed_path/fail/not_deserialize.stderr | 2 +- .../typed_path/pass/customize_rejection.rs | 2 +- .../typed_path/pass/named_fields_struct.rs | 2 +- .../tests/typed_path/pass/option_result.rs | 2 +- .../tests/typed_path/pass/tuple_struct.rs | 2 +- .../tests/typed_path/pass/unit_struct.rs | 2 +- .../tests/typed_path/pass/wildcards.rs | 2 +- 28 files changed, 78 insertions(+), 67 deletions(-) diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 7945e8bb84..6614c7dd81 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -196,7 +196,7 @@ fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStr #[allow(warnings)] fn #name() where - #ty: ::axum::extract::FromRequest<#body_ty> + Send, + #ty: ::axum::extract::FromRequest<(), #body_ty> + Send, {} } }) diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs index 76634ee9b2..c8cdf85a0a 100644 --- a/axum-macros/src/from_request.rs +++ b/axum-macros/src/from_request.rs @@ -106,16 +106,17 @@ fn impl_struct_by_extracting_each_field( Ok(quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, B::Error: ::std::convert::Into<::axum::BoxError>, + S: Send, { type Rejection = #rejection_ident; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: &mut ::axum::extract::RequestParts, ) -> ::std::result::Result { ::std::result::Result::Ok(Self { #(#extract_fields)* @@ -301,7 +302,7 @@ fn extract_each_field_rejection( Ok(quote_spanned! {ty_span=> #[allow(non_camel_case_types)] - #variant_name(<#extractor_ty as ::axum::extract::FromRequest<::axum::body::Body>>::Rejection), + #variant_name(<#extractor_ty as ::axum::extract::FromRequest>::Rejection), }) }) .collect::>>()?; @@ -485,18 +486,19 @@ fn impl_struct_by_extracting_all_at_once( Ok(quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, B::Error: ::std::convert::Into<::axum::BoxError>, + S: Send, { - type Rejection = <#path as ::axum::extract::FromRequest>::Rejection; + type Rejection = <#path as ::axum::extract::FromRequest>::Rejection; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: &mut ::axum::extract::RequestParts, ) -> ::std::result::Result { - ::axum::extract::FromRequest::::from_request(req) + ::axum::extract::FromRequest::::from_request(req) .await .map(|#path(inner)| inner) } @@ -540,18 +542,19 @@ fn impl_enum_by_extracting_all_at_once( Ok(quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, B::Error: ::std::convert::Into<::axum::BoxError>, + S: Send, { - type Rejection = <#path as ::axum::extract::FromRequest>::Rejection; + type Rejection = <#path as ::axum::extract::FromRequest>::Rejection; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: &mut ::axum::extract::RequestParts, ) -> ::std::result::Result { - ::axum::extract::FromRequest::::from_request(req) + ::axum::extract::FromRequest::::from_request(req) .await .map(|#path(inner)| inner) } diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs index a765437198..6a8f03c170 100644 --- a/axum-macros/src/typed_path.rs +++ b/axum-macros/src/typed_path.rs @@ -127,13 +127,14 @@ fn expand_named_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: Send, + S: Send, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { + async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { ::axum::extract::Path::from_request(req) .await .map(|path| path.0) @@ -229,13 +230,14 @@ fn expand_unnamed_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: Send, + S: Send, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { + async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { ::axum::extract::Path::from_request(req) .await .map(|path| path.0) @@ -310,13 +312,14 @@ fn expand_unit_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: Send, + S: Send, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { + async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { if req.uri().path() == ::PATH { Ok(Self) } else { @@ -387,7 +390,7 @@ enum Segment { fn path_rejection() -> TokenStream { quote! { - <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection + <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection } } diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr index 078b2b0371..265258419e 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr @@ -1,17 +1,17 @@ -error[E0277]: the trait bound `bool: FromRequest` is not satisfied +error[E0277]: the trait bound `bool: FromRequest<(), Body>` is not satisfied --> tests/debug_handler/fail/argument_not_extractor.rs:4:23 | 4 | async fn handler(foo: bool) {} - | ^^^^ the trait `FromRequest` is not implemented for `bool` + | ^^^^ the trait `FromRequest<(), Body>` is not implemented for `bool` | - = help: the following other types implement trait `FromRequest`: - () - (T1, T2) - (T1, T2, T3) - (T1, T2, T3, T4) - (T1, T2, T3, T4, T5) - (T1, T2, T3, T4, T5, T6) - (T1, T2, T3, T4, T5, T6, T7) - (T1, T2, T3, T4, T5, T6, T7, T8) - and 33 others + = help: the following other types implement trait `FromRequest`: + <() as FromRequest> + <(T1, T2) as FromRequest> + <(T1, T2, T3) as FromRequest> + <(T1, T2, T3, T4) as FromRequest> + <(T1, T2, T3, T4, T5) as FromRequest> + <(T1, T2, T3, T4, T5, T6) as FromRequest> + <(T1, T2, T3, T4, T5, T6, T7) as FromRequest> + <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequest> + and 34 others = help: see issue #48214 diff --git a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs index 01eb636bc0..d38d5e0c4d 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs @@ -7,13 +7,14 @@ use axum_macros::debug_handler; struct A; #[async_trait] -impl FromRequest for A +impl FromRequest for A where - B: Send + 'static, + B: Send, + S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr b/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr index 595786bf4e..3d80dffbca 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr +++ b/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr @@ -1,5 +1,5 @@ error: Handlers must only take owned values - --> tests/debug_handler/fail/extract_self_mut.rs:23:22 + --> tests/debug_handler/fail/extract_self_mut.rs:24:22 | -23 | async fn handler(&mut self) {} +24 | async fn handler(&mut self) {} | ^^^^^^^^^ diff --git a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs index d64732cdcc..06b87f0a82 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs @@ -7,13 +7,14 @@ use axum_macros::debug_handler; struct A; #[async_trait] -impl FromRequest for A +impl FromRequest for A where - B: Send + 'static, + B: Send, + S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr b/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr index 4c0b4950c7..82d9a89ff5 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr +++ b/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr @@ -1,5 +1,5 @@ error: Handlers must only take owned values - --> tests/debug_handler/fail/extract_self_ref.rs:23:22 + --> tests/debug_handler/fail/extract_self_ref.rs:24:22 | -23 | async fn handler(&self) {} +24 | async fn handler(&self) {} | ^^^^^ diff --git a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs index 81269a9cb2..762809b62a 100644 --- a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs +++ b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs @@ -120,13 +120,14 @@ impl A { } #[async_trait] -impl FromRequest for A +impl FromRequest for A where - B: Send + 'static, + B: Send, + S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/pass/self_receiver.rs b/axum-macros/tests/debug_handler/pass/self_receiver.rs index f22ccc08cd..a88382cf18 100644 --- a/axum-macros/tests/debug_handler/pass/self_receiver.rs +++ b/axum-macros/tests/debug_handler/pass/self_receiver.rs @@ -7,13 +7,14 @@ use axum_macros::debug_handler; struct A; #[async_trait] -impl FromRequest for A +impl FromRequest for A where - B: Send + 'static, + B: Send, + S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/from_request/pass/container.rs b/axum-macros/tests/from_request/pass/container.rs index a6732527e4..e8eaa0a58a 100644 --- a/axum-macros/tests/from_request/pass/container.rs +++ b/axum-macros/tests/from_request/pass/container.rs @@ -15,7 +15,7 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest, + Extractor: FromRequest<(), Body, Rejection = JsonRejection>, { } diff --git a/axum-macros/tests/from_request/pass/derive_opt_out.rs b/axum-macros/tests/from_request/pass/derive_opt_out.rs index 0bf24c7389..9738116d86 100644 --- a/axum-macros/tests/from_request/pass/derive_opt_out.rs +++ b/axum-macros/tests/from_request/pass/derive_opt_out.rs @@ -14,13 +14,14 @@ struct Extractor { struct OtherExtractor; #[async_trait] -impl FromRequest for OtherExtractor +impl FromRequest for OtherExtractor where - B: Send + 'static, + B: Send, + S: Send, { type Rejection = OtherExtractorRejection; - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts<(), B>) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/from_request/pass/empty_named.rs b/axum-macros/tests/from_request/pass/empty_named.rs index 2cc5dda8b6..eec021d0f5 100644 --- a/axum-macros/tests/from_request/pass/empty_named.rs +++ b/axum-macros/tests/from_request/pass/empty_named.rs @@ -5,7 +5,7 @@ struct Extractor {} fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, { } diff --git a/axum-macros/tests/from_request/pass/empty_tuple.rs b/axum-macros/tests/from_request/pass/empty_tuple.rs index bbb525fa1b..3d8bcd25c0 100644 --- a/axum-macros/tests/from_request/pass/empty_tuple.rs +++ b/axum-macros/tests/from_request/pass/empty_tuple.rs @@ -5,7 +5,7 @@ struct Extractor(); fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, { } diff --git a/axum-macros/tests/from_request/pass/named.rs b/axum-macros/tests/from_request/pass/named.rs index 4ee40a0447..89fb8da004 100644 --- a/axum-macros/tests/from_request/pass/named.rs +++ b/axum-macros/tests/from_request/pass/named.rs @@ -18,7 +18,7 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest, + Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>, { } diff --git a/axum-macros/tests/from_request/pass/named_via.rs b/axum-macros/tests/from_request/pass/named_via.rs index fc03b8c0fb..8a81869d1a 100644 --- a/axum-macros/tests/from_request/pass/named_via.rs +++ b/axum-macros/tests/from_request/pass/named_via.rs @@ -25,7 +25,7 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest, + Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>, { } diff --git a/axum-macros/tests/from_request/pass/tuple.rs b/axum-macros/tests/from_request/pass/tuple.rs index 9786285223..2af407d0f9 100644 --- a/axum-macros/tests/from_request/pass/tuple.rs +++ b/axum-macros/tests/from_request/pass/tuple.rs @@ -5,7 +5,7 @@ struct Extractor(axum::http::HeaderMap, String); fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs index 0434bb29f3..00b6dd78df 100644 --- a/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs @@ -13,7 +13,7 @@ struct Payload {} fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs index df7761eab9..0b148ebc50 100644 --- a/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs @@ -27,7 +27,7 @@ struct Payload {} fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } diff --git a/axum-macros/tests/from_request/pass/tuple_via.rs b/axum-macros/tests/from_request/pass/tuple_via.rs index dde9887771..7a8723b628 100644 --- a/axum-macros/tests/from_request/pass/tuple_via.rs +++ b/axum-macros/tests/from_request/pass/tuple_via.rs @@ -9,7 +9,7 @@ struct State; fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } diff --git a/axum-macros/tests/from_request/pass/unit.rs b/axum-macros/tests/from_request/pass/unit.rs index 57f774d143..3e5d986917 100644 --- a/axum-macros/tests/from_request/pass/unit.rs +++ b/axum-macros/tests/from_request/pass/unit.rs @@ -5,7 +5,7 @@ struct Extractor; fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, { } diff --git a/axum-macros/tests/typed_path/fail/not_deserialize.stderr b/axum-macros/tests/typed_path/fail/not_deserialize.stderr index 7581b3997c..9aabf3625f 100644 --- a/axum-macros/tests/typed_path/fail/not_deserialize.stderr +++ b/axum-macros/tests/typed_path/fail/not_deserialize.stderr @@ -15,5 +15,5 @@ error[E0277]: the trait bound `for<'de> MyPath: serde::de::Deserialize<'de>` is (T0, T1, T2, T3) and 138 others = note: required because of the requirements on the impl of `serde::de::DeserializeOwned` for `MyPath` - = note: required because of the requirements on the impl of `FromRequest` for `axum::extract::Path` + = note: required because of the requirements on the impl of `FromRequest` for `axum::extract::Path` = note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/typed_path/pass/customize_rejection.rs b/axum-macros/tests/typed_path/pass/customize_rejection.rs index 41aa7e614e..40f3ec0ada 100644 --- a/axum-macros/tests/typed_path/pass/customize_rejection.rs +++ b/axum-macros/tests/typed_path/pass/customize_rejection.rs @@ -40,7 +40,7 @@ impl Default for MyRejection { } fn main() { - axum::Router::::new() + axum::Router::<(), axum::body::Body>::new() .typed_get(|_: Result| async {}) .typed_post(|_: Result| async {}) .typed_put(|_: Result| async {}); diff --git a/axum-macros/tests/typed_path/pass/named_fields_struct.rs b/axum-macros/tests/typed_path/pass/named_fields_struct.rs index 6942bd3394..6119304080 100644 --- a/axum-macros/tests/typed_path/pass/named_fields_struct.rs +++ b/axum-macros/tests/typed_path/pass/named_fields_struct.rs @@ -9,7 +9,7 @@ struct MyPath { } fn main() { - axum::Router::::new().route("/", axum::routing::get(|_: MyPath| async {})); + axum::Router::<(), axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {})); assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id"); assert_eq!( diff --git a/axum-macros/tests/typed_path/pass/option_result.rs b/axum-macros/tests/typed_path/pass/option_result.rs index d89dea2dd5..252bde137f 100644 --- a/axum-macros/tests/typed_path/pass/option_result.rs +++ b/axum-macros/tests/typed_path/pass/option_result.rs @@ -20,7 +20,7 @@ struct UsersIndex; async fn result_handler_unit_struct(_: Result) {} fn main() { - axum::Router::::new() + axum::Router::<(), axum::body::Body>::new() .typed_get(option_handler) .typed_post(result_handler) .typed_post(result_handler_unit_struct); diff --git a/axum-macros/tests/typed_path/pass/tuple_struct.rs b/axum-macros/tests/typed_path/pass/tuple_struct.rs index 5e3d27ff40..4f8fa17eeb 100644 --- a/axum-macros/tests/typed_path/pass/tuple_struct.rs +++ b/axum-macros/tests/typed_path/pass/tuple_struct.rs @@ -8,7 +8,7 @@ pub type Result = std::result::Result; struct MyPath(u32, u32); fn main() { - axum::Router::::new().route("/", axum::routing::get(|_: MyPath| async {})); + axum::Router::<(), axum::body::Body>::new().route("/", axum::routing::get(|_: MyPath| async {})); assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id"); assert_eq!(format!("{}", MyPath(1, 2)), "/users/1/teams/2"); diff --git a/axum-macros/tests/typed_path/pass/unit_struct.rs b/axum-macros/tests/typed_path/pass/unit_struct.rs index 9b6a0f6e39..0ba27f81ac 100644 --- a/axum-macros/tests/typed_path/pass/unit_struct.rs +++ b/axum-macros/tests/typed_path/pass/unit_struct.rs @@ -5,7 +5,7 @@ use axum_extra::routing::TypedPath; struct MyPath; fn main() { - axum::Router::::new() + axum::Router::<(), axum::body::Body>::new() .route("/", axum::routing::get(|_: MyPath| async {})); assert_eq!(MyPath::PATH, "/users"); diff --git a/axum-macros/tests/typed_path/pass/wildcards.rs b/axum-macros/tests/typed_path/pass/wildcards.rs index 0c9155f71d..e7794fc895 100644 --- a/axum-macros/tests/typed_path/pass/wildcards.rs +++ b/axum-macros/tests/typed_path/pass/wildcards.rs @@ -8,5 +8,5 @@ struct MyPath { } fn main() { - axum::Router::::new().typed_get(|_: MyPath| async {}); + axum::Router::<(), axum::body::Body>::new().typed_get(|_: MyPath| async {}); } From 09e189843fe876a4823c3b44db6fbadd3c2a4299 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 14 Jul 2022 10:05:49 +0200 Subject: [PATCH 07/45] Implement more traits for `State` --- axum/src/extract/state.rs | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index fb9aa26e76..fad5e67405 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -1,8 +1,11 @@ use async_trait::async_trait; use axum_core::extract::{FromRequest, RequestParts}; -use std::convert::Infallible; +use std::{ + convert::Infallible, + ops::{Deref, DerefMut}, +}; -#[derive(Debug, Clone)] +#[derive(Debug, Default, Clone, Copy)] pub struct State(pub S); #[async_trait] @@ -19,3 +22,17 @@ where Ok(Self(inner_state)) } } + +impl Deref for State { + type Target = S; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for State { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} From f5cd34f3f19e68f06dcc4971e7677157375f6024 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 14 Jul 2022 10:18:44 +0200 Subject: [PATCH 08/45] Update examples to use `State` --- examples/async-graphql/src/main.rs | 8 ++-- examples/chat/src/main.rs | 9 ++-- .../src/main.rs | 11 +++-- examples/cors/src/main.rs | 2 +- .../customize-extractor-error/src/main.rs | 5 +- examples/customize-path-rejection/src/main.rs | 5 +- .../src/main.rs | 13 ++--- examples/global-404-handler/src/main.rs | 3 +- examples/jwt/src/main.rs | 5 +- examples/key-value-store/src/main.rs | 27 ++++++----- examples/oauth/src/main.rs | 48 ++++++++++++------- .../src/main.rs | 2 +- examples/reverse-proxy/src/main.rs | 13 ++--- .../src/main.rs | 8 ++-- examples/sessions/src/main.rs | 18 +++---- examples/sqlx-postgres/src/main.rs | 24 ++++------ examples/sse/src/main.rs | 2 +- examples/static-file-server/src/main.rs | 2 +- examples/testing/src/main.rs | 2 +- examples/todos/src/main.rs | 16 +++---- examples/tokio-postgres/src/main.rs | 25 +++++----- examples/validator/src/main.rs | 5 +- examples/versioning/src/main.rs | 5 +- examples/websockets/src/main.rs | 2 +- 24 files changed, 128 insertions(+), 132 deletions(-) diff --git a/examples/async-graphql/src/main.rs b/examples/async-graphql/src/main.rs index 7d03d3b95f..a8d84cb9a1 100644 --- a/examples/async-graphql/src/main.rs +++ b/examples/async-graphql/src/main.rs @@ -13,14 +13,14 @@ use async_graphql::{ EmptyMutation, EmptySubscription, Request, Response, Schema, }; use axum::{ - extract::Extension, + extract::State, response::{Html, IntoResponse}, routing::get, Json, Router, }; use starwars::{QueryRoot, StarWars, StarWarsSchema}; -async fn graphql_handler(schema: Extension, req: Json) -> Json { +async fn graphql_handler(schema: State, req: Json) -> Json { schema.execute(req.0).await.into() } @@ -34,9 +34,7 @@ async fn main() { .data(StarWars::new()) .finish(); - let app = Router::new() - .route("/", get(graphql_playground).post(graphql_handler)) - .layer(Extension(schema)); + let app = Router::with_state(schema).route("/", get(graphql_playground).post(graphql_handler)); println!("Playground: http://localhost:3000"); diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs index 5107323ea8..092a7d96de 100644 --- a/examples/chat/src/main.rs +++ b/examples/chat/src/main.rs @@ -9,7 +9,7 @@ use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, - Extension, + State, }, response::{Html, IntoResponse}, routing::get, @@ -44,10 +44,9 @@ async fn main() { let app_state = Arc::new(AppState { user_set, tx }); - let app = Router::new() + let app = Router::with_state(app_state) .route("/", get(index)) - .route("/websocket", get(websocket_handler)) - .layer(Extension(app_state)); + .route("/websocket", get(websocket_handler)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); @@ -59,7 +58,7 @@ async fn main() { async fn websocket_handler( ws: WebSocketUpgrade, - Extension(state): Extension>, + State(state): State>, ) -> impl IntoResponse { ws.on_upgrade(|socket| websocket(socket, state)) } diff --git a/examples/consume-body-in-extractor-or-middleware/src/main.rs b/examples/consume-body-in-extractor-or-middleware/src/main.rs index 1fdd90221a..11b7a6b28f 100644 --- a/examples/consume-body-in-extractor-or-middleware/src/main.rs +++ b/examples/consume-body-in-extractor-or-middleware/src/main.rs @@ -80,17 +80,22 @@ async fn handler(_: PrintRequestBody, body: Bytes) { struct PrintRequestBody; #[async_trait] -impl FromRequest for PrintRequestBody { +impl FromRequest for PrintRequestBody +where + S: Send + Clone, +{ type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { + let state = req.state().clone(); + let request = Request::from_request(req) .await .map_err(|err| err.into_response())?; let request = buffer_request_body(request).await?; - *req = RequestParts::new(request); + *req = RequestParts::new(state, request); Ok(Self) } diff --git a/examples/cors/src/main.rs b/examples/cors/src/main.rs index a8e02cc559..d7ae530706 100644 --- a/examples/cors/src/main.rs +++ b/examples/cors/src/main.rs @@ -38,7 +38,7 @@ async fn main() { tokio::join!(frontend, backend); } -async fn serve(app: Router, port: u16) { +async fn serve(app: Router<()>, port: u16) { let addr = SocketAddr::from(([127, 0, 0, 1], port)); axum::Server::bind(&addr) .serve(app.into_make_service()) diff --git a/examples/customize-extractor-error/src/main.rs b/examples/customize-extractor-error/src/main.rs index bc0973b4c6..448d30c57d 100644 --- a/examples/customize-extractor-error/src/main.rs +++ b/examples/customize-extractor-error/src/main.rs @@ -53,8 +53,9 @@ struct User { struct Json(T); #[async_trait] -impl FromRequest for Json +impl FromRequest for Json where + S: Send, // these trait bounds are copied from `impl FromRequest for axum::Json` T: DeserializeOwned, B: axum::body::HttpBody + Send, @@ -63,7 +64,7 @@ where { type Rejection = (StatusCode, axum::Json); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match axum::Json::::from_request(req).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { diff --git a/examples/customize-path-rejection/src/main.rs b/examples/customize-path-rejection/src/main.rs index 4e26894959..8330b95a93 100644 --- a/examples/customize-path-rejection/src/main.rs +++ b/examples/customize-path-rejection/src/main.rs @@ -52,15 +52,16 @@ struct Params { struct Path(T); #[async_trait] -impl FromRequest for Path +impl FromRequest for Path where // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` T: DeserializeOwned + Send, B: Send, + S: Send, { type Rejection = (StatusCode, axum::Json); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match axum::extract::Path::::from_request(req).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { diff --git a/examples/error-handling-and-dependency-injection/src/main.rs b/examples/error-handling-and-dependency-injection/src/main.rs index d92b43bf49..914ae18155 100644 --- a/examples/error-handling-and-dependency-injection/src/main.rs +++ b/examples/error-handling-and-dependency-injection/src/main.rs @@ -9,7 +9,7 @@ use axum::{ async_trait, - extract::{Extension, Path}, + extract::{Path, State}, http::StatusCode, response::{IntoResponse, Response}, routing::{get, post}, @@ -36,12 +36,9 @@ async fn main() { let user_repo = Arc::new(ExampleUserRepo) as DynUserRepo; // Build our application with some routes - let app = Router::new() + let app = Router::with_state(user_repo) .route("/users/:id", get(users_show)) - .route("/users", post(users_create)) - // Add our `user_repo` to all request's extensions so handlers can access - // it. - .layer(Extension(user_repo)); + .route("/users", post(users_create)); // Run our application let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -59,7 +56,7 @@ async fn main() { /// so it can be returned from handlers directly. async fn users_show( Path(user_id): Path, - Extension(user_repo): Extension, + State(user_repo): State, ) -> Result, AppError> { let user = user_repo.find(user_id).await?; @@ -69,7 +66,7 @@ async fn users_show( /// Handler for `POST /users`. async fn users_create( Json(params): Json, - Extension(user_repo): Extension, + State(user_repo): State, ) -> Result, AppError> { let user = user_repo.create(params).await?; diff --git a/examples/global-404-handler/src/main.rs b/examples/global-404-handler/src/main.rs index 385a0e21e3..a3a5ea15dc 100644 --- a/examples/global-404-handler/src/main.rs +++ b/examples/global-404-handler/src/main.rs @@ -5,7 +5,6 @@ //! ``` use axum::{ - handler::Handler, http::StatusCode, response::{Html, IntoResponse}, routing::get, @@ -27,7 +26,7 @@ async fn main() { let app = Router::new().route("/", get(handler)); // add a fallback service for handling routes to unknown paths - let app = app.fallback(handler_404.into_service()); + let app = app.fallback(handler_404); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/examples/jwt/src/main.rs b/examples/jwt/src/main.rs index 0ac4053e6c..8725581da7 100644 --- a/examples/jwt/src/main.rs +++ b/examples/jwt/src/main.rs @@ -122,13 +122,14 @@ impl AuthBody { } #[async_trait] -impl FromRequest for Claims +impl FromRequest for Claims where + S: Send, B: Send, { type Rejection = AuthError; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { // Extract the token from the authorization header let TypedHeader(Authorization(bearer)) = TypedHeader::>::from_request(req) diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index 0ad5b7f670..c65ee75a0f 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -9,7 +9,7 @@ use axum::{ body::Bytes, error_handling::HandleErrorLayer, - extract::{ContentLengthLimit, Extension, Path}, + extract::{ContentLengthLimit, Path, State}, handler::Handler, http::StatusCode, response::IntoResponse, @@ -39,8 +39,10 @@ async fn main() { .with(tracing_subscriber::fmt::layer()) .init(); + let shared_state = SharedState::default(); + // Build our application by composing routes - let app = Router::new() + let app = Router::with_state(Arc::clone(&shared_state)) .route( "/:key", // Add compression to `kv_get` @@ -50,7 +52,7 @@ async fn main() { ) .route("/keys", get(list_keys)) // Nest our admin routes under `/admin` - .nest("/admin", admin_routes()) + .nest("/admin", admin_routes(shared_state)) // Add middleware to all routes .layer( ServiceBuilder::new() @@ -60,7 +62,6 @@ async fn main() { .concurrency_limit(1024) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) - .layer(Extension(SharedState::default())) .into_inner(), ); @@ -73,16 +74,16 @@ async fn main() { .unwrap(); } -type SharedState = Arc>; +type SharedState = Arc>; #[derive(Default)] -struct State { +struct AppState { db: HashMap, } async fn kv_get( Path(key): Path, - Extension(state): Extension, + State(state): State, ) -> Result { let db = &state.read().unwrap().db; @@ -96,12 +97,12 @@ async fn kv_get( async fn kv_set( Path(key): Path, ContentLengthLimit(bytes): ContentLengthLimit, // ~5mb - Extension(state): Extension, + State(state): State, ) { state.write().unwrap().db.insert(key, bytes); } -async fn list_keys(Extension(state): Extension) -> String { +async fn list_keys(State(state): State) -> String { let db = &state.read().unwrap().db; db.keys() @@ -110,16 +111,16 @@ async fn list_keys(Extension(state): Extension) -> String { .join("\n") } -fn admin_routes() -> Router { - async fn delete_all_keys(Extension(state): Extension) { +fn admin_routes(state: SharedState) -> Router { + async fn delete_all_keys(State(state): State) { state.write().unwrap().db.clear(); } - async fn remove_key(Path(key): Path, Extension(state): Extension) { + async fn remove_key(Path(key): Path, State(state): State) { state.write().unwrap().db.remove(&key); } - Router::new() + Router::with_state(state) .route("/keys", delete(delete_all_keys)) .route("/key/:key", delete(remove_key)) // Require bearer auth for all admin routes diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index 6357a7fc08..303ca1647f 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -12,8 +12,7 @@ use async_session::{MemoryStore, Session, SessionStore}; use axum::{ async_trait, extract::{ - rejection::TypedHeaderRejectionReason, Extension, FromRequest, Query, RequestParts, - TypedHeader, + rejection::TypedHeaderRejectionReason, FromRequest, Query, RequestParts, State, TypedHeader, }, http::{header::SET_COOKIE, HeaderMap}, response::{IntoResponse, Redirect, Response}, @@ -42,17 +41,18 @@ async fn main() { // `MemoryStore` is just used as an example. Don't use this in production. let store = MemoryStore::new(); - let oauth_client = oauth_client(); + let app_state = AppState { + store, + oauth_client, + }; - let app = Router::new() + let app = Router::with_state(app_state) .route("/", get(index)) .route("/auth/discord", get(discord_auth)) .route("/auth/authorized", get(login_authorized)) .route("/protected", get(protected)) - .route("/logout", get(logout)) - .layer(Extension(store)) - .layer(Extension(oauth_client)); + .route("/logout", get(logout)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); @@ -63,6 +63,24 @@ async fn main() { .unwrap(); } +#[derive(Clone)] +struct AppState { + store: MemoryStore, + oauth_client: BasicClient, +} + +impl From for MemoryStore { + fn from(state: AppState) -> Self { + state.store + } +} + +impl From for BasicClient { + fn from(state: AppState) -> Self { + state.oauth_client + } +} + fn oauth_client() -> BasicClient { // Environment variables (* = required): // *"CLIENT_ID" "REPLACE_ME"; @@ -113,7 +131,7 @@ async fn index(user: Option) -> impl IntoResponse { } } -async fn discord_auth(Extension(client): Extension) -> impl IntoResponse { +async fn discord_auth(State(client): State) -> impl IntoResponse { let (auth_url, _csrf_token) = client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("identify".to_string())) @@ -132,7 +150,7 @@ async fn protected(user: User) -> impl IntoResponse { } async fn logout( - Extension(store): Extension, + State(store): State, TypedHeader(cookies): TypedHeader, ) -> impl IntoResponse { let cookie = cookies.get(COOKIE_NAME).unwrap(); @@ -156,8 +174,8 @@ struct AuthRequest { async fn login_authorized( Query(query): Query, - Extension(store): Extension, - Extension(oauth_client): Extension, + State(store): State, + State(oauth_client): State, ) -> impl IntoResponse { // Get an auth token let token = oauth_client @@ -205,17 +223,15 @@ impl IntoResponse for AuthRedirect { } #[async_trait] -impl FromRequest for User +impl FromRequest for User where B: Send, { // If anything goes wrong or no session is found, redirect to the auth page type Rejection = AuthRedirect; - async fn from_request(req: &mut RequestParts) -> Result { - let Extension(store) = Extension::::from_request(req) - .await - .expect("`MemoryStore` extension is missing"); + async fn from_request(req: &mut RequestParts) -> Result { + let store = req.state().clone().store; let cookies = TypedHeader::::from_request(req) .await diff --git a/examples/query-params-with-empty-strings/src/main.rs b/examples/query-params-with-empty-strings/src/main.rs index 0af20111d7..7e9a08894c 100644 --- a/examples/query-params-with-empty-strings/src/main.rs +++ b/examples/query-params-with-empty-strings/src/main.rs @@ -16,7 +16,7 @@ async fn main() { .unwrap(); } -fn app() -> Router { +fn app() -> Router<()> { Router::new().route("/", get(handler)) } diff --git a/examples/reverse-proxy/src/main.rs b/examples/reverse-proxy/src/main.rs index a9d2a5c750..af74ea1295 100644 --- a/examples/reverse-proxy/src/main.rs +++ b/examples/reverse-proxy/src/main.rs @@ -8,7 +8,7 @@ //! ``` use axum::{ - extract::Extension, + extract::State, http::{uri::Uri, Request, Response}, routing::get, Router, @@ -24,9 +24,7 @@ async fn main() { let client = Client::new(); - let app = Router::new() - .route("/", get(handler)) - .layer(Extension(client)); + let app = Router::with_state(client).route("/", get(handler)); let addr = SocketAddr::from(([127, 0, 0, 1], 4000)); println!("reverse proxy listening on {}", addr); @@ -36,12 +34,7 @@ async fn main() { .unwrap(); } -async fn handler( - Extension(client): Extension, - // NOTE: Make sure to put the request extractor last because once the request - // is extracted, extensions can't be extracted anymore. - mut req: Request, -) -> Response { +async fn handler(State(client): State, mut req: Request) -> Response { let path = req.uri().path(); let path_query = req .uri() diff --git a/examples/routes-and-handlers-close-together/src/main.rs b/examples/routes-and-handlers-close-together/src/main.rs index 5e52ad7b55..6fc75c9e41 100644 --- a/examples/routes-and-handlers-close-together/src/main.rs +++ b/examples/routes-and-handlers-close-together/src/main.rs @@ -25,7 +25,7 @@ async fn main() { .unwrap(); } -fn root() -> Router { +fn root() -> Router<()> { async fn handler() -> &'static str { "Hello, World!" } @@ -33,7 +33,7 @@ fn root() -> Router { route("/", get(handler)) } -fn get_foo() -> Router { +fn get_foo() -> Router<()> { async fn handler() -> &'static str { "Hi from `GET /foo`" } @@ -41,7 +41,7 @@ fn get_foo() -> Router { route("/foo", get(handler)) } -fn post_foo() -> Router { +fn post_foo() -> Router<()> { async fn handler() -> &'static str { "Hi from `POST /foo`" } @@ -49,6 +49,6 @@ fn post_foo() -> Router { route("/foo", post(handler)) } -fn route(path: &str, method_router: MethodRouter) -> Router { +fn route(path: &str, method_router: MethodRouter<()>) -> Router<()> { Router::new().route(path, method_router) } diff --git a/examples/sessions/src/main.rs b/examples/sessions/src/main.rs index 3251122c9a..cd0d41a1f6 100644 --- a/examples/sessions/src/main.rs +++ b/examples/sessions/src/main.rs @@ -7,7 +7,7 @@ use async_session::{MemoryStore, Session, SessionStore as _}; use axum::{ async_trait, - extract::{Extension, FromRequest, RequestParts, TypedHeader}, + extract::{FromRequest, RequestParts, TypedHeader}, headers::Cookie, http::{ self, @@ -38,9 +38,7 @@ async fn main() { // `MemoryStore` just used as an example. Don't use this in production. let store = MemoryStore::new(); - let app = Router::new() - .route("/", get(handler)) - .layer(Extension(store)); + let app = Router::with_state(store).route("/", get(handler)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); @@ -82,20 +80,16 @@ enum UserIdFromSession { } #[async_trait] -impl FromRequest for UserIdFromSession +impl FromRequest for UserIdFromSession where B: Send, { type Rejection = (StatusCode, &'static str); - async fn from_request(req: &mut RequestParts) -> Result { - let Extension(store) = Extension::::from_request(req) - .await - .expect("`MemoryStore` extension missing"); + async fn from_request(req: &mut RequestParts) -> Result { + let store = req.state().clone(); - let cookie = Option::>::from_request(req) - .await - .unwrap(); + let cookie = req.extract::>>().await.unwrap(); let session_cookie = cookie .as_ref() diff --git a/examples/sqlx-postgres/src/main.rs b/examples/sqlx-postgres/src/main.rs index 9d101618db..6548cdeb97 100644 --- a/examples/sqlx-postgres/src/main.rs +++ b/examples/sqlx-postgres/src/main.rs @@ -15,7 +15,7 @@ use axum::{ async_trait, - extract::{Extension, FromRequest, RequestParts}, + extract::{FromRequest, RequestParts, State}, http::StatusCode, routing::get, Router, @@ -46,12 +46,10 @@ async fn main() { .expect("can connect to database"); // build our application with some routes - let app = Router::new() - .route( - "/", - get(using_connection_pool_extractor).post(using_connection_extractor), - ) - .layer(Extension(pool)); + let app = Router::with_state(pool).route( + "/", + get(using_connection_pool_extractor).post(using_connection_extractor), + ); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -62,9 +60,9 @@ async fn main() { .unwrap(); } -// we can extract the connection pool with `Extension` +// we can extract the connection pool with `State` async fn using_connection_pool_extractor( - Extension(pool): Extension, + State(pool): State, ) -> Result { sqlx::query_scalar("select 'hello world from pg'") .fetch_one(&pool) @@ -77,16 +75,14 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(sqlx::pool::PoolConnection); #[async_trait] -impl FromRequest for DatabaseConnection +impl FromRequest for DatabaseConnection where B: Send, { type Rejection = (StatusCode, String); - async fn from_request(req: &mut RequestParts) -> Result { - let Extension(pool) = Extension::::from_request(req) - .await - .map_err(internal_error)?; + async fn from_request(req: &mut RequestParts) -> Result { + let pool = req.state().clone(); let conn = pool.acquire().await.map_err(internal_error)?; diff --git a/examples/sse/src/main.rs b/examples/sse/src/main.rs index 4dbfcb4673..6679971152 100644 --- a/examples/sse/src/main.rs +++ b/examples/sse/src/main.rs @@ -41,7 +41,7 @@ async fn main() { // build our application with a route let app = Router::new() - .fallback(static_files_service) + .fallback_service(static_files_service) .route("/sse", get(sse_handler)) .layer(TraceLayer::new_for_http()); diff --git a/examples/static-file-server/src/main.rs b/examples/static-file-server/src/main.rs index 1862ecab5a..f7ac2bb9a8 100644 --- a/examples/static-file-server/src/main.rs +++ b/examples/static-file-server/src/main.rs @@ -34,7 +34,7 @@ async fn main() { // as the fallback to a `Router` let app: _ = Router::new() .route("/foo", get(|| async { "Hi from /foo" })) - .fallback(get_service(ServeDir::new(".")).handle_error(handle_error)) + .fallback_service(get_service(ServeDir::new(".")).handle_error(handle_error)) .layer(TraceLayer::new_for_http()); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/examples/testing/src/main.rs b/examples/testing/src/main.rs index 2893188bf1..11c0b6b4bd 100644 --- a/examples/testing/src/main.rs +++ b/examples/testing/src/main.rs @@ -34,7 +34,7 @@ async fn main() { /// Having a function that produces our app makes it easy to call it from tests /// without having to create an HTTP server. #[allow(dead_code)] -fn app() -> Router { +fn app() -> Router<()> { Router::new() .route("/", get(|| async { "Hello, World!" })) .route( diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs index 9a33416be3..b82a308db4 100644 --- a/examples/todos/src/main.rs +++ b/examples/todos/src/main.rs @@ -15,7 +15,7 @@ use axum::{ error_handling::HandleErrorLayer, - extract::{Extension, Path, Query}, + extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, routing::{get, patch}, @@ -46,7 +46,7 @@ async fn main() { let db = Db::default(); // Compose the routes - let app = Router::new() + let app = Router::with_state(db) .route("/todos", get(todos_index).post(todos_create)) .route("/todos/:id", patch(todos_update).delete(todos_delete)) // Add middleware to all routes @@ -64,7 +64,6 @@ async fn main() { })) .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) - .layer(Extension(db)) .into_inner(), ); @@ -85,7 +84,7 @@ pub struct Pagination { async fn todos_index( pagination: Option>, - Extension(db): Extension, + State(db): State, ) -> impl IntoResponse { let todos = db.read().unwrap(); @@ -106,10 +105,7 @@ struct CreateTodo { text: String, } -async fn todos_create( - Json(input): Json, - Extension(db): Extension, -) -> impl IntoResponse { +async fn todos_create(Json(input): Json, State(db): State) -> impl IntoResponse { let todo = Todo { id: Uuid::new_v4(), text: input.text, @@ -130,7 +126,7 @@ struct UpdateTodo { async fn todos_update( Path(id): Path, Json(input): Json, - Extension(db): Extension, + State(db): State, ) -> Result { let mut todo = db .read() @@ -152,7 +148,7 @@ async fn todos_update( Ok(Json(todo)) } -async fn todos_delete(Path(id): Path, Extension(db): Extension) -> impl IntoResponse { +async fn todos_delete(Path(id): Path, State(db): State) -> impl IntoResponse { if db.write().unwrap().remove(&id).is_some() { StatusCode::NO_CONTENT } else { diff --git a/examples/tokio-postgres/src/main.rs b/examples/tokio-postgres/src/main.rs index 4489f616df..48f84c6956 100644 --- a/examples/tokio-postgres/src/main.rs +++ b/examples/tokio-postgres/src/main.rs @@ -6,7 +6,7 @@ use axum::{ async_trait, - extract::{Extension, FromRequest, RequestParts}, + extract::{FromRequest, RequestParts, State}, http::StatusCode, routing::get, Router, @@ -33,12 +33,10 @@ async fn main() { let pool = Pool::builder().build(manager).await.unwrap(); // build our application with some routes - let app = Router::new() - .route( - "/", - get(using_connection_pool_extractor).post(using_connection_extractor), - ) - .layer(Extension(pool)); + let app = Router::with_state(pool).route( + "/", + get(using_connection_pool_extractor).post(using_connection_extractor), + ); // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); @@ -51,9 +49,8 @@ async fn main() { type ConnectionPool = Pool>; -// we can exact the connection pool with `Extension` async fn using_connection_pool_extractor( - Extension(pool): Extension, + State(pool): State, ) -> Result { let conn = pool.get().await.map_err(internal_error)?; @@ -71,16 +68,16 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager>); #[async_trait] -impl FromRequest for DatabaseConnection +impl FromRequest for DatabaseConnection where B: Send, { type Rejection = (StatusCode, String); - async fn from_request(req: &mut RequestParts) -> Result { - let Extension(pool) = Extension::::from_request(req) - .await - .map_err(internal_error)?; + async fn from_request( + req: &mut RequestParts, + ) -> Result { + let pool = req.state().clone(); let conn = pool.get_owned().await.map_err(internal_error)?; diff --git a/examples/validator/src/main.rs b/examples/validator/src/main.rs index c8ce8c0884..8682eb85e5 100644 --- a/examples/validator/src/main.rs +++ b/examples/validator/src/main.rs @@ -60,16 +60,17 @@ async fn handler(ValidatedForm(input): ValidatedForm) -> Html pub struct ValidatedForm(pub T); #[async_trait] -impl FromRequest for ValidatedForm +impl FromRequest for ValidatedForm where T: DeserializeOwned + Validate, + S: Send, B: http_body::Body + Send, B::Data: Send, B::Error: Into, { type Rejection = ServerError; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let Form(value) = Form::::from_request(req).await?; value.validate()?; Ok(ValidatedForm(value)) diff --git a/examples/versioning/src/main.rs b/examples/versioning/src/main.rs index 48ade3c90c..cf8e15f280 100644 --- a/examples/versioning/src/main.rs +++ b/examples/versioning/src/main.rs @@ -48,13 +48,14 @@ enum Version { } #[async_trait] -impl FromRequest for Version +impl FromRequest for Version where B: Send, + S: Send, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let params = Path::>::from_request(req) .await .map_err(IntoResponse::into_response)?; diff --git a/examples/websockets/src/main.rs b/examples/websockets/src/main.rs index bbdcaa0821..a317bfa470 100644 --- a/examples/websockets/src/main.rs +++ b/examples/websockets/src/main.rs @@ -37,7 +37,7 @@ async fn main() { // build our application with some routes let app = Router::new() - .fallback( + .fallback_service( get_service(ServeDir::new(assets_dir).append_index_html_on_directories(true)) .handle_error(|error: std::io::Error| async move { ( From 0081c0d447b49bb062ed6e644241b3649ba1912a Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Jul 2022 14:30:59 +0200 Subject: [PATCH 09/45] consistent naming of request body param --- axum/src/routing/method_routing.rs | 55 +++++++++++++----------------- 1 file changed, 23 insertions(+), 32 deletions(-) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 0ba802775e..dff193a4ae 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -76,9 +76,9 @@ macro_rules! top_level_service_fn { $name:ident, $method:ident ) => { $(#[$m])+ - pub fn $name(svc: T) -> MethodRouter + pub fn $name(svc: T) -> MethodRouter where - T: Service, Response = Response> + Clone + Send + 'static, + T: Service, Response = Response> + Clone + Send + 'static, T::Future: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, @@ -211,7 +211,7 @@ macro_rules! chained_service_fn { $(#[$m])+ pub fn $name(self, svc: T) -> Self where - T: Service, Response = Response, Error = E> + T: Service, Response = Response, Error = E> + Clone + Send + 'static, @@ -318,12 +318,9 @@ top_level_service_fn!(trace_service, TRACE); /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn on_service( - filter: MethodFilter, - svc: T, -) -> MethodRouter +pub fn on_service(filter: MethodFilter, svc: T) -> MethodRouter where - T: Service, Response = Response> + Clone + Send + 'static, + T: Service, Response = Response> + Clone + Send + 'static, T::Future: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, @@ -384,9 +381,9 @@ where /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn any_service(svc: T) -> MethodRouter +pub fn any_service(svc: T) -> MethodRouter where - T: Service, Response = Response> + Clone + Send + 'static, + T: Service, Response = Response> + Clone + Send + 'static, T::Future: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, @@ -697,7 +694,7 @@ where } } -impl MethodRouter { +impl MethodRouter { /// Chain an additional service that will accept requests matching the given /// `MethodFilter`. /// @@ -725,10 +722,7 @@ impl MethodRouter { /// ``` pub fn on_service(self, filter: MethodFilter, svc: T) -> Self where - T: Service, Response = Response, Error = E> - + Clone - + Send - + 'static, + T: Service, Response = Response, Error = E> + Clone + Send + 'static, T::Future: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, @@ -748,10 +742,7 @@ impl MethodRouter { #[doc = include_str!("../docs/method_routing/fallback.md")] pub fn fallback_service(mut self, svc: T) -> Self where - T: Service, Response = Response, Error = E> - + Clone - + Send - + 'static, + T: Service, Response = Response, Error = E> + Clone + Send + 'static, T::Future: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, @@ -762,7 +753,7 @@ impl MethodRouter { fn fallback_boxed_response_body(mut self, svc: T) -> Self where - T: Service, Response = Response, Error = E> + Clone + Send + 'static, + T: Service, Response = Response, Error = E> + Clone + Send + 'static, T::Future: Send + 'static, { self.fallback = Fallback::Custom(Route::new(svc)); @@ -775,7 +766,7 @@ impl MethodRouter { layer: L, ) -> MethodRouter where - L: Layer>, + L: Layer>, L::Service: Service, Response = Response, Error = NewError> + Clone + Send @@ -807,14 +798,14 @@ impl MethodRouter { } #[doc = include_str!("../docs/method_routing/route_layer.md")] - pub fn route_layer(self, layer: L) -> MethodRouter + pub fn route_layer(self, layer: L) -> MethodRouter where - L: Layer>, - L::Service: Service, Response = Response, Error = E> + L: Layer>, + L::Service: Service, Response = Response, Error = E> + Clone + Send + 'static, - >>::Future: Send + 'static, + >>::Future: Send + 'static, NewResBody: HttpBody + Send + 'static, NewResBody::Error: Into, { @@ -841,7 +832,7 @@ impl MethodRouter { } #[doc = include_str!("../docs/method_routing/merge.md")] - pub fn merge(self, other: MethodRouter) -> Self { + pub fn merge(self, other: MethodRouter) -> Self { macro_rules! merge { ( $first:ident, $second:ident ) => { match ($first, $second) { @@ -933,22 +924,22 @@ impl MethodRouter { /// Apply a [`HandleErrorLayer`]. /// /// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`. - pub fn handle_error(self, f: F) -> MethodRouter + pub fn handle_error(self, f: F) -> MethodRouter where F: Clone + Send + 'static, - HandleError, F, T>: - Service, Response = Response, Error = Infallible>, - , F, T> as Service>>::Future: Send, + HandleError, F, T>: + Service, Response = Response, Error = Infallible>, + , F, T> as Service>>::Future: Send, T: 'static, E: 'static, - ReqBody: 'static, + B: 'static, { self.layer(HandleErrorLayer::new(f)) } fn on_service_boxed_response_body(self, filter: MethodFilter, svc: T) -> Self where - T: Service, Response = Response, Error = E> + Clone + Send + 'static, + T: Service, Response = Response, Error = E> + Clone + Send + 'static, T::Future: Send + 'static, { macro_rules! set_service { From e9b0b6ff5eac1b57acaa51a9068bde74863f5a67 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Jul 2022 14:49:11 +0200 Subject: [PATCH 10/45] swap type params --- axum-core/src/extract/mod.rs | 38 ++++---- axum-core/src/extract/request_parts.rs | 34 +++---- axum-core/src/extract/tuple.rs | 10 +-- axum-extra/src/body/async_read_body.rs | 2 +- axum-extra/src/extract/cached.rs | 20 +++-- axum-extra/src/extract/cookie/mod.rs | 10 +-- axum-extra/src/extract/cookie/private.rs | 8 +- axum-extra/src/extract/cookie/signed.rs | 8 +- axum-extra/src/extract/form.rs | 6 +- axum-extra/src/extract/query.rs | 4 +- axum-extra/src/json_lines.rs | 4 +- axum-extra/src/routing/mod.rs | 2 +- axum-extra/src/routing/resource.rs | 20 ++++- axum-extra/src/routing/spa.rs | 6 +- axum-extra/src/routing/typed.rs | 2 +- axum-macros/src/debug_handler.rs | 2 +- axum-macros/src/from_request.rs | 22 ++--- axum-macros/src/lib.rs | 13 +-- axum-macros/src/typed_path.rs | 14 +-- .../fail/argument_not_extractor.stderr | 22 ++--- .../debug_handler/fail/extract_self_mut.rs | 4 +- .../debug_handler/fail/extract_self_ref.rs | 4 +- .../pass/result_impl_into_response.rs | 4 +- .../tests/debug_handler/pass/self_receiver.rs | 4 +- .../tests/from_request/pass/container.rs | 2 +- .../tests/from_request/pass/derive_opt_out.rs | 4 +- .../tests/from_request/pass/empty_named.rs | 2 +- .../tests/from_request/pass/empty_tuple.rs | 2 +- .../tests/from_request/pass/enum_via.rs | 2 +- axum-macros/tests/from_request/pass/named.rs | 2 +- .../tests/from_request/pass/named_via.rs | 2 +- axum-macros/tests/from_request/pass/tuple.rs | 2 +- .../pass/tuple_same_type_twice.rs | 2 +- .../pass/tuple_same_type_twice_via.rs | 2 +- .../tests/from_request/pass/tuple_via.rs | 4 +- axum-macros/tests/from_request/pass/unit.rs | 2 +- .../typed_path/fail/not_deserialize.stderr | 2 +- axum/src/docs/error_handling.md | 2 +- axum/src/docs/extract.md | 14 +-- axum/src/docs/method_routing/fallback.md | 8 +- axum/src/docs/middleware.md | 6 +- axum/src/docs/routing/fallback.md | 4 +- axum/src/docs/routing/route.md | 89 +------------------ axum/src/docs/routing/route_service.md | 81 +++++++++++++++++ axum/src/error_handling/mod.rs | 18 ++-- axum/src/extension.rs | 4 +- axum/src/extract/connect_info.rs | 6 +- axum/src/extract/content_length_limit.rs | 6 +- axum/src/extract/host.rs | 4 +- axum/src/extract/matched_path.rs | 10 ++- axum/src/extract/mod.rs | 6 +- axum/src/extract/multipart.rs | 6 +- axum/src/extract/path/mod.rs | 4 +- axum/src/extract/query.rs | 4 +- axum/src/extract/raw_query.rs | 4 +- axum/src/extract/request_parts.rs | 12 +-- axum/src/extract/state.rs | 4 +- axum/src/extract/ws.rs | 8 +- axum/src/form.rs | 4 +- axum/src/handler/future.rs | 16 ++-- axum/src/handler/mod.rs | 65 +++++++++++--- axum/src/json.rs | 6 +- axum/src/middleware/from_extractor.rs | 45 +++++----- axum/src/middleware/from_fn.rs | 30 +++---- axum/src/routing/method_routing.rs | 2 +- axum/src/routing/mod.rs | 8 +- axum/src/routing/route.rs | 4 +- axum/src/routing/strip_prefix.rs | 2 +- axum/src/routing/tests/mod.rs | 7 +- axum/src/typed_header.rs | 4 +- .../customize-extractor-error/src/main.rs | 4 +- examples/customize-path-rejection/src/main.rs | 4 +- examples/jwt/src/main.rs | 4 +- examples/validator/src/main.rs | 4 +- examples/versioning/src/main.rs | 4 +- 75 files changed, 436 insertions(+), 370 deletions(-) create mode 100644 axum/src/docs/routing/route_service.md diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index 11236279ee..b052b3867d 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -42,13 +42,15 @@ mod tuple; /// struct MyExtractor; /// /// #[async_trait] -/// impl FromRequest for MyExtractor +/// impl FromRequest for MyExtractor /// where -/// B: Send, // required by `async_trait` +/// // these bounds are required by `async_trait` +/// B: Send, +/// S: Send, /// { /// type Rejection = http::StatusCode; /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: &mut RequestParts) -> Result { /// // ... /// # unimplemented!() /// } @@ -60,20 +62,20 @@ mod tuple; /// [`http::Request`]: http::Request /// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html #[async_trait] -pub trait FromRequest: Sized { +pub trait FromRequest: Sized { /// If the extractor fails it'll use this "rejection" type. A rejection is /// a kind of error that can be converted into a response. type Rejection: IntoResponse; /// Perform the extraction. - async fn from_request(req: &mut RequestParts) -> Result; + async fn from_request(req: &mut RequestParts) -> Result; } /// The type used with [`FromRequest`] to extract data from requests. /// /// Has several convenience methods for getting owned parts of the request. #[derive(Debug)] -pub struct RequestParts { +pub struct RequestParts { state: S, method: Method, uri: Uri, @@ -83,7 +85,7 @@ pub struct RequestParts { body: Option, } -impl RequestParts { +impl RequestParts { /// Create a new `RequestParts` with the given state. /// /// You generally shouldn't need to construct this type yourself, unless @@ -132,10 +134,14 @@ impl RequestParts { /// use http::{Method, Uri}; /// /// #[async_trait] - /// impl FromRequest for MyExtractor { + /// impl FromRequest for MyExtractor + /// where + /// B: Send, + /// S: Send, + /// { /// type Rejection = Infallible; /// - /// async fn from_request(req: &mut RequestParts) -> Result { + /// async fn from_request(req: &mut RequestParts) -> Result { /// let method = req.extract::().await?; /// let path = req.extract::().await?.path().to_owned(); /// @@ -145,7 +151,7 @@ impl RequestParts { /// ``` pub async fn extract(&mut self) -> Result where - E: FromRequest, + E: FromRequest, { E::from_request(self).await } @@ -259,29 +265,29 @@ impl RequestParts { } #[async_trait] -impl FromRequest for Option +impl FromRequest for Option where - T: FromRequest, + T: FromRequest, B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result, Self::Rejection> { + async fn from_request(req: &mut RequestParts) -> Result, Self::Rejection> { Ok(T::from_request(req).await.ok()) } } #[async_trait] -impl FromRequest for Result +impl FromRequest for Result where - T: FromRequest, + T: FromRequest, B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(T::from_request(req).await) } } diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index 4faaf2d355..e9027d2611 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -6,14 +6,14 @@ use http::{Extensions, HeaderMap, Method, Request, Uri, Version}; use std::convert::Infallible; #[async_trait] -impl FromRequest for Request +impl FromRequest for Request where B: Send, S: Clone + Send, { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let req = std::mem::replace( req, RequestParts { @@ -32,40 +32,40 @@ where } #[async_trait] -impl FromRequest for Method +impl FromRequest for Method where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.method().clone()) } } #[async_trait] -impl FromRequest for Uri +impl FromRequest for Uri where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.uri().clone()) } } #[async_trait] -impl FromRequest for Version +impl FromRequest for Version where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.version()) } } @@ -76,20 +76,20 @@ where /// /// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html #[async_trait] -impl FromRequest for HeaderMap +impl FromRequest for HeaderMap where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.headers().clone()) } } #[async_trait] -impl FromRequest for Bytes +impl FromRequest for Bytes where B: http_body::Body + Send, B::Data: Send, @@ -98,7 +98,7 @@ where { type Rejection = BytesRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; let bytes = crate::body::to_bytes(body) @@ -110,7 +110,7 @@ where } #[async_trait] -impl FromRequest for String +impl FromRequest for String where B: http_body::Body + Send, B::Data: Send, @@ -119,7 +119,7 @@ where { type Rejection = StringRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; let bytes = crate::body::to_bytes(body) @@ -134,14 +134,14 @@ where } #[async_trait] -impl FromRequest for http::request::Parts +impl FromRequest for http::request::Parts where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let method = unwrap_infallible(Method::from_request(req).await); let uri = unwrap_infallible(Uri::from_request(req).await); let version = unwrap_infallible(Version::from_request(req).await); @@ -168,6 +168,6 @@ fn unwrap_infallible(result: Result) -> T { } } -pub(crate) fn take_body(req: &mut RequestParts) -> Result { +pub(crate) fn take_body(req: &mut RequestParts) -> Result { req.take_body().ok_or(BodyAlreadyExtracted) } diff --git a/axum-core/src/extract/tuple.rs b/axum-core/src/extract/tuple.rs index 05e38bf004..1ae56e032a 100644 --- a/axum-core/src/extract/tuple.rs +++ b/axum-core/src/extract/tuple.rs @@ -4,14 +4,14 @@ use async_trait::async_trait; use std::convert::Infallible; #[async_trait] -impl FromRequest for () +impl FromRequest for () where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(_: &mut RequestParts) -> Result<(), Self::Rejection> { + async fn from_request(_: &mut RequestParts) -> Result<(), Self::Rejection> { Ok(()) } } @@ -22,15 +22,15 @@ macro_rules! impl_from_request { ( $($ty:ident),* $(,)? ) => { #[async_trait] #[allow(non_snake_case)] - impl FromRequest for ($($ty,)*) + impl FromRequest for ($($ty,)*) where - $( $ty: FromRequest + Send, )* + $( $ty: FromRequest + Send, )* B: Send, S: Send, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { $( let $ty = $ty::from_request(req).await.map_err(|err| err.into_response())?; )* Ok(($($ty,)*)) } diff --git a/axum-extra/src/body/async_read_body.rs b/axum-extra/src/body/async_read_body.rs index 6e66f0dfda..0b39ff8079 100644 --- a/axum-extra/src/body/async_read_body.rs +++ b/axum-extra/src/body/async_read_body.rs @@ -42,7 +42,7 @@ pin_project! { /// } /// /// let app = Router::new().route("/Cargo.toml", get(cargo_toml)); - /// # let _: Router = app; + /// # let _: Router<()> = app; /// ``` #[cfg(feature = "async-read-body")] #[derive(Debug)] diff --git a/axum-extra/src/extract/cached.rs b/axum-extra/src/extract/cached.rs index 9ced87e450..7826ffccba 100644 --- a/axum-extra/src/extract/cached.rs +++ b/axum-extra/src/extract/cached.rs @@ -30,13 +30,14 @@ use std::ops::{Deref, DerefMut}; /// struct Session { /* ... */ } /// /// #[async_trait] -/// impl FromRequest for Session +/// impl FromRequest for Session /// where /// B: Send, +/// S: Send, /// { /// type Rejection = (StatusCode, String); /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: &mut RequestParts) -> Result { /// // load session... /// # unimplemented!() /// } @@ -45,13 +46,14 @@ use std::ops::{Deref, DerefMut}; /// struct CurrentUser { /* ... */ } /// /// #[async_trait] -/// impl FromRequest for CurrentUser +/// impl FromRequest for CurrentUser /// where /// B: Send, +/// S: Send, /// { /// type Rejection = Response; /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: &mut RequestParts) -> Result { /// // loading a `CurrentUser` requires first loading the `Session` /// // /// // by using `Cached` we avoid extracting the session more than @@ -88,15 +90,15 @@ pub struct Cached(pub T); struct CachedEntry(T); #[async_trait] -impl FromRequest for Cached +impl FromRequest for Cached where B: Send, S: Send, - T: FromRequest + Clone + Send + Sync + 'static, + T: FromRequest + Clone + Send + Sync + 'static, { type Rejection = T::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match Extension::>::from_request(req).await { Ok(Extension(CachedEntry(value))) => Ok(Self(value)), Err(_) => { @@ -140,14 +142,14 @@ mod tests { struct Extractor(Instant); #[async_trait] - impl FromRequest for Extractor + impl FromRequest for Extractor where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { COUNTER.fetch_add(1, Ordering::SeqCst); Ok(Self(Instant::now())) } diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index 2c0b2f3e28..601721c7b5 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -80,7 +80,7 @@ pub use cookie_lib::Key; /// let app = Router::new() /// .route("/sessions", post(create_session)) /// .route("/me", get(me)); -/// # let app: Router = app; +/// # let app: Router<()> = app; /// ``` #[derive(Debug)] pub struct CookieJar { @@ -88,14 +88,14 @@ pub struct CookieJar { } #[async_trait] -impl FromRequest for CookieJar +impl FromRequest for CookieJar where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let mut jar = cookie_lib::CookieJar::new(); for cookie in cookies_from_request(req) { jar.add_original(cookie); @@ -104,8 +104,8 @@ where } } -fn cookies_from_request( - req: &mut RequestParts, +fn cookies_from_request( + req: &mut RequestParts, ) -> impl Iterator> + '_ { req.headers() .get_all(COOKIE) diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index 57adb74e29..7ef300f918 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -54,7 +54,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// .route("/get", get(get_secret)) /// // add extension with the key so `PrivateCookieJar` can access it /// .layer(Extension(key)); -/// # let app: Router = app; +/// # let app: Router<()> = app; /// ``` pub struct PrivateCookieJar { jar: cookie_lib::CookieJar, @@ -74,15 +74,15 @@ impl fmt::Debug for PrivateCookieJar { } #[async_trait] -impl FromRequest for PrivateCookieJar +impl FromRequest for PrivateCookieJar where B: Send, S: Send, K: Into + Clone + Send + Sync + 'static, { - type Rejection = as FromRequest>::Rejection; + type Rejection = as FromRequest>::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let key = Extension::::from_request(req).await?.0.into(); let mut jar = cookie_lib::CookieJar::new(); diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index cab8d6af9f..2a9a4f96d7 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -72,7 +72,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// .route("/me", get(me)) /// // add extension with the key so `SignedCookieJar` can access it /// .layer(Extension(key)); -/// # let app: Router = app; +/// # let app: Router<()> = app; /// ``` pub struct SignedCookieJar { jar: cookie_lib::CookieJar, @@ -92,15 +92,15 @@ impl fmt::Debug for SignedCookieJar { } #[async_trait] -impl FromRequest for SignedCookieJar +impl FromRequest for SignedCookieJar where B: Send, S: Send, K: Into + Clone + Send + Sync + 'static, { - type Rejection = as FromRequest>::Rejection; + type Rejection = as FromRequest>::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let key = Extension::::from_request(req).await?.0.into(); let mut jar = cookie_lib::CookieJar::new(); diff --git a/axum-extra/src/extract/form.rs b/axum-extra/src/extract/form.rs index ec8091e3aa..3c2a3ea87c 100644 --- a/axum-extra/src/extract/form.rs +++ b/axum-extra/src/extract/form.rs @@ -54,7 +54,7 @@ impl Deref for Form { } #[async_trait] -impl FromRequest for Form +impl FromRequest for Form where T: DeserializeOwned, B: HttpBody + Send, @@ -64,7 +64,7 @@ where { type Rejection = FormRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if req.method() == Method::GET { let query = req.uri().query().unwrap_or_default(); let value = serde_html_form::from_str(query) @@ -85,7 +85,7 @@ where } // this is duplicated in `axum/src/extract/mod.rs` -fn has_content_type(req: &RequestParts, expected_content_type: &mime::Mime) -> bool { +fn has_content_type(req: &RequestParts, expected_content_type: &mime::Mime) -> bool { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { content_type } else { diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index f53a8a1a5f..84b0a73852 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -58,7 +58,7 @@ use std::ops::Deref; pub struct Query(pub T); #[async_trait] -impl FromRequest for Query +impl FromRequest for Query where T: DeserializeOwned, B: Send, @@ -66,7 +66,7 @@ where { type Rejection = QueryRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let query = req.uri().query().unwrap_or_default(); let value = serde_html_form::from_str(query) .map_err(FailedToDeserializeQueryString::__private_new::)?; diff --git a/axum-extra/src/json_lines.rs b/axum-extra/src/json_lines.rs index 242b43e70f..bd5b0b944a 100644 --- a/axum-extra/src/json_lines.rs +++ b/axum-extra/src/json_lines.rs @@ -98,7 +98,7 @@ impl JsonLines { } #[async_trait] -impl FromRequest for JsonLines +impl FromRequest for JsonLines where B: HttpBody + Send + 'static, B::Data: Into, @@ -108,7 +108,7 @@ where { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { // `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead` // so we can call `AsyncRead::lines` and then convert it back to a `Stream` diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 67da57f106..3ed31f7870 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -156,7 +156,7 @@ pub trait RouterExt: sealed::Sealed { /// .route_with_tsr("/foo", get(|| async {})) /// // `/bar` will redirect to `/bar/` /// .route_with_tsr("/bar/", get(|| async {})); - /// # let _: Router = app; + /// # let _: Router<()> = app; /// ``` fn route_with_tsr(self, path: &str, method_router: MethodRouter) -> Self where diff --git a/axum-extra/src/routing/resource.rs b/axum-extra/src/routing/resource.rs index b8dc68d47d..3736af92eb 100644 --- a/axum-extra/src/routing/resource.rs +++ b/axum-extra/src/routing/resource.rs @@ -45,7 +45,7 @@ use tower_service::Service; /// ); /// /// let app = Router::new().merge(users); -/// # let _: Router = app; +/// # let _: Router<()> = app; /// ``` #[derive(Debug)] pub struct Resource { @@ -53,15 +53,27 @@ pub struct Resource { pub(crate) router: Router, } +impl Resource<(), B> +where + B: axum::body::HttpBody + Send + 'static, +{ + /// Create a `Resource` with the given name. + /// + /// All routes will be nested at `/{resource_name}`. + pub fn named(resource_name: &str) -> Self { + Self::named_with((), resource_name) + } +} + impl Resource where B: axum::body::HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { - /// Create a `Resource` with the given name. + /// Create a `Resource` with the given name and state. /// /// All routes will be nested at `/{resource_name}`. - pub fn named(state: S, resource_name: &str) -> Self { + pub fn named_with(state: S, resource_name: &str) -> Self { Self { name: resource_name.to_owned(), router: Router::with_state(state), @@ -193,7 +205,7 @@ mod tests { #[tokio::test] async fn works() { - let users = Resource::named((), "users") + let users = Resource::named("users") .index(|| async { "users#index" }) .create(|| async { "users#create" }) .new(|| async { "users#new" }) diff --git a/axum-extra/src/routing/spa.rs b/axum-extra/src/routing/spa.rs index 69f1185cea..b5c883a321 100644 --- a/axum-extra/src/routing/spa.rs +++ b/axum-extra/src/routing/spa.rs @@ -36,7 +36,7 @@ use tower_service::Service; /// .merge(spa) /// // we can still add other routes /// .route("/api/foo", get(api_foo)); -/// # let _: Router = app; +/// # let _: Router<()> = app; /// /// async fn api_foo() {} /// ``` @@ -101,7 +101,7 @@ impl SpaRouter { /// .index_file("another_file.html"); /// /// let app = Router::new().merge(spa); - /// # let _: Router = app; + /// # let _: Router<()> = app; /// ``` pub fn index_file

(mut self, path: P) -> Self where @@ -136,7 +136,7 @@ impl SpaRouter { /// } /// /// let app = Router::new().merge(spa); - /// # let _: Router = app; + /// # let _: Router<()> = app; /// ``` pub fn handle_error(self, f: F2) -> SpaRouter { SpaRouter { diff --git a/axum-extra/src/routing/typed.rs b/axum-extra/src/routing/typed.rs index 683d68431c..c927a279ee 100644 --- a/axum-extra/src/routing/typed.rs +++ b/axum-extra/src/routing/typed.rs @@ -60,7 +60,7 @@ use http::Uri; /// async fn users_destroy(_: UsersCollection) { /* ... */ } /// /// # -/// # let app: Router = app; +/// # let app: Router<()> = app; /// ``` /// /// # Using `#[derive(TypedPath)]` diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 6614c7dd81..468d06cdaa 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -196,7 +196,7 @@ fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStr #[allow(warnings)] fn #name() where - #ty: ::axum::extract::FromRequest<(), #body_ty> + Send, + #ty: ::axum::extract::FromRequest<#body_ty, ()> + Send, {} } }) diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs index c8cdf85a0a..ab5a61f926 100644 --- a/axum-macros/src/from_request.rs +++ b/axum-macros/src/from_request.rs @@ -106,7 +106,7 @@ fn impl_struct_by_extracting_each_field( Ok(quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, @@ -116,7 +116,7 @@ fn impl_struct_by_extracting_each_field( type Rejection = #rejection_ident; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: &mut ::axum::extract::RequestParts, ) -> ::std::result::Result { ::std::result::Result::Ok(Self { #(#extract_fields)* @@ -302,7 +302,7 @@ fn extract_each_field_rejection( Ok(quote_spanned! {ty_span=> #[allow(non_camel_case_types)] - #variant_name(<#extractor_ty as ::axum::extract::FromRequest>::Rejection), + #variant_name(<#extractor_ty as ::axum::extract::FromRequest<::axum::body::Body, ()>>::Rejection), }) }) .collect::>>()?; @@ -486,19 +486,19 @@ fn impl_struct_by_extracting_all_at_once( Ok(quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, B::Error: ::std::convert::Into<::axum::BoxError>, S: Send, { - type Rejection = <#path as ::axum::extract::FromRequest>::Rejection; + type Rejection = <#path as ::axum::extract::FromRequest>::Rejection; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: &mut ::axum::extract::RequestParts, ) -> ::std::result::Result { - ::axum::extract::FromRequest::::from_request(req) + ::axum::extract::FromRequest::::from_request(req) .await .map(|#path(inner)| inner) } @@ -542,19 +542,19 @@ fn impl_enum_by_extracting_all_at_once( Ok(quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, B::Error: ::std::convert::Into<::axum::BoxError>, S: Send, { - type Rejection = <#path as ::axum::extract::FromRequest>::Rejection; + type Rejection = <#path as ::axum::extract::FromRequest>::Rejection; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: &mut ::axum::extract::RequestParts, ) -> ::std::result::Result { - ::axum::extract::FromRequest::::from_request(req) + ::axum::extract::FromRequest::::from_request(req) .await .map(|#path(inner)| inner) } diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index a44a14db32..935eff844e 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -125,7 +125,7 @@ mod typed_path; /// ``` /// pub struct ViaExtractor(pub T); /// -/// // impl FromRequest for ViaExtractor { ... } +/// // impl FromRequest for ViaExtractor { ... } /// ``` /// /// More complex via extractors are not supported and require writing a manual implementation. @@ -223,14 +223,15 @@ mod typed_path; /// struct OtherExtractor; /// /// #[async_trait] -/// impl FromRequest for OtherExtractor +/// impl FromRequest for OtherExtractor /// where -/// B: Send + 'static, +/// B: Send, +/// S: Send, /// { /// // this rejection doesn't implement `Display` and `Error` /// type Rejection = (StatusCode, String); /// -/// async fn from_request(_req: &mut RequestParts) -> Result { +/// async fn from_request(_req: &mut RequestParts) -> Result { /// // ... /// # unimplemented!() /// } @@ -274,7 +275,9 @@ mod typed_path; /// [`axum::extract::rejection::ExtensionRejection`]: https://docs.rs/axum/latest/axum/extract/rejection/enum.ExtensionRejection.html #[proc_macro_derive(FromRequest, attributes(from_request))] pub fn derive_from_request(item: TokenStream) -> TokenStream { - expand_with(item, from_request::expand) + let tokens = expand_with(item, from_request::expand); + // panic!("{}", tokens); + tokens } /// Generates better error messages when applied handler functions. diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs index 6a8f03c170..9e3c508d57 100644 --- a/axum-macros/src/typed_path.rs +++ b/axum-macros/src/typed_path.rs @@ -127,14 +127,14 @@ fn expand_named_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: Send, S: Send, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { + async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { ::axum::extract::Path::from_request(req) .await .map(|path| path.0) @@ -230,14 +230,14 @@ fn expand_unnamed_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: Send, S: Send, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { + async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { ::axum::extract::Path::from_request(req) .await .map(|path| path.0) @@ -312,14 +312,14 @@ fn expand_unit_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: Send, S: Send, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { + async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { if req.uri().path() == ::PATH { Ok(Self) } else { @@ -390,7 +390,7 @@ enum Segment { fn path_rejection() -> TokenStream { quote! { - <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection + <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection } } diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr index 265258419e..420005970c 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr @@ -1,17 +1,17 @@ -error[E0277]: the trait bound `bool: FromRequest<(), Body>` is not satisfied +error[E0277]: the trait bound `bool: FromRequest` is not satisfied --> tests/debug_handler/fail/argument_not_extractor.rs:4:23 | 4 | async fn handler(foo: bool) {} - | ^^^^ the trait `FromRequest<(), Body>` is not implemented for `bool` + | ^^^^ the trait `FromRequest` is not implemented for `bool` | - = help: the following other types implement trait `FromRequest`: - <() as FromRequest> - <(T1, T2) as FromRequest> - <(T1, T2, T3) as FromRequest> - <(T1, T2, T3, T4) as FromRequest> - <(T1, T2, T3, T4, T5) as FromRequest> - <(T1, T2, T3, T4, T5, T6) as FromRequest> - <(T1, T2, T3, T4, T5, T6, T7) as FromRequest> - <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequest> + = help: the following other types implement trait `FromRequest`: + <() as FromRequest> + <(T1, T2) as FromRequest> + <(T1, T2, T3) as FromRequest> + <(T1, T2, T3, T4) as FromRequest> + <(T1, T2, T3, T4, T5) as FromRequest> + <(T1, T2, T3, T4, T5, T6) as FromRequest> + <(T1, T2, T3, T4, T5, T6, T7) as FromRequest> + <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequest> and 34 others = help: see issue #48214 diff --git a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs index d38d5e0c4d..910ba78ced 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs @@ -7,14 +7,14 @@ use axum_macros::debug_handler; struct A; #[async_trait] -impl FromRequest for A +impl FromRequest for A where B: Send, S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs index 06b87f0a82..75d8f5ae18 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs @@ -7,14 +7,14 @@ use axum_macros::debug_handler; struct A; #[async_trait] -impl FromRequest for A +impl FromRequest for A where B: Send, S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs index 762809b62a..ebf02a2629 100644 --- a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs +++ b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs @@ -120,14 +120,14 @@ impl A { } #[async_trait] -impl FromRequest for A +impl FromRequest for A where B: Send, S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/pass/self_receiver.rs b/axum-macros/tests/debug_handler/pass/self_receiver.rs index a88382cf18..3939349836 100644 --- a/axum-macros/tests/debug_handler/pass/self_receiver.rs +++ b/axum-macros/tests/debug_handler/pass/self_receiver.rs @@ -7,14 +7,14 @@ use axum_macros::debug_handler; struct A; #[async_trait] -impl FromRequest for A +impl FromRequest for A where B: Send, S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/from_request/pass/container.rs b/axum-macros/tests/from_request/pass/container.rs index e8eaa0a58a..fe388e4806 100644 --- a/axum-macros/tests/from_request/pass/container.rs +++ b/axum-macros/tests/from_request/pass/container.rs @@ -15,7 +15,7 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest<(), Body, Rejection = JsonRejection>, + Extractor: FromRequest, { } diff --git a/axum-macros/tests/from_request/pass/derive_opt_out.rs b/axum-macros/tests/from_request/pass/derive_opt_out.rs index 9738116d86..f852115361 100644 --- a/axum-macros/tests/from_request/pass/derive_opt_out.rs +++ b/axum-macros/tests/from_request/pass/derive_opt_out.rs @@ -14,14 +14,14 @@ struct Extractor { struct OtherExtractor; #[async_trait] -impl FromRequest for OtherExtractor +impl FromRequest for OtherExtractor where B: Send, S: Send, { type Rejection = OtherExtractorRejection; - async fn from_request(_req: &mut RequestParts<(), B>) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/from_request/pass/empty_named.rs b/axum-macros/tests/from_request/pass/empty_named.rs index eec021d0f5..c550f77a03 100644 --- a/axum-macros/tests/from_request/pass/empty_named.rs +++ b/axum-macros/tests/from_request/pass/empty_named.rs @@ -5,7 +5,7 @@ struct Extractor {} fn assert_from_request() where - Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, + Extractor: axum::extract::FromRequest, { } diff --git a/axum-macros/tests/from_request/pass/empty_tuple.rs b/axum-macros/tests/from_request/pass/empty_tuple.rs index 3d8bcd25c0..6429b4f9f2 100644 --- a/axum-macros/tests/from_request/pass/empty_tuple.rs +++ b/axum-macros/tests/from_request/pass/empty_tuple.rs @@ -5,7 +5,7 @@ struct Extractor(); fn assert_from_request() where - Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, + Extractor: axum::extract::FromRequest, { } diff --git a/axum-macros/tests/from_request/pass/enum_via.rs b/axum-macros/tests/from_request/pass/enum_via.rs index d6ba90e277..c68b9796a7 100644 --- a/axum-macros/tests/from_request/pass/enum_via.rs +++ b/axum-macros/tests/from_request/pass/enum_via.rs @@ -8,5 +8,5 @@ enum Extractor {} async fn foo(_: Extractor) {} fn main() { - Router::::new().route("/", get(foo)); + Router::<(), Body>::new().route("/", get(foo)); } diff --git a/axum-macros/tests/from_request/pass/named.rs b/axum-macros/tests/from_request/pass/named.rs index 89fb8da004..cd4b8649ac 100644 --- a/axum-macros/tests/from_request/pass/named.rs +++ b/axum-macros/tests/from_request/pass/named.rs @@ -18,7 +18,7 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>, + Extractor: FromRequest, { } diff --git a/axum-macros/tests/from_request/pass/named_via.rs b/axum-macros/tests/from_request/pass/named_via.rs index 8a81869d1a..44c633a097 100644 --- a/axum-macros/tests/from_request/pass/named_via.rs +++ b/axum-macros/tests/from_request/pass/named_via.rs @@ -25,7 +25,7 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>, + Extractor: FromRequest, { } diff --git a/axum-macros/tests/from_request/pass/tuple.rs b/axum-macros/tests/from_request/pass/tuple.rs index 2af407d0f9..7561f998a2 100644 --- a/axum-macros/tests/from_request/pass/tuple.rs +++ b/axum-macros/tests/from_request/pass/tuple.rs @@ -5,7 +5,7 @@ struct Extractor(axum::http::HeaderMap, String); fn assert_from_request() where - Extractor: axum::extract::FromRequest<(), axum::body::Body>, + Extractor: axum::extract::FromRequest, { } diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs index 00b6dd78df..3ed6ad0853 100644 --- a/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs @@ -13,7 +13,7 @@ struct Payload {} fn assert_from_request() where - Extractor: axum::extract::FromRequest<(), axum::body::Body>, + Extractor: axum::extract::FromRequest, { } diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs index 0b148ebc50..13ee5f259c 100644 --- a/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs @@ -27,7 +27,7 @@ struct Payload {} fn assert_from_request() where - Extractor: axum::extract::FromRequest<(), axum::body::Body>, + Extractor: axum::extract::FromRequest, { } diff --git a/axum-macros/tests/from_request/pass/tuple_via.rs b/axum-macros/tests/from_request/pass/tuple_via.rs index 7a8723b628..d08c0f52ed 100644 --- a/axum-macros/tests/from_request/pass/tuple_via.rs +++ b/axum-macros/tests/from_request/pass/tuple_via.rs @@ -1,5 +1,5 @@ -use axum_macros::FromRequest; use axum::extract::Extension; +use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor(#[from_request(via(Extension))] State); @@ -9,7 +9,7 @@ struct State; fn assert_from_request() where - Extractor: axum::extract::FromRequest<(), axum::body::Body>, + Extractor: axum::extract::FromRequest, { } diff --git a/axum-macros/tests/from_request/pass/unit.rs b/axum-macros/tests/from_request/pass/unit.rs index 3e5d986917..76073d2777 100644 --- a/axum-macros/tests/from_request/pass/unit.rs +++ b/axum-macros/tests/from_request/pass/unit.rs @@ -5,7 +5,7 @@ struct Extractor; fn assert_from_request() where - Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, + Extractor: axum::extract::FromRequest, { } diff --git a/axum-macros/tests/typed_path/fail/not_deserialize.stderr b/axum-macros/tests/typed_path/fail/not_deserialize.stderr index 9aabf3625f..bc77a0d2ea 100644 --- a/axum-macros/tests/typed_path/fail/not_deserialize.stderr +++ b/axum-macros/tests/typed_path/fail/not_deserialize.stderr @@ -15,5 +15,5 @@ error[E0277]: the trait bound `for<'de> MyPath: serde::de::Deserialize<'de>` is (T0, T1, T2, T3) and 138 others = note: required because of the requirements on the impl of `serde::de::DeserializeOwned` for `MyPath` - = note: required because of the requirements on the impl of `FromRequest` for `axum::extract::Path` + = note: required because of the requirements on the impl of `FromRequest` for `axum::extract::Path` = note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum/src/docs/error_handling.md b/axum/src/docs/error_handling.md index 01a0afaa09..45c768c69a 100644 --- a/axum/src/docs/error_handling.md +++ b/axum/src/docs/error_handling.md @@ -69,7 +69,7 @@ let some_fallible_service = tower::service_fn(|_req| async { Ok::<_, anyhow::Error>(Response::new(Body::empty())) }); -let app = Router::new().route( +let app = Router::new().route_service( "/", // we cannot route to `some_fallible_service` directly since it might fail. // we have to use `handle_error` which converts its errors into responses diff --git a/axum/src/docs/extract.md b/axum/src/docs/extract.md index 92784a65e5..32240dfda9 100644 --- a/axum/src/docs/extract.md +++ b/axum/src/docs/extract.md @@ -421,13 +421,14 @@ use http::{StatusCode, header::{HeaderValue, USER_AGENT}}; struct ExtractUserAgent(HeaderValue); #[async_trait] -impl FromRequest for ExtractUserAgent +impl FromRequest for ExtractUserAgent where B: Send, + S: Send, { type Rejection = (StatusCode, &'static str); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if let Some(user_agent) = req.headers().get(USER_AGENT) { Ok(ExtractUserAgent(user_agent.clone())) } else { @@ -472,13 +473,14 @@ struct AuthenticatedUser { } #[async_trait] -impl FromRequest for AuthenticatedUser +impl FromRequest for AuthenticatedUser where B: Send, + S: Send, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let TypedHeader(Authorization(token)) = TypedHeader::>::from_request(req) .await @@ -603,7 +605,7 @@ where B: Send, { // running extractors requires a `RequestParts` - let mut request_parts = RequestParts::new(request); + let mut request_parts = RequestParts::new((), request); // `TypedHeader>` extracts the auth token but // `RequestParts::extract` works with anything that implements `FromRequest` @@ -633,7 +635,7 @@ fn token_is_valid(token: &str) -> bool { } let app = Router::new().layer(middleware::from_fn(auth_middleware)); -# let _: Router = app; +# let _: Router<()> = app; ``` [`body::Body`]: crate::body::Body diff --git a/axum/src/docs/method_routing/fallback.md b/axum/src/docs/method_routing/fallback.md index c6a0d09024..906cbb3b5d 100644 --- a/axum/src/docs/method_routing/fallback.md +++ b/axum/src/docs/method_routing/fallback.md @@ -11,7 +11,7 @@ use axum::{ http::{StatusCode, Method, Uri}, }; -let handler = get(|| async {}).fallback(fallback.into_service()); +let handler = get(|| async {}).fallback(fallback); let app = Router::new().route("/", handler); @@ -36,11 +36,9 @@ use axum::{ http::{StatusCode, Uri}, }; -let one = get(|| async {}) - .fallback(fallback_one.into_service()); +let one = get(|| async {}).fallback(fallback_one); -let two = post(|| async {}) - .fallback(fallback_two.into_service()); +let two = post(|| async {}).fallback(fallback_two); let method_route = one.merge(two); diff --git a/axum/src/docs/middleware.md b/axum/src/docs/middleware.md index 958c504b9a..5066665d63 100644 --- a/axum/src/docs/middleware.md +++ b/axum/src/docs/middleware.md @@ -94,7 +94,7 @@ let app = Router::new() .layer(layer_one) .layer(layer_two) .layer(layer_three); -# let app: Router = app; +# let _: Router<(), axum::body::Body> = app; ``` Think of the middleware as being layered like an onion where each new layer @@ -153,7 +153,7 @@ let app = Router::new() .layer(layer_two) .layer(layer_three), ); -# let app: Router = app; +# let _: Router<(), axum::body::Body> = app; ``` `ServiceBuilder` works by composing all layers into one such that they run top @@ -436,7 +436,7 @@ async fn handler( let app = Router::new() .route("/", get(handler)) .route_layer(middleware::from_fn(auth)); -# let app: Router = app; +# let _: Router<()> = app; ``` [Response extensions] can also be used but note that request extensions are not diff --git a/axum/src/docs/routing/fallback.md b/axum/src/docs/routing/fallback.md index a5c7467ed3..f2b5d3331e 100644 --- a/axum/src/docs/routing/fallback.md +++ b/axum/src/docs/routing/fallback.md @@ -1,4 +1,4 @@ -Add a fallback service to the router. +Add a fallback [`Handler`] to the router. This service will be called if no routes matches the incoming request. @@ -13,7 +13,7 @@ use axum::{ let app = Router::new() .route("/foo", get(|| async { /* ... */ })) - .fallback(fallback.into_service()); + .fallback(fallback); async fn fallback(uri: Uri) -> (StatusCode, String) { (StatusCode::NOT_FOUND, format!("No route for {}", uri)) diff --git a/axum/src/docs/routing/route.md b/axum/src/docs/routing/route.md index e0a753a005..f372e4d27b 100644 --- a/axum/src/docs/routing/route.md +++ b/axum/src/docs/routing/route.md @@ -3,10 +3,10 @@ Add another route to the router. `path` is a string of path segments separated by `/`. Each segment can be either static, a capture, or a wildcard. -`service` is the [`Service`] that should receive the request if the path matches -`path`. `service` will commonly be a handler wrapped in a method router like -[`get`](crate::routing::get). See [`handler`](crate::handler) for more details -on handlers. +`method_router` is the [`MethodRouter`] that should receive the request if the +path matches `path`. `method_router` will commonly be a handler wrapped in a method +router like [`get`](crate::routing::get). See [`handler`](crate::handler) for +more details on handlers. # Static paths @@ -103,69 +103,6 @@ async fn serve_asset(Path(path): Path) {} # }; ``` -# Routing to any [`Service`] - -axum also supports routing to general [`Service`]s: - -```rust,no_run -use axum::{ - Router, - body::Body, - routing::{any_service, get_service}, - http::{Request, StatusCode}, - error_handling::HandleErrorLayer, -}; -use tower_http::services::ServeFile; -use http::Response; -use std::{convert::Infallible, io}; -use tower::service_fn; - -let app = Router::new() - .route( - // Any request to `/` goes to a service - "/", - // Services whose response body is not `axum::body::BoxBody` - // can be wrapped in `axum::routing::any_service` (or one of the other routing filters) - // to have the response body mapped - any_service(service_fn(|_: Request| async { - let res = Response::new(Body::from("Hi from `GET /`")); - Ok::<_, Infallible>(res) - })) - ) - .route( - "/foo", - // This service's response body is `axum::body::BoxBody` so - // it can be routed to directly. - service_fn(|req: Request| async move { - let body = Body::from(format!("Hi from `{} /foo`", req.method())); - let body = axum::body::boxed(body); - let res = Response::new(body); - Ok::<_, Infallible>(res) - }) - ) - .route( - // GET `/static/Cargo.toml` goes to a service from tower-http - "/static/Cargo.toml", - get_service(ServeFile::new("Cargo.toml")) - // though we must handle any potential errors - .handle_error(|error: io::Error| async move { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Unhandled internal error: {}", error), - ) - }) - ); -# async { -# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -# }; -``` - -Routing to arbitrary services in this way has complications for backpressure -([`Service::poll_ready`]). See the [Routing to services and backpressure] module -for more details. - -[Routing to services and backpressure]: middleware/index.html#routing-to-servicesmiddleware-and-backpressure - # Panics Panics if the route overlaps with another route: @@ -201,21 +138,3 @@ let app = Router::new() ``` Also panics if `path` is empty. - -## Nesting - -`route` cannot be used to nest `Router`s. Instead use [`Router::nest`]. - -Attempting to will result in a panic: - -```rust,should_panic -use axum::{routing::get, Router}; - -let app = Router::new().route( - "/", - Router::new().route("/foo", get(|| async {})), -); -# async { -# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -# }; -``` diff --git a/axum/src/docs/routing/route_service.md b/axum/src/docs/routing/route_service.md new file mode 100644 index 0000000000..a14323a933 --- /dev/null +++ b/axum/src/docs/routing/route_service.md @@ -0,0 +1,81 @@ +Add another route to the router that calls a [`Service`]. + +# Example + +```rust,no_run +use axum::{ + Router, + body::Body, + routing::{any_service, get_service}, + http::{Request, StatusCode}, + error_handling::HandleErrorLayer, +}; +use tower_http::services::ServeFile; +use http::Response; +use std::{convert::Infallible, io}; +use tower::service_fn; + +let app = Router::new() + .route( + // Any request to `/` goes to a service + "/", + // Services whose response body is not `axum::body::BoxBody` + // can be wrapped in `axum::routing::any_service` (or one of the other routing filters) + // to have the response body mapped + any_service(service_fn(|_: Request| async { + let res = Response::new(Body::from("Hi from `GET /`")); + Ok::<_, Infallible>(res) + })) + ) + .route_service( + "/foo", + // This service's response body is `axum::body::BoxBody` so + // it can be routed to directly. + service_fn(|req: Request| async move { + let body = Body::from(format!("Hi from `{} /foo`", req.method())); + let body = axum::body::boxed(body); + let res = Response::new(body); + Ok::<_, Infallible>(res) + }) + ) + .route( + // GET `/static/Cargo.toml` goes to a service from tower-http + "/static/Cargo.toml", + get_service(ServeFile::new("Cargo.toml")) + // though we must handle any potential errors + .handle_error(|error: io::Error| async move { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Unhandled internal error: {}", error), + ) + }) + ); +# async { +# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +# }; +``` + +Routing to arbitrary services in this way has complications for backpressure +([`Service::poll_ready`]). See the [Routing to services and backpressure] module +for more details. + +# Panics + +Panics for the same reasons as [`Router::route`] or if you attempt to route to a +`Router`: + +```rust,should_panic +use axum::{routing::get, Router}; + +let app = Router::new().route_service( + "/", + Router::new().route("/foo", get(|| async {})), +); +# async { +# axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +# }; +``` + +Use [`Router::nest`] instead. + +[Routing to services and backpressure]: middleware/index.html#routing-to-servicesmiddleware-and-backpressure diff --git a/axum/src/error_handling/mod.rs b/axum/src/error_handling/mod.rs index 3c5779a2d5..d905245b17 100644 --- a/axum/src/error_handling/mod.rs +++ b/axum/src/error_handling/mod.rs @@ -114,15 +114,15 @@ where } } -impl Service> for HandleError +impl Service> for HandleError where - S: Service, Response = Response> + Clone + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Error: Send, S::Future: Send, F: FnOnce(S::Error) -> Fut + Clone + Send + 'static, Fut: Future + Send, Res: IntoResponse, - ReqBody: Send + 'static, + B: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { @@ -134,7 +134,7 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { let f = self.f.clone(); let clone = self.inner.clone(); @@ -154,17 +154,17 @@ where #[allow(unused_macros)] macro_rules! impl_service { ( $($ty:ident),* $(,)? ) => { - impl Service> + impl Service> for HandleError where - S: Service, Response = Response> + Clone + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Error: Send, S::Future: Send, F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static, Fut: Future + Send, Res: IntoResponse, - $( $ty: FromRequest<(), ReqBody> + Send,)* - ReqBody: Send + 'static, + $( $ty: FromRequest + Send,)* + B: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { @@ -178,7 +178,7 @@ macro_rules! impl_service { } #[allow(non_snake_case)] - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { let f = self.f.clone(); let clone = self.inner.clone(); diff --git a/axum/src/extension.rs b/axum/src/extension.rs index d040b9e4e5..5d8d60ecfc 100644 --- a/axum/src/extension.rs +++ b/axum/src/extension.rs @@ -73,7 +73,7 @@ use tower_service::Service; pub struct Extension(pub T); #[async_trait] -impl FromRequest for Extension +impl FromRequest for Extension where T: Clone + Send + Sync + 'static, B: Send, @@ -81,7 +81,7 @@ where { type Rejection = ExtensionRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let value = req .extensions() .get::() diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index 3aa7684c81..ceedc116ac 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -128,15 +128,15 @@ opaque_future! { pub struct ConnectInfo(pub T); #[async_trait] -impl FromRequest for ConnectInfo +impl FromRequest for ConnectInfo where B: Send, S: Send, T: Clone + Send + Sync + 'static, { - type Rejection = as FromRequest>::Rejection; + type Rejection = as FromRequest>::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let Extension(connect_info) = Extension::::from_request(req).await?; Ok(connect_info) } diff --git a/axum/src/extract/content_length_limit.rs b/axum/src/extract/content_length_limit.rs index f4c475437f..8584c737dc 100644 --- a/axum/src/extract/content_length_limit.rs +++ b/axum/src/extract/content_length_limit.rs @@ -36,16 +36,16 @@ use std::ops::Deref; pub struct ContentLengthLimit(pub T); #[async_trait] -impl FromRequest for ContentLengthLimit +impl FromRequest for ContentLengthLimit where - T: FromRequest, + T: FromRequest, T::Rejection: IntoResponse, B: Send, S: Send, { type Rejection = ContentLengthLimitRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let content_length = req .headers() .get(http::header::CONTENT_LENGTH) diff --git a/axum/src/extract/host.rs b/axum/src/extract/host.rs index 79ae13fc28..71acb938ea 100644 --- a/axum/src/extract/host.rs +++ b/axum/src/extract/host.rs @@ -21,14 +21,14 @@ const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; pub struct Host(pub String); #[async_trait] -impl FromRequest for Host +impl FromRequest for Host where B: Send, S: Send, { type Rejection = HostRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if let Some(host) = parse_forwarded(req.headers()) { return Ok(Host(host.to_owned())); } diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index bce7bf67b2..1dfb3d536d 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -64,14 +64,14 @@ impl MatchedPath { } #[async_trait] -impl FromRequest for MatchedPath +impl FromRequest for MatchedPath where B: Send, S: Send, { type Rejection = MatchedPathRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let matched_path = req .extensions() .get::() @@ -85,7 +85,9 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{extract::Extension, handler::Handler, routing::get, test_helpers::*, Router}; + use crate::{ + extract::Extension, handler::HandlerWithoutStateExt, routing::get, test_helpers::*, Router, + }; use http::Request; use std::task::{Context, Poll}; use tower_service::Service; @@ -150,7 +152,7 @@ mod tests { "/public", Router::new().route("/assets/*path", get(handler)), ) - .nest("/foo", handler.into_service(())) + .nest("/foo", handler.into_service()) .layer(tower::layer::layer_fn(SetMatchedPathExtension)); let client = TestClient::new(app); diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index c4aeb3ad70..511447d41c 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -75,13 +75,13 @@ pub use self::ws::WebSocketUpgrade; #[doc(no_inline)] pub use crate::TypedHeader; -pub(crate) fn take_body(req: &mut RequestParts) -> Result { +pub(crate) fn take_body(req: &mut RequestParts) -> Result { req.take_body().ok_or_else(BodyAlreadyExtracted::default) } // this is duplicated in `axum-extra/src/extract/form.rs` -pub(super) fn has_content_type( - req: &RequestParts, +pub(super) fn has_content_type( + req: &RequestParts, expected_content_type: &mime::Mime, ) -> bool { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index 6b5b46f779..afcbbf1aba 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -50,7 +50,7 @@ pub struct Multipart { } #[async_trait] -impl FromRequest for Multipart +impl FromRequest for Multipart where B: HttpBody + Default + Unpin + Send + 'static, B::Error: Into, @@ -58,7 +58,7 @@ where { type Rejection = MultipartRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let stream = BodyStream::from_request(req).await?; let headers = req.headers(); let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?; @@ -180,7 +180,7 @@ impl<'a> Field<'a> { /// } /// /// let app = Router::new().route("/upload", post(upload)); - /// # let _: Router = app; + /// # let _: Router<()> = app; /// ``` pub async fn chunk(&mut self) -> Result, MultipartError> { self.inner diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index 4b04cce65f..766c21ec67 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -163,7 +163,7 @@ impl DerefMut for Path { } #[async_trait] -impl FromRequest for Path +impl FromRequest for Path where T: DeserializeOwned + Send, B: Send, @@ -171,7 +171,7 @@ where { type Rejection = PathRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let params = match req.extensions_mut().get::() { Some(UrlParams::Params(params)) => params, Some(UrlParams::InvalidUtf8InPathParam { key }) => { diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index bc462eeb1b..6c18c8a914 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -49,7 +49,7 @@ use std::ops::Deref; pub struct Query(pub T); #[async_trait] -impl FromRequest for Query +impl FromRequest for Query where T: DeserializeOwned, B: Send, @@ -57,7 +57,7 @@ where { type Rejection = QueryRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let query = req.uri().query().unwrap_or_default(); let value = serde_urlencoded::from_str(query) .map_err(FailedToDeserializeQueryString::__private_new::)?; diff --git a/axum/src/extract/raw_query.rs b/axum/src/extract/raw_query.rs index faf8df6e4c..eeda0f44be 100644 --- a/axum/src/extract/raw_query.rs +++ b/axum/src/extract/raw_query.rs @@ -27,14 +27,14 @@ use std::convert::Infallible; pub struct RawQuery(pub Option); #[async_trait] -impl FromRequest for RawQuery +impl FromRequest for RawQuery where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let query = req.uri().query().map(|query| query.to_owned()); Ok(Self(query)) } diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs index c04ff5349b..5ab9ae53f0 100644 --- a/axum/src/extract/request_parts.rs +++ b/axum/src/extract/request_parts.rs @@ -86,14 +86,14 @@ pub struct OriginalUri(pub Uri); #[cfg(feature = "original-uri")] #[async_trait] -impl FromRequest for OriginalUri +impl FromRequest for OriginalUri where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let uri = Extension::::from_request(req) .await .unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone()))) @@ -141,7 +141,7 @@ impl Stream for BodyStream { } #[async_trait] -impl FromRequest for BodyStream +impl FromRequest for BodyStream where B: HttpBody + Send + 'static, B::Data: Into, @@ -150,7 +150,7 @@ where { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)? .map_data(Into::into) .map_err(|err| Error::new(err.into())); @@ -198,14 +198,14 @@ fn body_stream_traits() { pub struct RawBody(pub B); #[async_trait] -impl FromRequest for RawBody +impl FromRequest for RawBody where B: Send, S: Send, { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; Ok(Self(body)) } diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index fad5e67405..07d8be0950 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -9,14 +9,14 @@ use std::{ pub struct State(pub S); #[async_trait] -impl FromRequest for State +impl FromRequest for State where B: Send, OuterState: Clone + Into + Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let outer_state = req.state().clone(); let inner_state = outer_state.into(); Ok(Self(inner_state)) diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 2840ad635c..1bc8b043d1 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -244,14 +244,14 @@ impl WebSocketUpgrade { } #[async_trait] -impl FromRequest for WebSocketUpgrade +impl FromRequest for WebSocketUpgrade where B: Send, S: Send, { type Rejection = WebSocketUpgradeRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if req.method() != Method::GET { return Err(MethodNotGet.into()); } @@ -289,7 +289,7 @@ where } } -fn header_eq(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { +fn header_eq(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { if let Some(header) = req.headers().get(&key) { header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) } else { @@ -297,7 +297,7 @@ fn header_eq(req: &RequestParts, key: HeaderName, value: &'static st } } -fn header_contains(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { +fn header_contains(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { let header = if let Some(header) = req.headers().get(&key) { header } else { diff --git a/axum/src/form.rs b/axum/src/form.rs index 7542841620..592c81a75f 100644 --- a/axum/src/form.rs +++ b/axum/src/form.rs @@ -56,7 +56,7 @@ use std::ops::Deref; pub struct Form(pub T); #[async_trait] -impl FromRequest for Form +impl FromRequest for Form where T: DeserializeOwned, B: HttpBody + Send, @@ -66,7 +66,7 @@ where { type Rejection = FormRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if req.method() == Method::GET { let query = req.uri().query().unwrap_or_default(); let value = serde_urlencoded::from_str(query) diff --git a/axum/src/handler/future.rs b/axum/src/handler/future.rs index b8b3e4e867..59487c31b2 100644 --- a/axum/src/handler/future.rs +++ b/axum/src/handler/future.rs @@ -19,29 +19,29 @@ opaque_future! { pin_project! { /// The response future for [`Layered`](super::Layered). - pub struct LayeredFuture + pub struct LayeredFuture where - S: Service>, + S: Service>, { #[pin] - inner: Map>, fn(Result) -> Response>, + inner: Map>, fn(Result) -> Response>, } } -impl LayeredFuture +impl LayeredFuture where - S: Service>, + S: Service>, { pub(super) fn new( - inner: Map>, fn(Result) -> Response>, + inner: Map>, fn(Result) -> Response>, ) -> Self { Self { inner } } } -impl Future for LayeredFuture +impl Future for LayeredFuture where - S: Service>, + S: Service>, { type Output = Response; diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index b9dd781433..b7975bbb98 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -124,7 +124,7 @@ pub trait Handler: Clone + Send + Sized + 'static { /// ```rust /// use axum::{ /// Server, - /// handler::Handler, + /// handler::HandlerWithoutStateExt, /// http::{Uri, Method, StatusCode}, /// response::IntoResponse, /// routing::{get, Router}, @@ -138,7 +138,7 @@ pub trait Handler: Clone + Send + Sized + 'static { /// /// let app = Router::new() /// .route("/", get(|| async {})) - /// .fallback(handler.into_service()); + /// .fallback(handler); /// /// # async { /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) @@ -149,7 +149,7 @@ pub trait Handler: Clone + Send + Sized + 'static { /// ``` /// /// [`Router::fallback`]: crate::routing::Router::fallback - fn into_service(self, state: S) -> IntoService { + fn into_service_with(self, state: S) -> IntoService { IntoService::new(self, state) } @@ -169,15 +169,15 @@ pub trait Handler: Clone + Send + Sized + 'static { /// /// # async { /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) - /// .serve(handler.into_make_service()) + /// .serve(handler.into_make_service_with(())) /// .await?; /// # Ok::<_, hyper::Error>(()) /// # }; /// ``` /// /// [`MakeService`]: tower::make::MakeService - fn into_make_service(self, state: S) -> IntoMakeService> { - IntoMakeService::new(self.into_service(state)) + fn into_make_service_with(self, state: S) -> IntoMakeService> { + IntoMakeService::new(self.into_service_with(state)) } /// Convert the handler into a [`MakeService`] which stores information @@ -200,7 +200,7 @@ pub trait Handler: Clone + Send + Sized + 'static { /// /// # async { /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) - /// .serve(handler.into_make_service_with_connect_info::()) + /// .serve(handler.into_make_service_with_connect_info_and_state::(())) /// .await?; /// # Ok::<_, hyper::Error>(()) /// # }; @@ -208,11 +208,11 @@ pub trait Handler: Clone + Send + Sized + 'static { /// /// [`MakeService`]: tower::make::MakeService /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info - fn into_make_service_with_connect_info( + fn into_make_service_with_connect_info_and_state( self, state: S, ) -> IntoMakeServiceWithConnectInfo, C> { - IntoMakeServiceWithConnectInfo::new(self.into_service(state)) + IntoMakeServiceWithConnectInfo::new(self.into_service_with(state)) } } @@ -240,7 +240,7 @@ macro_rules! impl_handler { B: Send + 'static, S: Send + 'static, Res: IntoResponse, - $( $ty: FromRequest + Send,)* + $( $ty: FromRequest + Send,)* { type Future = Pin + Send>>; @@ -313,12 +313,12 @@ where ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { - type Future = future::LayeredFuture; + type Future = future::LayeredFuture; fn call(self, state: S, req: Request) -> Self::Future { use futures_util::future::{FutureExt, Map}; - let svc = self.handler.into_service(state); + let svc = self.handler.into_service_with(state); let svc = self.layer.layer(svc); let future: Map< @@ -338,6 +338,45 @@ where } } +pub trait HandlerWithoutStateExt: Handler { + /// Convert the handler into a [`Service`] and no state. + /// + /// See [`Handler::into_service`] for more details. + fn into_service(self) -> IntoService; + + /// Convert the handler into a [`MakeService`] and no state. + /// + /// See [`Handler::into_make_service`] for more details. + fn into_make_service(self) -> IntoMakeService>; + + /// Convert the handler into a [`MakeService`] which stores information + /// about the incoming connection and has no state. + /// + /// See [`Handler::into_make_service_with_connect_info_and_state`] for more details. + fn into_make_service_with_connect_info( + self, + ) -> IntoMakeServiceWithConnectInfo, C>; +} + +impl HandlerWithoutStateExt for H +where + H: Handler, +{ + fn into_service(self) -> IntoService { + self.into_service_with(()) + } + + fn into_make_service(self) -> IntoMakeService> { + self.into_make_service_with(()) + } + + fn into_make_service_with_connect_info( + self, + ) -> IntoMakeServiceWithConnectInfo, C> { + self.into_make_service_with_connect_info_and_state(()) + } +} + #[cfg(test)] mod tests { use super::*; @@ -350,7 +389,7 @@ mod tests { format!("you said: {}", body) } - let client = TestClient::new(handle.into_service(())); + let client = TestClient::new(handle.into_service()); let res = client.post("/").body("hi there!").send().await; assert_eq!(res.status(), StatusCode::OK); diff --git a/axum/src/json.rs b/axum/src/json.rs index a7bfedcb9d..58429447ff 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -93,7 +93,7 @@ use std::ops::{Deref, DerefMut}; pub struct Json(pub T); #[async_trait] -impl FromRequest for Json +impl FromRequest for Json where T: DeserializeOwned, B: HttpBody + Send, @@ -103,7 +103,7 @@ where { type Rejection = JsonRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if json_content_type(req) { let bytes = Bytes::from_request(req).await?; @@ -136,7 +136,7 @@ where } } -fn json_content_type(req: &RequestParts) -> bool { +fn json_content_type(req: &RequestParts) -> bool { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { content_type } else { diff --git a/axum/src/middleware/from_extractor.rs b/axum/src/middleware/from_extractor.rs index 2b2884c7a0..b4134570d2 100644 --- a/axum/src/middleware/from_extractor.rs +++ b/axum/src/middleware/from_extractor.rs @@ -47,13 +47,14 @@ use tower_service::Service; /// struct RequireAuth; /// /// #[async_trait] -/// impl FromRequest for RequireAuth +/// impl FromRequest for RequireAuth /// where /// B: Send, +/// S: Send, /// { /// type Rejection = StatusCode; /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: &mut RequestParts) -> Result { /// let auth_header = req /// .headers() /// .get(header::AUTHORIZATION) @@ -166,24 +167,24 @@ where } } -impl Service> for FromExtractor +impl Service> for FromExtractor where - E: FromRequest<(), ReqBody> + 'static, - ReqBody: Default + Send + 'static, - S: Service, Response = Response> + Clone, + E: FromRequest + 'static, + B: Default + Send + 'static, + S: Service, Response = Response> + Clone, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { type Response = Response; type Error = S::Error; - type Future = ResponseFuture; + type Future = ResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { let extract_future = Box::pin(async move { let mut req = RequestParts::new((), req); let extracted = E::from_request(&mut req).await; @@ -202,36 +203,36 @@ where pin_project! { /// Response future for [`FromExtractor`]. #[allow(missing_debug_implementations)] - pub struct ResponseFuture + pub struct ResponseFuture where - E: FromRequest<(), ReqBody>, - S: Service>, + E: FromRequest, + S: Service>, { #[pin] - state: State, + state: State, svc: Option, } } pin_project! { #[project = StateProj] - enum State + enum State where - E: FromRequest<(), ReqBody>, - S: Service>, + E: FromRequest, + S: Service>, { Extracting { - future: BoxFuture<'static, (RequestParts<(), ReqBody>, Result)>, + future: BoxFuture<'static, (RequestParts, Result)>, }, Call { #[pin] future: S::Future }, } } -impl Future for ResponseFuture +impl Future for ResponseFuture where - E: FromRequest<(), ReqBody>, - S: Service, Response = Response>, - ReqBody: Default, + E: FromRequest, + S: Service, Response = Response>, + B: Default, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { @@ -281,14 +282,14 @@ mod tests { struct RequireAuth; #[async_trait::async_trait] - impl FromRequest for RequireAuth + impl FromRequest for RequireAuth where B: Send, S: Send, { type Rejection = StatusCode; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if let Some(auth) = req .headers() .get(header::AUTHORIZATION) diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index c0cda708aa..9ec59ebc8f 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -61,7 +61,7 @@ use tower_service::Service; /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .route_layer(middleware::from_fn(auth)); -/// # let app: Router = app; +/// # let app: Router<()> = app; /// ``` /// /// # Running extractors @@ -92,7 +92,7 @@ use tower_service::Service; /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .route_layer(middleware::from_fn(my_middleware)); -/// # let app: Router = app; +/// # let app: Router<()> = app; /// ``` /// /// # Passing state @@ -127,7 +127,7 @@ use tower_service::Service; /// .route_layer(middleware::from_fn(move |req, next| { /// my_middleware(req, next, state.clone()) /// })); -/// # let app: Router = app; +/// # let app: Router<()> = app; /// ``` /// /// Or via extensions: @@ -164,7 +164,7 @@ use tower_service::Service; /// .layer(Extension(state)) /// .layer(middleware::from_fn(my_middleware)), /// ); -/// # let app: Router = app; +/// # let app: Router<()> = app; /// ``` /// /// [extractors]: crate::extract::FromRequest @@ -256,18 +256,18 @@ where macro_rules! impl_service { ( $($ty:ident),* $(,)? ) => { #[allow(non_snake_case)] - impl Service> for FromFn + impl Service> for FromFn where - F: FnMut($($ty),*, Next) -> Fut + Clone + Send + 'static, - $( $ty: FromRequest<(), ReqBody> + Send, )* + F: FnMut($($ty),*, Next) -> Fut + Clone + Send + 'static, + $( $ty: FromRequest + Send, )* Fut: Future + Send + 'static, Out: IntoResponse + 'static, - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, S::Future: Send + 'static, - ReqBody: Send + 'static, + B: Send + 'static, ResBody: HttpBody + Send + 'static, ResBody::Error: Into, { @@ -279,7 +279,7 @@ macro_rules! impl_service { self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { let not_ready_inner = self.inner.clone(); let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); @@ -326,13 +326,13 @@ where } /// The remainder of a middleware stack, including the handler. -pub struct Next { - inner: BoxCloneService, Response, Infallible>, +pub struct Next { + inner: BoxCloneService, Response, Infallible>, } -impl Next { +impl Next { /// Execute the remaining middleware stack. - pub async fn run(mut self, req: Request) -> Response { + pub async fn run(mut self, req: Request) -> Response { match self.inner.call(req).await { Ok(res) => res, Err(err) => match err {}, @@ -340,7 +340,7 @@ impl Next { } } -impl fmt::Debug for Next { +impl fmt::Debug for Next { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromFnLayer") .field("inner", &self.inner) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index dff193a4ae..4ed8ee4155 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -571,7 +571,7 @@ impl MethodRouter { } } -impl MethodRouter +impl MethodRouter where B: Send + 'static, { diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 68fcd043a0..2a81c92c1b 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -142,6 +142,7 @@ where } } + #[doc = include_str!("../docs/routing/route.md")] pub fn route(mut self, path: &str, method_router: MethodRouter) -> Self { if path.is_empty() { panic!("Paths must start with a `/`. Use \"/\" for root routes"); @@ -178,7 +179,7 @@ where self } - #[doc = include_str!("../docs/routing/route.md")] + #[doc = include_str!("../docs/routing/route_service.md")] pub fn route_service(mut self, path: &str, service: T) -> Self where T: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, @@ -429,9 +430,12 @@ where T: 'static, { let state = self.state.clone(); - self.fallback_service(handler.into_service(state)) + self.fallback_service(handler.into_service_with(state)) } + /// Add a fallback [`Service`] to the router. + /// + /// See [`Router::fallback`] for more details. pub fn fallback_service(mut self, svc: T) -> Self where T: Service, Response = Response, Error = Infallible> + Clone + Send + 'static, diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 6e56d77a72..3fcce9e328 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -44,13 +44,13 @@ impl Route { } } -impl Clone for Route { +impl Clone for Route { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl fmt::Debug for Route { +impl fmt::Debug for Route { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Route").finish() } diff --git a/axum/src/routing/strip_prefix.rs b/axum/src/routing/strip_prefix.rs index ec0e232553..feb5a5c4a3 100644 --- a/axum/src/routing/strip_prefix.rs +++ b/axum/src/routing/strip_prefix.rs @@ -20,7 +20,7 @@ impl StripPrefix { } } -impl Service> for StripPrefix +impl Service> for StripPrefix where S: Service>, { diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index f447d6fde7..4b92c7f94f 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -2,7 +2,7 @@ use crate::{ body::{Bytes, Empty}, error_handling::HandleErrorLayer, extract::{self, Path, State}, - handler::Handler, + handler::{Handler, HandlerWithoutStateExt}, response::IntoResponse, routing::{delete, get, get_service, on, on_service, patch, patch_service, post, MethodFilter}, test_helpers::*, @@ -148,10 +148,7 @@ async fn routing_between_services() { }), ), ) - .route( - "/two", - on_service(MethodFilter::GET, handle.into_service(())), - ); + .route("/two", on_service(MethodFilter::GET, handle.into_service())); let client = TestClient::new(app); diff --git a/axum/src/typed_header.rs b/axum/src/typed_header.rs index c28a24a81a..88f3bd55a3 100644 --- a/axum/src/typed_header.rs +++ b/axum/src/typed_header.rs @@ -52,7 +52,7 @@ use std::{convert::Infallible, ops::Deref}; pub struct TypedHeader(pub T); #[async_trait] -impl FromRequest for TypedHeader +impl FromRequest for TypedHeader where T: headers::Header, B: Send, @@ -60,7 +60,7 @@ where { type Rejection = TypedHeaderRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match req.headers().typed_try_get::() { Ok(Some(value)) => Ok(Self(value)), Ok(None) => Err(TypedHeaderRejection { diff --git a/examples/customize-extractor-error/src/main.rs b/examples/customize-extractor-error/src/main.rs index 448d30c57d..477bae2dc4 100644 --- a/examples/customize-extractor-error/src/main.rs +++ b/examples/customize-extractor-error/src/main.rs @@ -53,7 +53,7 @@ struct User { struct Json(T); #[async_trait] -impl FromRequest for Json +impl FromRequest for Json where S: Send, // these trait bounds are copied from `impl FromRequest for axum::Json` @@ -64,7 +64,7 @@ where { type Rejection = (StatusCode, axum::Json); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match axum::Json::::from_request(req).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { diff --git a/examples/customize-path-rejection/src/main.rs b/examples/customize-path-rejection/src/main.rs index 8330b95a93..7807249280 100644 --- a/examples/customize-path-rejection/src/main.rs +++ b/examples/customize-path-rejection/src/main.rs @@ -52,7 +52,7 @@ struct Params { struct Path(T); #[async_trait] -impl FromRequest for Path +impl FromRequest for Path where // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` T: DeserializeOwned + Send, @@ -61,7 +61,7 @@ where { type Rejection = (StatusCode, axum::Json); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match axum::extract::Path::::from_request(req).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { diff --git a/examples/jwt/src/main.rs b/examples/jwt/src/main.rs index 8725581da7..ffed4c36a3 100644 --- a/examples/jwt/src/main.rs +++ b/examples/jwt/src/main.rs @@ -122,14 +122,14 @@ impl AuthBody { } #[async_trait] -impl FromRequest for Claims +impl FromRequest for Claims where S: Send, B: Send, { type Rejection = AuthError; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { // Extract the token from the authorization header let TypedHeader(Authorization(bearer)) = TypedHeader::>::from_request(req) diff --git a/examples/validator/src/main.rs b/examples/validator/src/main.rs index 8682eb85e5..efc966902d 100644 --- a/examples/validator/src/main.rs +++ b/examples/validator/src/main.rs @@ -60,7 +60,7 @@ async fn handler(ValidatedForm(input): ValidatedForm) -> Html pub struct ValidatedForm(pub T); #[async_trait] -impl FromRequest for ValidatedForm +impl FromRequest for ValidatedForm where T: DeserializeOwned + Validate, S: Send, @@ -70,7 +70,7 @@ where { type Rejection = ServerError; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let Form(value) = Form::::from_request(req).await?; value.validate()?; Ok(ValidatedForm(value)) diff --git a/examples/versioning/src/main.rs b/examples/versioning/src/main.rs index cf8e15f280..4150a04385 100644 --- a/examples/versioning/src/main.rs +++ b/examples/versioning/src/main.rs @@ -48,14 +48,14 @@ enum Version { } #[async_trait] -impl FromRequest for Version +impl FromRequest for Version where B: Send, S: Send, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let params = Path::>::from_request(req) .await .map_err(IntoResponse::into_response)?; From 2c2187d750152c6ec9f56532671cd5bcd28ea333 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Jul 2022 16:29:38 +0200 Subject: [PATCH 11/45] Default the state param to () --- axum-core/src/extract/mod.rs | 4 ++-- axum-extra/src/routing/resource.rs | 2 +- .../tests/debug_handler/fail/argument_not_extractor.stderr | 4 ++-- axum/src/handler/mod.rs | 2 +- axum/src/routing/method_routing.rs | 2 +- axum/src/routing/mod.rs | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index b052b3867d..2be92cd96e 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -62,7 +62,7 @@ mod tuple; /// [`http::Request`]: http::Request /// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html #[async_trait] -pub trait FromRequest: Sized { +pub trait FromRequest: Sized { /// If the extractor fails it'll use this "rejection" type. A rejection is /// a kind of error that can be converted into a response. type Rejection: IntoResponse; @@ -75,7 +75,7 @@ pub trait FromRequest: Sized { /// /// Has several convenience methods for getting owned parts of the request. #[derive(Debug)] -pub struct RequestParts { +pub struct RequestParts { state: S, method: Method, uri: Uri, diff --git a/axum-extra/src/routing/resource.rs b/axum-extra/src/routing/resource.rs index 3736af92eb..189fa89c3d 100644 --- a/axum-extra/src/routing/resource.rs +++ b/axum-extra/src/routing/resource.rs @@ -48,7 +48,7 @@ use tower_service::Service; /// # let _: Router<()> = app; /// ``` #[derive(Debug)] -pub struct Resource { +pub struct Resource { pub(crate) name: String, pub(crate) router: Router, } diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr index 420005970c..468c4e17f0 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr @@ -1,8 +1,8 @@ -error[E0277]: the trait bound `bool: FromRequest` is not satisfied +error[E0277]: the trait bound `bool: FromRequest` is not satisfied --> tests/debug_handler/fail/argument_not_extractor.rs:4:23 | 4 | async fn handler(foo: bool) {} - | ^^^^ the trait `FromRequest` is not implemented for `bool` + | ^^^^ the trait `FromRequest` is not implemented for `bool` | = help: the following other types implement trait `FromRequest`: <() as FromRequest> diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index b7975bbb98..41eba46600 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -63,7 +63,7 @@ pub(crate) use self::into_service_state_in_extension::IntoServiceStateInExtensio /// See the [module docs](crate::handler) for more details. /// #[doc = include_str!("../docs/debugging_handler_type_errors.md")] -pub trait Handler: Clone + Send + Sized + 'static { +pub trait Handler: Clone + Send + Sized + 'static { /// The type of future calling this handler returns. type Future: Future + Send + 'static; diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 4ed8ee4155..6b3a051ca2 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -482,7 +482,7 @@ where /// A [`Service`] that accepts requests based on a [`MethodFilter`] and /// allows chaining additional handlers and services. -pub struct MethodRouter { +pub struct MethodRouter { get: Option>, head: Option>, delete: Option>, diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 2a81c92c1b..6aad74ea15 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -63,7 +63,7 @@ impl RouteId { } /// The router type for composing handlers and services. -pub struct Router { +pub struct Router { state: S, routes: HashMap>, node: Arc, From 8c2027ebd9115072b5f503f92219b4909e163175 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Jul 2022 16:32:38 +0200 Subject: [PATCH 12/45] fix docs references --- axum/src/handler/into_service.rs | 2 +- axum/src/handler/mod.rs | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/axum/src/handler/into_service.rs b/axum/src/handler/into_service.rs index 590af8dd4e..d59e3d869c 100644 --- a/axum/src/handler/into_service.rs +++ b/axum/src/handler/into_service.rs @@ -11,7 +11,7 @@ use tower_service::Service; /// An adapter that makes a [`Handler`] into a [`Service`]. /// -/// Created with [`Handler::into_service`]. +/// Created with [`Handler::into_service_with`]. pub struct IntoService { handler: H, state: S, diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 41eba46600..c5a6893799 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -341,18 +341,22 @@ where pub trait HandlerWithoutStateExt: Handler { /// Convert the handler into a [`Service`] and no state. /// - /// See [`Handler::into_service`] for more details. + /// See [`Handler::into_service_with`] for more details. fn into_service(self) -> IntoService; /// Convert the handler into a [`MakeService`] and no state. /// - /// See [`Handler::into_make_service`] for more details. + /// See [`Handler::into_make_service_with`] for more details. + /// + /// [`MakeService`]: tower::make::MakeService fn into_make_service(self) -> IntoMakeService>; /// Convert the handler into a [`MakeService`] which stores information /// about the incoming connection and has no state. /// /// See [`Handler::into_make_service_with_connect_info_and_state`] for more details. + /// + /// [`MakeService`]: tower::make::MakeService fn into_make_service_with_connect_info( self, ) -> IntoMakeServiceWithConnectInfo, C>; From c1fe325e37b8b2433dfa34cdac03c1f03c330346 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Jul 2022 17:59:32 +0200 Subject: [PATCH 13/45] Docs and handler state refactoring --- axum-macros/src/lib.rs | 4 +- axum/src/extract/state.rs | 123 +++++++++++++++++++++++++++ axum/src/handler/mod.rs | 131 +++++------------------------ axum/src/handler/with_state.rs | 118 ++++++++++++++++++++++++++ axum/src/lib.rs | 58 ++++++++++--- axum/src/routing/method_routing.rs | 50 +++++++++-- axum/src/routing/mod.rs | 9 +- 7 files changed, 361 insertions(+), 132 deletions(-) create mode 100644 axum/src/handler/with_state.rs diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index 935eff844e..44e509479f 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -275,9 +275,7 @@ mod typed_path; /// [`axum::extract::rejection::ExtensionRejection`]: https://docs.rs/axum/latest/axum/extract/rejection/enum.ExtensionRejection.html #[proc_macro_derive(FromRequest, attributes(from_request))] pub fn derive_from_request(item: TokenStream) -> TokenStream { - let tokens = expand_with(item, from_request::expand); - // panic!("{}", tokens); - tokens + expand_with(item, from_request::expand) } /// Generates better error messages when applied handler functions. diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index 07d8be0950..f4a81317bd 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -5,6 +5,129 @@ use std::{ ops::{Deref, DerefMut}, }; +/// Extractor for state. +/// +/// # Examples +/// +/// ## With [`Router`] +/// +/// ``` +/// use axum::{Router, routing::get, extract::State}; +/// +/// // the application state +/// // +/// // here you can put configuration, database connection pools, or whatever +/// // state you need +/// #[derive(Clone)] +/// struct AppState {} +/// +/// let state = AppState {}; +/// +/// // create a `Router` that holds our state +/// let app = Router::with_state(state).route("/", get(handler)); +/// +/// async fn handler( +/// // access the state via the `State` extractor +/// // extracting a state of the wrong type results in a compile error +/// State(state): State, +/// ) { +/// // use `state`... +/// } +/// # let _: Router = app; +/// ``` +/// +/// ### Substates +/// +/// [`State`] only allows a single state type but you can use [`From`] to extract "substates": +/// +/// ``` +/// use axum::{Router, routing::get, extract::State}; +/// +/// // the application state +/// #[derive(Clone)] +/// struct AppState { +/// // that holds some api specific state +/// api_state: ApiState, +/// } +/// +/// // the api specific state +/// #[derive(Clone)] +/// struct ApiState {} +/// +/// // support converting an `AppState` in an `ApiState` +/// impl From for ApiState { +/// fn from(app_state: AppState) -> ApiState { +/// app_state.api_state +/// } +/// } +/// +/// let state = AppState { +/// api_state: ApiState {}, +/// }; +/// +/// let app = Router::with_state(state) +/// .route("/", get(handler)) +/// .route("/api/users", get(api_users)); +/// +/// async fn api_users( +/// // access the api specific state +/// State(api_state): State, +/// ) { +/// } +/// +/// async fn handler( +/// // we can still access to top level state +/// State(state): State, +/// ) { +/// } +/// # let _: Router = app; +/// ``` +/// +/// ## With [`MethodRouter`] +/// +/// ``` +/// use axum::{routing::get, extract::State}; +/// +/// #[derive(Clone)] +/// struct AppState {} +/// +/// let state = AppState {}; +/// +/// let app = get(handler) +/// // provide the state so the handler can access it +/// .with_state(state); +/// +/// async fn handler(State(state): State) { +/// // use `state`... +/// } +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// ## With [`Handler`] +/// +/// ``` +/// use axum::{routing::get, handler::Handler, extract::State}; +/// +/// #[derive(Clone)] +/// struct AppState {} +/// +/// let state = AppState {}; +/// +/// let app = handler.with_state(state); +/// +/// async fn handler(State(state): State) { +/// // use `state`... +/// } +/// +/// # async { +/// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) +/// .serve(app.into_make_service()) +/// .await +/// .expect("server failed"); +/// # }; +/// ``` #[derive(Debug, Default, Clone, Copy)] pub struct State(pub S); diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index c5a6893799..dc25100db6 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -51,9 +51,10 @@ use tower_service::Service; pub mod future; mod into_service; mod into_service_state_in_extension; +mod with_state; -pub use self::into_service::IntoService; pub(crate) use self::into_service_state_in_extension::IntoServiceStateInExtension; +pub use self::{into_service::IntoService, with_state::WithState}; /// Trait for async functions that can be used to handle requests. /// @@ -108,7 +109,7 @@ pub trait Handler: Clone + Send + Sized + 'static { /// ``` fn layer(self, layer: L) -> Layered where - L: Layer>, + L: Layer>, { Layered { layer, @@ -117,102 +118,11 @@ pub trait Handler: Clone + Send + Sized + 'static { } } - /// Convert the handler into a [`Service`]. - /// - /// This is commonly used together with [`Router::fallback`]: - /// - /// ```rust - /// use axum::{ - /// Server, - /// handler::HandlerWithoutStateExt, - /// http::{Uri, Method, StatusCode}, - /// response::IntoResponse, - /// routing::{get, Router}, - /// }; - /// use tower::make::Shared; - /// use std::net::SocketAddr; - /// - /// async fn handler(method: Method, uri: Uri) -> (StatusCode, String) { - /// (StatusCode::NOT_FOUND, format!("Nothing to see at {} {}", method, uri)) - /// } - /// - /// let app = Router::new() - /// .route("/", get(|| async {})) - /// .fallback(handler); - /// - /// # async { - /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) - /// .serve(app.into_make_service()) - /// .await?; - /// # Ok::<_, hyper::Error>(()) - /// # }; - /// ``` - /// - /// [`Router::fallback`]: crate::routing::Router::fallback - fn into_service_with(self, state: S) -> IntoService { - IntoService::new(self, state) - } - - /// Convert the handler into a [`MakeService`]. - /// - /// This allows you to serve a single handler if you don't need any routing: - /// - /// ```rust - /// use axum::{ - /// Server, handler::Handler, http::{Uri, Method}, response::IntoResponse, - /// }; - /// use std::net::SocketAddr; - /// - /// async fn handler(method: Method, uri: Uri, body: String) -> String { - /// format!("received `{} {}` with body `{:?}`", method, uri, body) - /// } - /// - /// # async { - /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) - /// .serve(handler.into_make_service_with(())) - /// .await?; - /// # Ok::<_, hyper::Error>(()) - /// # }; - /// ``` - /// - /// [`MakeService`]: tower::make::MakeService - fn into_make_service_with(self, state: S) -> IntoMakeService> { - IntoMakeService::new(self.into_service_with(state)) - } - - /// Convert the handler into a [`MakeService`] which stores information - /// about the incoming connection. - /// - /// See [`Router::into_make_service_with_connect_info`] for more details. - /// - /// ```rust - /// use axum::{ - /// Server, - /// handler::Handler, - /// response::IntoResponse, - /// extract::ConnectInfo, - /// }; - /// use std::net::SocketAddr; - /// - /// async fn handler(ConnectInfo(addr): ConnectInfo) -> String { - /// format!("Hello {}", addr) - /// } - /// - /// # async { - /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) - /// .serve(handler.into_make_service_with_connect_info_and_state::(())) - /// .await?; - /// # Ok::<_, hyper::Error>(()) - /// # }; - /// ``` - /// - /// [`MakeService`]: tower::make::MakeService - /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info - fn into_make_service_with_connect_info_and_state( - self, - state: S, - ) -> IntoMakeServiceWithConnectInfo, C> { - IntoMakeServiceWithConnectInfo::new(self.into_service_with(state)) + /// Convert the handler into a [`Service`] by providing the state + fn with_state(self, state: S) -> WithState { + WithState { + service: IntoService::new(self, state), + } } } @@ -302,7 +212,7 @@ where impl Handler for Layered where - L: Layer> + Clone + Send + 'static, + L: Layer> + Clone + Send + 'static, H: Handler, L::Service: Service, Response = Response> + Clone + Send + 'static, >>::Error: IntoResponse, @@ -318,7 +228,7 @@ where fn call(self, state: S, req: Request) -> Self::Future { use futures_util::future::{FutureExt, Map}; - let svc = self.handler.into_service_with(state); + let svc = self.handler.with_state(state); let svc = self.layer.layer(svc); let future: Map< @@ -338,15 +248,20 @@ where } } +/// Extension trait for [`Handler`]s who doesn't have state. +/// +/// This provides convenience methods to convert the [`Handler`] into a [`Service`] or [`MakeService`]. +/// +/// [`MakeService`]: tower::make::MakeService pub trait HandlerWithoutStateExt: Handler { /// Convert the handler into a [`Service`] and no state. /// - /// See [`Handler::into_service_with`] for more details. - fn into_service(self) -> IntoService; + /// See [`WithState::into_service_with_state`] for more details. + fn into_service(self) -> WithState; /// Convert the handler into a [`MakeService`] and no state. /// - /// See [`Handler::into_make_service_with`] for more details. + /// See [`WithState::into_make_service_with_state`] for more details. /// /// [`MakeService`]: tower::make::MakeService fn into_make_service(self) -> IntoMakeService>; @@ -354,7 +269,7 @@ pub trait HandlerWithoutStateExt: Handler { /// Convert the handler into a [`MakeService`] which stores information /// about the incoming connection and has no state. /// - /// See [`Handler::into_make_service_with_connect_info_and_state`] for more details. + /// See [`WithState::into_make_service_with_connect_info_and_state`] for more details. /// /// [`MakeService`]: tower::make::MakeService fn into_make_service_with_connect_info( @@ -366,18 +281,18 @@ impl HandlerWithoutStateExt for H where H: Handler, { - fn into_service(self) -> IntoService { - self.into_service_with(()) + fn into_service(self) -> WithState { + self.with_state(()) } fn into_make_service(self) -> IntoMakeService> { - self.into_make_service_with(()) + self.with_state(()).into_make_service() } fn into_make_service_with_connect_info( self, ) -> IntoMakeServiceWithConnectInfo, C> { - self.into_make_service_with_connect_info_and_state(()) + self.with_state(()).into_make_service_with_connect_info() } } diff --git a/axum/src/handler/with_state.rs b/axum/src/handler/with_state.rs new file mode 100644 index 0000000000..a338642ae2 --- /dev/null +++ b/axum/src/handler/with_state.rs @@ -0,0 +1,118 @@ +use super::{Handler, IntoService}; +use crate::{extract::connect_info::IntoMakeServiceWithConnectInfo, routing::IntoMakeService}; +use http::Request; +use std::task::{Context, Poll}; +use tower_service::Service; + +/// A [`Handler`] with state provided. +/// +/// Implements [`Service`]. +/// +/// Created with [`Handler::with_state`]. +pub struct WithState { + pub(super) service: IntoService, +} + +impl WithState { + /// Convert the handler into a [`MakeService`]. + /// + /// This allows you to serve a single handler if you don't need any routing: + /// + /// ```rust + /// use axum::{ + /// Server, handler::Handler, http::{Uri, Method}, response::IntoResponse, + /// }; + /// use std::net::SocketAddr; + /// + /// async fn handler(method: Method, uri: Uri, body: String) -> String { + /// format!("received `{} {}` with body `{:?}`", method, uri, body) + /// } + /// + /// # async { + /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) + /// .serve(handler.with_state(()).into_make_service()) + /// .await?; + /// # Ok::<_, hyper::Error>(()) + /// # }; + /// ``` + /// + /// [`MakeService`]: tower::make::MakeService + pub fn into_make_service(self) -> IntoMakeService> { + IntoMakeService::new(self.service) + } + + /// Convert the handler into a [`MakeService`] which stores information + /// about the incoming connection. + /// + /// See [`Router::into_make_service_with_connect_info`] for more details. + /// + /// ```rust + /// use axum::{ + /// Server, + /// handler::Handler, + /// response::IntoResponse, + /// extract::ConnectInfo, + /// }; + /// use std::net::SocketAddr; + /// + /// async fn handler(ConnectInfo(addr): ConnectInfo) -> String { + /// format!("Hello {}", addr) + /// } + /// + /// # async { + /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) + /// .serve(handler.with_state(()).into_make_service_with_connect_info::()) + /// .await?; + /// # Ok::<_, hyper::Error>(()) + /// # }; + /// ``` + /// + /// [`MakeService`]: tower::make::MakeService + /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info + pub fn into_make_service_with_connect_info( + self, + ) -> IntoMakeServiceWithConnectInfo, C> { + IntoMakeServiceWithConnectInfo::new(self.service) + } +} + +impl Service> for WithState +where + H: Handler + Clone + Send + 'static, + B: Send + 'static, + S: Clone, +{ + type Response = as Service>>::Response; + type Error = as Service>>::Error; + type Future = as Service>>::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + #[inline] + fn call(&mut self, req: Request) -> Self::Future { + self.service.call(req) + } +} + +impl std::fmt::Debug for WithState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WithState") + .field("service", &self.service) + .finish() + } +} + +impl Clone for WithState +where + H: Clone, + S: Clone, +{ + fn clone(&self) -> Self { + Self { + service: self.service.clone(), + } + } +} diff --git a/axum/src/lib.rs b/axum/src/lib.rs index fabddf499a..bd3167e354 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -168,13 +168,48 @@ //! pool of database connections or clients to other services. //! //! The two most common ways of doing that are: +//! - Using the [`State`] extractor. //! - Using request extensions //! - Using closure captures //! +//! ## Using the [`State`] extractor +//! +//! ```rust,no_run +//! use axum::{ +//! extract::State, +//! routing::get, +//! Router, +//! }; +//! use std::sync::Arc; +//! +//! struct AppState { +//! // ... +//! } +//! +//! let shared_state = Arc::new(AppState { /* ... */ }); +//! +//! let app = Router::with_state(shared_state) +//! .route("/", get(handler)); +//! +//! async fn handler( +//! State(state): State>, +//! ) { +//! // ... +//! } +//! # async { +//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` +//! +//! You should prefer using [`State`] if possible since its more type safe. The downside is that +//! its less dynamic than request extensions. +//! +//! See [`State`] for more details about accessing state. +//! //! ## Using request extensions //! -//! The easiest way to extract state in handlers is using [`Extension`](crate::extract::Extension) -//! as layer and extractor: +//! Another way to extract state in handlers is using [`Extension`](crate::extract::Extension) as +//! layer and extractor: //! //! ```rust,no_run //! use axum::{ @@ -184,18 +219,18 @@ //! }; //! use std::sync::Arc; //! -//! struct State { +//! struct AppState { //! // ... //! } //! -//! let shared_state = Arc::new(State { /* ... */ }); +//! let shared_state = Arc::new(AppState { /* ... */ }); //! //! let app = Router::new() //! .route("/", get(handler)) //! .layer(Extension(shared_state)); //! //! async fn handler( -//! Extension(state): Extension>, +//! Extension(state): Extension>, //! ) { //! // ... //! } @@ -223,11 +258,11 @@ //! use std::sync::Arc; //! use serde::Deserialize; //! -//! struct State { +//! struct AppState { //! // ... //! } //! -//! let shared_state = Arc::new(State { /* ... */ }); +//! let shared_state = Arc::new(AppState { /* ... */ }); //! //! let app = Router::new() //! .route( @@ -245,11 +280,11 @@ //! }), //! ); //! -//! async fn get_user(Path(user_id): Path, state: Arc) { +//! async fn get_user(Path(user_id): Path, state: Arc) { //! // ... //! } //! -//! async fn create_user(Json(payload): Json, state: Arc) { +//! async fn create_user(Json(payload): Json, state: Arc) { //! // ... //! } //! @@ -263,7 +298,7 @@ //! ``` //! //! The downside to this approach is that it's a little more verbose than using -//! extensions. +//! [`State`] or extensions. //! //! # Building integrations for axum //! @@ -350,6 +385,7 @@ //! [`Infallible`]: std::convert::Infallible //! [load shed]: tower::load_shed //! [`axum-core`]: http://crates.io/crates/axum-core +//! [`State`]: crate::extract::State #![warn( clippy::all, @@ -384,7 +420,7 @@ future_incompatible, nonstandard_style, missing_debug_implementations, - // missing_docs + missing_docs )] #![deny(unreachable_pub, private_in_public)] #![allow(elided_lifetimes_in_paths, clippy::type_complexity)] diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 6b3a051ca2..b7f927ee6c 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -547,6 +547,7 @@ impl MethodRouter { } } + /// Provide the state. pub fn with_state(self, state: S) -> MethodRouterWithState { MethodRouterWithState { method_router: self, @@ -616,6 +617,21 @@ where chained_handler_fn!(put, PUT); chained_handler_fn!(trace, TRACE); + /// Add a fallback [`Handler`] to the router. + pub fn fallback(self, handler: H) -> Self + where + H: Handler, + T: 'static, + S: Clone + Send + Sync + 'static, + { + self.fallback_service(IntoServiceStateInExtension::new(handler)) + } +} + +impl MethodRouter<(), B, Infallible> +where + B: Send + 'static, +{ /// Convert the handler into a [`MakeService`]. /// /// This allows you to serve a single handler if you don't need any routing: @@ -683,15 +699,6 @@ where pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { IntoMakeServiceWithConnectInfo::new(self) } - - pub fn fallback(self, handler: H) -> Self - where - H: Handler, - T: 'static, - S: Clone + Send + Sync + 'static, - { - self.fallback_service(IntoServiceStateInExtension::new(handler)) - } } impl MethodRouter { @@ -1084,15 +1091,40 @@ where } } +/// A [`MethodRouter`] which has access to some state. +/// +/// The state can be extracted with [`State`](crate::extract::State). +/// +/// Created with [`MethodRouter::with_state`] pub struct MethodRouterWithState { method_router: MethodRouter, state: S, } impl MethodRouterWithState { + /// Get a reference to the state. pub fn state(&self) -> &S { &self.state } + + /// Convert the handler into a [`MakeService`]. + /// + /// See [`MethodRouter::into_make_service`] for more details. + /// + /// [`MakeService`]: tower::make::MakeService + pub fn into_make_service(self) -> IntoMakeService { + IntoMakeService::new(self) + } + + /// Convert the router into a [`MakeService`] which stores information + /// about the incoming connection. + /// + /// See [`MethodRouter::into_make_service_with_connect_info`] for more details. + /// + /// [`MakeService`]: tower::make::MakeService + pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo { + IntoMakeServiceWithConnectInfo::new(self) + } } impl Clone for MethodRouterWithState diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 6aad74ea15..953ba0d061 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -132,6 +132,12 @@ where B: HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { + /// Create a new `Router` with the given state. + /// + /// See [`State`](crate::extract::State) for more details about accessing state. + /// + /// Unless you add additional routes this will respond with `404 Not Found` to + /// all requests. pub fn with_state(state: S) -> Self { Self { state, @@ -430,7 +436,7 @@ where T: 'static, { let state = self.state.clone(); - self.fallback_service(handler.into_service_with(state)) + self.fallback_service(handler.with_state(state)) } /// Add a fallback [`Service`] to the router. @@ -535,6 +541,7 @@ where } } + /// Get a reference to the state. pub fn state(&self) -> &S { &self.state } From 96a41f47357ffc7137ad716698099cb992dab08f Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Jul 2022 18:08:34 +0200 Subject: [PATCH 14/45] docs clean ups --- axum/src/extract/state.rs | 6 +++--- axum/src/handler/into_service.rs | 4 +++- axum/src/handler/mod.rs | 6 ++---- axum/src/handler/with_state.rs | 31 +++++++++++++++++++++++------- axum/src/routing/method_routing.rs | 2 ++ axum/src/routing/mod.rs | 4 ++-- 6 files changed, 36 insertions(+), 17 deletions(-) diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index f4a81317bd..da85eb09dc 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -9,7 +9,7 @@ use std::{ /// /// # Examples /// -/// ## With [`Router`] +/// ## With `Router` /// /// ``` /// use axum::{Router, routing::get, extract::State}; @@ -83,7 +83,7 @@ use std::{ /// # let _: Router = app; /// ``` /// -/// ## With [`MethodRouter`] +/// ## With `MethodRouter` /// /// ``` /// use axum::{routing::get, extract::State}; @@ -105,7 +105,7 @@ use std::{ /// # }; /// ``` /// -/// ## With [`Handler`] +/// ## With `Handler` /// /// ``` /// use axum::{routing::get, handler::Handler, extract::State}; diff --git a/axum/src/handler/into_service.rs b/axum/src/handler/into_service.rs index d59e3d869c..151395dd61 100644 --- a/axum/src/handler/into_service.rs +++ b/axum/src/handler/into_service.rs @@ -11,7 +11,9 @@ use tower_service::Service; /// An adapter that makes a [`Handler`] into a [`Service`]. /// -/// Created with [`Handler::into_service_with`]. +/// Created with [`HandlerWithoutStateExt::into_service`]. +/// +/// [`HandlerWithoutStateExt::into_service`]: super::HandlerWithoutStateExt::into_service pub struct IntoService { handler: H, state: S, diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index dc25100db6..a40c6de069 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -255,13 +255,11 @@ where /// [`MakeService`]: tower::make::MakeService pub trait HandlerWithoutStateExt: Handler { /// Convert the handler into a [`Service`] and no state. - /// - /// See [`WithState::into_service_with_state`] for more details. fn into_service(self) -> WithState; /// Convert the handler into a [`MakeService`] and no state. /// - /// See [`WithState::into_make_service_with_state`] for more details. + /// See [`WithState::into_make_service`] for more details. /// /// [`MakeService`]: tower::make::MakeService fn into_make_service(self) -> IntoMakeService>; @@ -269,7 +267,7 @@ pub trait HandlerWithoutStateExt: Handler { /// Convert the handler into a [`MakeService`] which stores information /// about the incoming connection and has no state. /// - /// See [`WithState::into_make_service_with_connect_info_and_state`] for more details. + /// See [`WithState::into_make_service_with_connect_info`] for more details. /// /// [`MakeService`]: tower::make::MakeService fn into_make_service_with_connect_info( diff --git a/axum/src/handler/with_state.rs b/axum/src/handler/with_state.rs index a338642ae2..73123ab11c 100644 --- a/axum/src/handler/with_state.rs +++ b/axum/src/handler/with_state.rs @@ -20,17 +20,26 @@ impl WithState { /// /// ```rust /// use axum::{ - /// Server, handler::Handler, http::{Uri, Method}, response::IntoResponse, + /// Server, + /// handler::Handler, + /// extract::State, + /// http::{Uri, Method}, + /// response::IntoResponse, /// }; /// use std::net::SocketAddr; /// - /// async fn handler(method: Method, uri: Uri, body: String) -> String { - /// format!("received `{} {}` with body `{:?}`", method, uri, body) + /// #[derive(Clone)] + /// struct AppState {} + /// + /// async fn handler(State(state): State) { + /// // ... /// } /// + /// let app = handler.with_state(AppState {}); + /// /// # async { /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) - /// .serve(handler.with_state(()).into_make_service()) + /// .serve(app.into_make_service()) /// .await?; /// # Ok::<_, hyper::Error>(()) /// # }; @@ -51,17 +60,25 @@ impl WithState { /// Server, /// handler::Handler, /// response::IntoResponse, - /// extract::ConnectInfo, + /// extract::{ConnectInfo, State}, /// }; /// use std::net::SocketAddr; /// - /// async fn handler(ConnectInfo(addr): ConnectInfo) -> String { + /// #[derive(Clone)] + /// struct AppState {}; + /// + /// async fn handler( + /// ConnectInfo(addr): ConnectInfo, + /// State(state): State, + /// ) -> String { /// format!("Hello {}", addr) /// } /// + /// let app = handler.with_state(AppState {}); + /// /// # async { /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000))) - /// .serve(handler.with_state(()).into_make_service_with_connect_info::()) + /// .serve(app.into_make_service_with_connect_info::()) /// .await?; /// # Ok::<_, hyper::Error>(()) /// # }; diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index b7f927ee6c..04a06f1aec 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1,3 +1,5 @@ +//! Route to services and handlers based on HTTP methods. + use super::IntoMakeService; use crate::{ body::{boxed, Body, Bytes, Empty, HttpBody}, diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 953ba0d061..42245608a3 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -26,10 +26,10 @@ use tower_layer::Layer; use tower_service::Service; pub mod future; +pub mod method_routing; mod into_make_service; mod method_filter; -mod method_routing; mod not_found; mod route; mod strip_prefix; @@ -43,7 +43,7 @@ pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, pub use self::method_routing::{ any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service, options, options_service, patch, patch_service, post, post_service, put, put_service, trace, - trace_service, MethodRouter, MethodRouterWithState, + trace_service, MethodRouter, }; #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] From 838201c9f0d56034e463cd31dbd3c7a5b86f73db Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Jul 2022 18:35:29 +0200 Subject: [PATCH 15/45] more consistent naming --- axum/src/handler/into_service.rs | 7 +++++++ axum/src/handler/with_state.rs | 11 ++++++++++- axum/src/routing/method_routing.rs | 18 ++++++++++-------- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/axum/src/handler/into_service.rs b/axum/src/handler/into_service.rs index 151395dd61..c8e51a243c 100644 --- a/axum/src/handler/into_service.rs +++ b/axum/src/handler/into_service.rs @@ -20,6 +20,13 @@ pub struct IntoService { _marker: PhantomData (T, B)>, } +impl IntoService { + /// Get a reference to the state. + pub fn state(&self) -> &S { + &self.state + } +} + #[test] fn traits() { use crate::test_helpers::*; diff --git a/axum/src/handler/with_state.rs b/axum/src/handler/with_state.rs index 73123ab11c..4afc9b106a 100644 --- a/axum/src/handler/with_state.rs +++ b/axum/src/handler/with_state.rs @@ -4,15 +4,24 @@ use http::Request; use std::task::{Context, Poll}; use tower_service::Service; -/// A [`Handler`] with state provided. +/// A [`Handler`] which has access to some state. /// /// Implements [`Service`]. /// +/// The state can be extracted with [`State`](crate::extract::State). +/// /// Created with [`Handler::with_state`]. pub struct WithState { pub(super) service: IntoService, } +impl WithState { + /// Get a reference to the state. + pub fn state(&self) -> &S { + self.service.state() + } +} + impl WithState { /// Convert the handler into a [`MakeService`]. /// diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 04a06f1aec..bedb9bb06a 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -550,8 +550,8 @@ impl MethodRouter { } /// Provide the state. - pub fn with_state(self, state: S) -> MethodRouterWithState { - MethodRouterWithState { + pub fn with_state(self, state: S) -> WithState { + WithState { method_router: self, state, } @@ -1095,15 +1095,17 @@ where /// A [`MethodRouter`] which has access to some state. /// +/// Implements [`Service`]. +/// /// The state can be extracted with [`State`](crate::extract::State). /// /// Created with [`MethodRouter::with_state`] -pub struct MethodRouterWithState { +pub struct WithState { method_router: MethodRouter, state: S, } -impl MethodRouterWithState { +impl WithState { /// Get a reference to the state. pub fn state(&self) -> &S { &self.state @@ -1129,7 +1131,7 @@ impl MethodRouterWithState { } } -impl Clone for MethodRouterWithState +impl Clone for WithState where S: Clone, { @@ -1141,19 +1143,19 @@ where } } -impl fmt::Debug for MethodRouterWithState +impl fmt::Debug for WithState where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("MethodRouterWithState") + f.debug_struct("WithState") .field("method_router", &self.method_router) .field("state", &self.state) .finish() } } -impl Service> for MethodRouterWithState +impl Service> for WithState where B: HttpBody, S: Clone + Send + Sync + 'static, From ea8160c9227ef187662517e0390ee51b7114de51 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Jul 2022 18:46:24 +0200 Subject: [PATCH 16/45] when does MethodRouter implement Service? --- axum/src/routing/method_routing.rs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index bedb9bb06a..f94049aa43 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -484,6 +484,33 @@ where /// A [`Service`] that accepts requests based on a [`MethodFilter`] and /// allows chaining additional handlers and services. +/// +/// # When does `MethodRouter` implement [`Service`]? +/// +/// Whether or not `MethodRouter` implements [`Service`] depends on the state type it requires. +/// +/// ``` +/// use tower::Service; +/// use axum::{routing::get, extract::State, body::Body, http::Request}; +/// +/// // this `MethodRouter` doesn't require any state, i.e. the state is `()`, +/// let method_router = get(|| async {}); +/// // and thus it implements `Service` +/// assert_service(method_router); +/// +/// // this requires a `String` and doesn't implement `Service` +/// let method_router = get(|_: State| async {}); +/// // until you provide the `String` with `.with_state(...)` +/// let method_router_with_state = method_router.with_state(String::new()); +/// // and then it implements `Service` +/// assert_service(method_router_with_state); +/// +/// // helper to check that a value implements `Service` +/// fn assert_service(service: S) +/// where +/// S: Service>, +/// {} +/// ``` pub struct MethodRouter { get: Option>, head: Option>, From 0f8047936adf4c54958c2d89e67e9fcc31c4b87e Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Jul 2022 18:48:43 +0200 Subject: [PATCH 17/45] add missing docs --- axum/src/routing/method_routing.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index f94049aa43..f4720ea67e 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -577,6 +577,8 @@ impl MethodRouter { } /// Provide the state. + /// + /// See [`State`](crate::extract::State) for more details about accessing state. pub fn with_state(self, state: S) -> WithState { WithState { method_router: self, From 7de141c0cb8c4ee2e41766e45aa55d305cd29265 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Jul 2022 18:58:45 +0200 Subject: [PATCH 18/45] use `Router`'s default state type param --- axum-extra/src/body/async_read_body.rs | 2 +- axum-extra/src/extract/cookie/mod.rs | 2 +- axum-extra/src/extract/cookie/private.rs | 2 +- axum-extra/src/extract/cookie/signed.rs | 2 +- axum-extra/src/routing/mod.rs | 2 +- axum-extra/src/routing/resource.rs | 2 +- axum-extra/src/routing/spa.rs | 6 +++--- axum-extra/src/routing/typed.rs | 2 +- axum/src/extract/multipart.rs | 2 +- axum/src/middleware/from_fn.rs | 8 ++++---- axum/src/routing/tests/mod.rs | 4 ++-- examples/cors/src/main.rs | 2 +- 12 files changed, 18 insertions(+), 18 deletions(-) diff --git a/axum-extra/src/body/async_read_body.rs b/axum-extra/src/body/async_read_body.rs index 0b39ff8079..6e66f0dfda 100644 --- a/axum-extra/src/body/async_read_body.rs +++ b/axum-extra/src/body/async_read_body.rs @@ -42,7 +42,7 @@ pin_project! { /// } /// /// let app = Router::new().route("/Cargo.toml", get(cargo_toml)); - /// # let _: Router<()> = app; + /// # let _: Router = app; /// ``` #[cfg(feature = "async-read-body")] #[derive(Debug)] diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index 601721c7b5..6b6f43a808 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -80,7 +80,7 @@ pub use cookie_lib::Key; /// let app = Router::new() /// .route("/sessions", post(create_session)) /// .route("/me", get(me)); -/// # let app: Router<()> = app; +/// # let app: Router = app; /// ``` #[derive(Debug)] pub struct CookieJar { diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index 7ef300f918..832bf2400d 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -54,7 +54,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// .route("/get", get(get_secret)) /// // add extension with the key so `PrivateCookieJar` can access it /// .layer(Extension(key)); -/// # let app: Router<()> = app; +/// # let app: Router = app; /// ``` pub struct PrivateCookieJar { jar: cookie_lib::CookieJar, diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index 2a9a4f96d7..fce4f413e0 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -72,7 +72,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// .route("/me", get(me)) /// // add extension with the key so `SignedCookieJar` can access it /// .layer(Extension(key)); -/// # let app: Router<()> = app; +/// # let app: Router = app; /// ``` pub struct SignedCookieJar { jar: cookie_lib::CookieJar, diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 3ed31f7870..67da57f106 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -156,7 +156,7 @@ pub trait RouterExt: sealed::Sealed { /// .route_with_tsr("/foo", get(|| async {})) /// // `/bar` will redirect to `/bar/` /// .route_with_tsr("/bar/", get(|| async {})); - /// # let _: Router<()> = app; + /// # let _: Router = app; /// ``` fn route_with_tsr(self, path: &str, method_router: MethodRouter) -> Self where diff --git a/axum-extra/src/routing/resource.rs b/axum-extra/src/routing/resource.rs index 189fa89c3d..e85663a47d 100644 --- a/axum-extra/src/routing/resource.rs +++ b/axum-extra/src/routing/resource.rs @@ -45,7 +45,7 @@ use tower_service::Service; /// ); /// /// let app = Router::new().merge(users); -/// # let _: Router<()> = app; +/// # let _: Router = app; /// ``` #[derive(Debug)] pub struct Resource { diff --git a/axum-extra/src/routing/spa.rs b/axum-extra/src/routing/spa.rs index b5c883a321..73b10baf00 100644 --- a/axum-extra/src/routing/spa.rs +++ b/axum-extra/src/routing/spa.rs @@ -36,7 +36,7 @@ use tower_service::Service; /// .merge(spa) /// // we can still add other routes /// .route("/api/foo", get(api_foo)); -/// # let _: Router<()> = app; +/// # let _: Router = app; /// /// async fn api_foo() {} /// ``` @@ -101,7 +101,7 @@ impl SpaRouter { /// .index_file("another_file.html"); /// /// let app = Router::new().merge(spa); - /// # let _: Router<()> = app; + /// # let _: Router = app; /// ``` pub fn index_file

(mut self, path: P) -> Self where @@ -136,7 +136,7 @@ impl SpaRouter { /// } /// /// let app = Router::new().merge(spa); - /// # let _: Router<()> = app; + /// # let _: Router = app; /// ``` pub fn handle_error(self, f: F2) -> SpaRouter { SpaRouter { diff --git a/axum-extra/src/routing/typed.rs b/axum-extra/src/routing/typed.rs index c927a279ee..159472a063 100644 --- a/axum-extra/src/routing/typed.rs +++ b/axum-extra/src/routing/typed.rs @@ -60,7 +60,7 @@ use http::Uri; /// async fn users_destroy(_: UsersCollection) { /* ... */ } /// /// # -/// # let app: Router<()> = app; +/// # let app: Router = app; /// ``` /// /// # Using `#[derive(TypedPath)]` diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index afcbbf1aba..a3cb9ea3e0 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -180,7 +180,7 @@ impl<'a> Field<'a> { /// } /// /// let app = Router::new().route("/upload", post(upload)); - /// # let _: Router<()> = app; + /// # let _: Router = app; /// ``` pub async fn chunk(&mut self) -> Result, MultipartError> { self.inner diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index 9ec59ebc8f..a4a55bd046 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -61,7 +61,7 @@ use tower_service::Service; /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .route_layer(middleware::from_fn(auth)); -/// # let app: Router<()> = app; +/// # let app: Router = app; /// ``` /// /// # Running extractors @@ -92,7 +92,7 @@ use tower_service::Service; /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .route_layer(middleware::from_fn(my_middleware)); -/// # let app: Router<()> = app; +/// # let app: Router = app; /// ``` /// /// # Passing state @@ -127,7 +127,7 @@ use tower_service::Service; /// .route_layer(middleware::from_fn(move |req, next| { /// my_middleware(req, next, state.clone()) /// })); -/// # let app: Router<()> = app; +/// # let app: Router = app; /// ``` /// /// Or via extensions: @@ -164,7 +164,7 @@ use tower_service::Service; /// .layer(Extension(state)) /// .layer(middleware::from_fn(my_middleware)), /// ); -/// # let app: Router<()> = app; +/// # let app: Router = app; /// ``` /// /// [extractors]: crate::extract::FromRequest diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 4b92c7f94f..fda9abfcf3 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -100,7 +100,7 @@ async fn routing() { #[tokio::test] async fn router_type_doesnt_change() { - let app: Router<()> = Router::new() + let app: Router = Router::new() .route( "/", on(MethodFilter::GET, |_: Request| async { @@ -365,7 +365,7 @@ async fn wildcard_with_trailing_slash() { path: String, } - let app: Router<()> = Router::new().route( + let app: Router = Router::new().route( "/:user/:repo/tree/*path", get(|Path(tree): Path| async move { Json(tree) }), ); diff --git a/examples/cors/src/main.rs b/examples/cors/src/main.rs index d7ae530706..a8e02cc559 100644 --- a/examples/cors/src/main.rs +++ b/examples/cors/src/main.rs @@ -38,7 +38,7 @@ async fn main() { tokio::join!(frontend, backend); } -async fn serve(app: Router<()>, port: u16) { +async fn serve(app: Router, port: u16) { let addr = SocketAddr::from(([127, 0, 0, 1], port)); axum::Server::bind(&addr) .serve(app.into_make_service()) From 9fd0c5efb54699ef64be5f4c0250c93973541876 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 26 Jul 2022 19:09:46 +0200 Subject: [PATCH 19/45] changelog --- axum/CHANGELOG.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index a510d28222..95640a9446 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -21,12 +21,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Added `debug_handler` which is an attribute macro that improves type errors when applied to handler function. It is re-exported from `axum-macros` +- **added:** Added new type safe `State` extractor. This can be used with + `Router::with_state` and gives compile errors for missing states, whereas + `Extension` would result in runtime errors ([#1155]) +- **breaking:** The following types or traits have a new `S` type param + (`()` by default) which represents the state ([#1155]): + - `FromRequest` + - `RequestParts` + - `Router` + - `MethodRouter` + - `Handler` + - `Resource` +- **breaking:** `RequestParts::new` takes the state and the request rather than + just the request ([#1155]) +- **breaking:** `Router::route` now only accepts `MethodRouter`s created with + `get`, `post`, etc ([#1155]) +- **added:** `Router::route_service` for routing to arbitrary `Service`s ([#1155]) [#1077]: https://github.com/tokio-rs/axum/pull/1077 [#1088]: https://github.com/tokio-rs/axum/pull/1088 [#1102]: https://github.com/tokio-rs/axum/pull/1102 [#1119]: https://github.com/tokio-rs/axum/pull/1119 [#1130]: https://github.com/tokio-rs/axum/pull/1130 +[#1155]: https://github.com/tokio-rs/axum/pull/1155 [#924]: https://github.com/tokio-rs/axum/pull/924 # 0.5.10 (28. June, 2022) From 619f267c7e2939ddf6a2d94c585ca1aae1bdbd1d Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 27 Jul 2022 06:40:55 +0200 Subject: [PATCH 20/45] don't use default type param for FromRequest and RequestParts probably safer for library authors so you don't accidentally forget --- axum-core/src/extract/mod.rs | 4 ++-- .../tests/debug_handler/fail/argument_not_extractor.stderr | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index 2be92cd96e..b052b3867d 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -62,7 +62,7 @@ mod tuple; /// [`http::Request`]: http::Request /// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html #[async_trait] -pub trait FromRequest: Sized { +pub trait FromRequest: Sized { /// If the extractor fails it'll use this "rejection" type. A rejection is /// a kind of error that can be converted into a response. type Rejection: IntoResponse; @@ -75,7 +75,7 @@ pub trait FromRequest: Sized { /// /// Has several convenience methods for getting owned parts of the request. #[derive(Debug)] -pub struct RequestParts { +pub struct RequestParts { state: S, method: Method, uri: Uri, diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr index 468c4e17f0..420005970c 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr @@ -1,8 +1,8 @@ -error[E0277]: the trait bound `bool: FromRequest` is not satisfied +error[E0277]: the trait bound `bool: FromRequest` is not satisfied --> tests/debug_handler/fail/argument_not_extractor.rs:4:23 | 4 | async fn handler(foo: bool) {} - | ^^^^ the trait `FromRequest` is not implemented for `bool` + | ^^^^ the trait `FromRequest` is not implemented for `bool` | = help: the following other types implement trait `FromRequest`: <() as FromRequest> From 32ace9566f69c3433599f5ca381f609023555a37 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 11 Aug 2022 15:27:41 +0200 Subject: [PATCH 21/45] fix examples --- .../consume-body-in-extractor-or-middleware/src/main.rs | 6 +++--- examples/oauth/src/main.rs | 4 ++-- examples/sessions/src/main.rs | 4 ++-- examples/sqlx-postgres/src/main.rs | 4 ++-- examples/tls-rustls/src/main.rs | 2 +- examples/tokio-postgres/src/main.rs | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/consume-body-in-extractor-or-middleware/src/main.rs b/examples/consume-body-in-extractor-or-middleware/src/main.rs index 11b7a6b28f..6057963161 100644 --- a/examples/consume-body-in-extractor-or-middleware/src/main.rs +++ b/examples/consume-body-in-extractor-or-middleware/src/main.rs @@ -80,13 +80,13 @@ async fn handler(_: PrintRequestBody, body: Bytes) { struct PrintRequestBody; #[async_trait] -impl FromRequest for PrintRequestBody +impl FromRequest for PrintRequestBody where S: Send + Clone, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let state = req.state().clone(); let request = Request::from_request(req) @@ -95,7 +95,7 @@ where let request = buffer_request_body(request).await?; - *req = RequestParts::new(state, request); + *req = RequestParts::new(request, state); Ok(Self) } diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index 303ca1647f..ba87b1cc95 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -223,14 +223,14 @@ impl IntoResponse for AuthRedirect { } #[async_trait] -impl FromRequest for User +impl FromRequest for User where B: Send, { // If anything goes wrong or no session is found, redirect to the auth page type Rejection = AuthRedirect; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let store = req.state().clone().store; let cookies = TypedHeader::::from_request(req) diff --git a/examples/sessions/src/main.rs b/examples/sessions/src/main.rs index cd0d41a1f6..50c1feb76f 100644 --- a/examples/sessions/src/main.rs +++ b/examples/sessions/src/main.rs @@ -80,13 +80,13 @@ enum UserIdFromSession { } #[async_trait] -impl FromRequest for UserIdFromSession +impl FromRequest for UserIdFromSession where B: Send, { type Rejection = (StatusCode, &'static str); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let store = req.state().clone(); let cookie = req.extract::>>().await.unwrap(); diff --git a/examples/sqlx-postgres/src/main.rs b/examples/sqlx-postgres/src/main.rs index 6548cdeb97..c76444079d 100644 --- a/examples/sqlx-postgres/src/main.rs +++ b/examples/sqlx-postgres/src/main.rs @@ -75,13 +75,13 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(sqlx::pool::PoolConnection); #[async_trait] -impl FromRequest for DatabaseConnection +impl FromRequest for DatabaseConnection where B: Send, { type Rejection = (StatusCode, String); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let pool = req.state().clone(); let conn = pool.acquire().await.map_err(internal_error)?; diff --git a/examples/tls-rustls/src/main.rs b/examples/tls-rustls/src/main.rs index 40008a9631..6eec0e9c2a 100644 --- a/examples/tls-rustls/src/main.rs +++ b/examples/tls-rustls/src/main.rs @@ -6,7 +6,7 @@ use axum::{ extract::Host, - handler::Handler, + handler::HandlerWithoutStateExt, http::{StatusCode, Uri}, response::Redirect, routing::get, diff --git a/examples/tokio-postgres/src/main.rs b/examples/tokio-postgres/src/main.rs index e0c60453e3..9b2310342c 100644 --- a/examples/tokio-postgres/src/main.rs +++ b/examples/tokio-postgres/src/main.rs @@ -68,14 +68,14 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager>); #[async_trait] -impl FromRequest for DatabaseConnection +impl FromRequest for DatabaseConnection where B: Send, { type Rejection = (StatusCode, String); async fn from_request( - req: &mut RequestParts, + req: &mut RequestParts, ) -> Result { let pool = req.state().clone(); From dec3e46e334c976e08a8fe82aa15fa313bdc1cfc Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 11 Aug 2022 16:20:16 +0200 Subject: [PATCH 22/45] minor docs tweaks --- axum/src/extract/state.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index da85eb09dc..0d811959c4 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -93,7 +93,7 @@ use std::{ /// /// let state = AppState {}; /// -/// let app = get(handler) +/// let method_router_with_state = get(handler) /// // provide the state so the handler can access it /// .with_state(state); /// @@ -101,7 +101,7 @@ use std::{ /// // use `state`... /// } /// # async { -/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # axum::Server::bind(&"".parse().unwrap()).serve(method_router_with_state.into_make_service()).await.unwrap(); /// # }; /// ``` /// @@ -115,15 +115,16 @@ use std::{ /// /// let state = AppState {}; /// -/// let app = handler.with_state(state); -/// /// async fn handler(State(state): State) { /// // use `state`... /// } /// +/// // provide the state so the handler can access it +/// let handler_with_state = handler.with_state(state); +/// /// # async { /// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) -/// .serve(app.into_make_service()) +/// .serve(handler_with_state.into_make_service()) /// .await /// .expect("server failed"); /// # }; From 703d7382b2ff572185a1c665244aece8f7f102df Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 11 Aug 2022 16:41:29 +0200 Subject: [PATCH 23/45] clarify how to convert handlers into services --- axum/src/handler/mod.rs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index b4c7f4729a..4f1e848783 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -62,6 +62,38 @@ pub use self::{into_service::IntoService, with_state::WithState}; /// /// See the [module docs](crate::handler) for more details. /// +/// # Converting `Handler`s into [`Service`]s +/// +/// To convert `Handler`s into [`Service`]s you have to call either +/// [`HandlerWithoutStateExt::into_service`] or [`WithState::into_service`]: +/// +/// ``` +/// use tower::Service; +/// use axum::{ +/// extract::State, +/// body::Body, +/// http::Request, +/// handler::{HandlerWithoutStateExt, Handler}, +/// }; +/// +/// // this handler doesn't require any state +/// async fn one() {} +/// // so it can be converted to a service with `HandlerWithoutStateExt::into_service` +/// assert_service(one.into_service()); +/// +/// // this handler requires state +/// async fn two(_: State) {} +/// // so we have to provide it +/// let handler_with_state = two.with_state(String::new()); +/// // which gives us a `Service` +/// assert_service(handler_with_state); +/// +/// // helper to check that a value implements `Service` +/// fn assert_service(service: S) +/// where +/// S: Service>, +/// {} +/// ``` #[doc = include_str!("../docs/debugging_handler_type_errors.md")] pub trait Handler: Clone + Send + Sized + 'static { /// The type of future calling this handler returns. From 28df3c5e5b7a915f2101adf9a3b4d6e8b718cc04 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 11 Aug 2022 16:41:47 +0200 Subject: [PATCH 24/45] group methods in one impl block --- axum/src/routing/method_routing.rs | 112 +++++++++++++++-------------- 1 file changed, 57 insertions(+), 55 deletions(-) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index a12f4ddc2c..9c32103863 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -2,7 +2,7 @@ use super::IntoMakeService; use crate::{ - body::{boxed, Body, Bytes, Empty, HttpBody}, + body::{Body, Bytes, HttpBody}, error_handling::{HandleError, HandleErrorLayer}, extract::connect_info::IntoMakeServiceWithConnectInfo, handler::{Handler, IntoServiceStateInExtension}, @@ -82,6 +82,7 @@ macro_rules! top_level_service_fn { T: Service> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, + B: Send + 'static, { on_service(MethodFilter::$method, svc) } @@ -324,6 +325,7 @@ where T: Service> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, + B: Send + 'static, { MethodRouter::new().on_service(filter, svc) } @@ -386,6 +388,7 @@ where T: Service> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, + B: Send + 'static, { MethodRouter::new() .fallback_service(svc) @@ -564,58 +567,6 @@ impl fmt::Debug for MethodRouter { } } -impl MethodRouter { - /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all - /// requests. - pub fn new() -> Self { - let fallback = Route::new(service_fn(|_: Request| async { - let mut response = Response::new(boxed(Empty::new())); - *response.status_mut() = StatusCode::METHOD_NOT_ALLOWED; - Ok(response) - })); - - Self { - get: None, - head: None, - delete: None, - options: None, - patch: None, - post: None, - put: None, - trace: None, - allow_header: AllowHeader::None, - fallback: Fallback::Default(fallback), - _marker: PhantomData, - } - } - - /// Provide the state. - /// - /// See [`State`](crate::extract::State) for more details about accessing state. - pub fn with_state(self, state: S) -> WithState { - WithState { - method_router: self, - state, - } - } - - pub(crate) fn downcast_state(self) -> MethodRouter { - MethodRouter { - get: self.get, - head: self.head, - delete: self.delete, - options: self.options, - patch: self.patch, - post: self.post, - put: self.put, - trace: self.trace, - fallback: self.fallback, - allow_header: self.allow_header, - _marker: PhantomData, - } - } -} - impl MethodRouter where B: Send + 'static, @@ -746,7 +697,58 @@ where } } -impl MethodRouter { +impl MethodRouter +where + B: Send + 'static, +{ + /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all + /// requests. + pub fn new() -> Self { + let fallback = Route::new(service_fn(|_: Request| async { + Ok(StatusCode::METHOD_NOT_ALLOWED.into_response()) + })); + + Self { + get: None, + head: None, + delete: None, + options: None, + patch: None, + post: None, + put: None, + trace: None, + allow_header: AllowHeader::None, + fallback: Fallback::Default(fallback), + _marker: PhantomData, + } + } + + /// Provide the state. + /// + /// See [`State`](crate::extract::State) for more details about accessing state. + pub fn with_state(self, state: S) -> WithState { + WithState { + method_router: self, + state, + } + } + + pub(crate) fn downcast_state(self) -> MethodRouter { + MethodRouter { + get: self.get, + head: self.head, + delete: self.delete, + options: self.options, + patch: self.patch, + post: self.post, + put: self.put, + trace: self.trace, + fallback: self.fallback, + allow_header: self.allow_header, + _marker: PhantomData, + } + } + /// Chain an additional service that will accept requests matching the given /// `MethodFilter`. /// @@ -1073,7 +1075,7 @@ fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) { impl Service> for MethodRouter<(), B, E> where - B: HttpBody, + B: HttpBody + Send + 'static, { type Response = Response; type Error = E; From b1c5f658c0b4b313a1c3a89dd10b6ee44312e951 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 11 Aug 2022 16:42:01 +0200 Subject: [PATCH 25/45] make sure merged `MethodRouter`s can access state --- axum/src/routing/method_routing.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 9c32103863..9c2dbd293e 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1515,6 +1515,22 @@ mod tests { assert_eq!(text, "state"); } + #[tokio::test] + async fn merge_accessing_state() { + let one = get(|State(state): State<&'static str>| async move { state }); + let two = post(|State(state): State<&'static str>| async move { state }); + + let mut svc = one.merge(two).with_state("state"); + + let (status, _, text) = call(Method::GET, &mut svc).await; + assert_eq!(status, StatusCode::OK); + assert_eq!(text, "state"); + + let (status, _, _) = call(Method::POST, &mut svc).await; + assert_eq!(status, StatusCode::OK); + assert_eq!(text, "state"); + } + async fn call(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String) where S: Service, Error = Infallible>, From 3f94c1c7ae2e3fb497e00c9817b6792a439284ed Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 11 Aug 2022 16:57:09 +0200 Subject: [PATCH 26/45] fix docs link --- axum/src/handler/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 4f1e848783..a685cbd49d 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -65,7 +65,7 @@ pub use self::{into_service::IntoService, with_state::WithState}; /// # Converting `Handler`s into [`Service`]s /// /// To convert `Handler`s into [`Service`]s you have to call either -/// [`HandlerWithoutStateExt::into_service`] or [`WithState::into_service`]: +/// [`HandlerWithoutStateExt::into_service`] or [`Handler::with_state`]: /// /// ``` /// use tower::Service; From ec1cfba20aaa7aa15848ca398a1523b62ca375ae Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 11 Aug 2022 17:01:19 +0200 Subject: [PATCH 27/45] test merge with same state type --- axum/src/routing/tests/merge.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/axum/src/routing/tests/merge.rs b/axum/src/routing/tests/merge.rs index 0f4206bca4..5f67dcd784 100644 --- a/axum/src/routing/tests/merge.rs +++ b/axum/src/routing/tests/merge.rs @@ -452,3 +452,28 @@ async fn merging_routes_different_method_different_states() { let res = client.post("/").send().await; assert_eq!(res.text().await, "post state"); } + +#[tokio::test] +async fn merging_routes_different_paths_different_states() { + let foo = Router::with_state("foo state").route( + "/foo", + get(|State(state): State<&'static str>| async move { state }), + ); + + let bar = Router::with_state("bar state").route( + "/bar", + get(|State(state): State<&'static str>| async move { state }), + ); + + let app = Router::new().merge(foo).merge(bar); + + let client = TestClient::new(app); + + let res = client.get("/foo").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "foo state"); + + let res = client.get("/bar").send().await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "bar state"); +} From babb073a8bb2e48d25b3e55f8a9d7742a867e0ca Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 12 Aug 2022 11:54:27 +0200 Subject: [PATCH 28/45] Document how to access state from middleware --- axum/src/docs/middleware.md | 120 +++++++++++++++++++++++++++++++++++- axum/src/extract/state.rs | 5 ++ 2 files changed, 122 insertions(+), 3 deletions(-) diff --git a/axum/src/docs/middleware.md b/axum/src/docs/middleware.md index c7bd4710cb..8d0e970c4f 100644 --- a/axum/src/docs/middleware.md +++ b/axum/src/docs/middleware.md @@ -6,7 +6,8 @@ - [Ordering](#ordering) - [Writing middleware](#writing-middleware) - [Routing to services/middleware and backpressure](#routing-to-servicesmiddleware-and-backpressure) -- [Sharing state between handlers and middleware](#sharing-state-between-handlers-and-middleware) +- [Accessing state in middleware](#accessing-state-in-middleware) +- [Passing state from middleware to handlers](#passing-state-from-middleware-to-handlers) # Intro @@ -386,9 +387,119 @@ Also note that handlers created from async functions don't care about backpressure and are always ready. So if you're not using any Tower middleware you don't have to worry about any of this. -# Sharing state between handlers and middleware +# Accessing state in middleware -State can be shared between middleware and handlers using [request extensions]: +Handlers can access state using the [`State`] extractor but this isn't available +to middleware. Instead you have to pass the state directly to middleware using +either closure captures (for [`axum::middleware::from_fn`]) or regular struct +fields (if you're implementing a [`tower::Layer`]) + +## Accessing state in `axum::middleware::from_fn` + +```rust +use axum::{ + Router, + routing::get, + middleware::{self, Next}, + response::Response, + extract::State, + http::Request, +}; + +#[derive(Clone)] +struct AppState {} + +async fn my_middleware( + state: AppState, + req: Request, + next: Next, +) -> Response { + next.run(req).await +} + +async fn handler(_: State) {} + +let state = AppState {}; + +let app = Router::with_state(state.clone()) + .route("/", get(handler)) + .layer(middleware::from_fn(move |req, next| { + my_middleware(state.clone(), req, next) + })); +# let _: Router<_> = app; +``` + +## Accessing state in custom `tower::Layer`s + +```rust +use axum::{ + Router, + routing::get, + middleware::{self, Next}, + response::Response, + extract::State, + http::Request, +}; +use tower::{Layer, Service}; +use std::task::{Context, Poll}; + +#[derive(Clone)] +struct AppState {} + +#[derive(Clone)] +struct MyLayer { + state: AppState, +} + +impl Layer for MyLayer { + type Service = MyService; + + fn layer(&self, inner: S) -> Self::Service { + MyService { + inner, + state: self.state.clone(), + } + } +} + +#[derive(Clone)] +struct MyService { + inner: S, + state: AppState, +} + +impl Service> for MyService +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + // do something with `self.state` + + self.inner.call(req) + } +} + +async fn handler(_: State) {} + +let state = AppState {}; + +let app = Router::with_state(state.clone()) + .route("/", get(handler)) + .layer(MyLayer { state }); +# let _: Router<_> = app; +``` + +# Passing state from middleware to handlers + +State can be passed from middleware to handlers using [request extensions]: ```rust use axum::{ @@ -415,6 +526,8 @@ async fn auth(mut req: Request, next: Next) -> Result Date: Fri, 12 Aug 2022 12:08:30 +0200 Subject: [PATCH 29/45] Port cookie extractors to use state to extract keys (#1250) --- axum-extra/CHANGELOG.md | 2 ++ axum-extra/src/extract/cookie/mod.rs | 40 ++++++++++++++++----- axum-extra/src/extract/cookie/private.rs | 44 ++++++++++++++++-------- axum-extra/src/extract/cookie/signed.rs | 44 ++++++++++++++++-------- 4 files changed, 92 insertions(+), 38 deletions(-) diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 932c55f5a7..e490aa48f9 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning]. literal `Response` - **added:** Support chaining handlers with `HandlerCallWithExtractors::or` ([#1170]) - **change:** axum-extra's MSRV is now 1.60 ([#1239]) +- **breaking:** `SignedCookieJar` and `PrivateCookieJar` now extracts the keys + from the router's state, rather than extensions [#1086]: https://github.com/tokio-rs/axum/pull/1086 [#1119]: https://github.com/tokio-rs/axum/pull/1119 diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index c380f376d7..887495d7dc 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -208,7 +208,7 @@ fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) { #[cfg(test)] mod tests { use super::*; - use axum::{body::Body, http::Request, routing::get, Extension, Router}; + use axum::{body::Body, http::Request, routing::get, Router}; use tower::ServiceExt; macro_rules! cookie_test { @@ -227,12 +227,15 @@ mod tests { jar.remove(Cookie::named("key")) } - let app = Router::<_, Body>::new() + let state = AppState { + key: Key::generate(), + custom_key: CustomKey(Key::generate()), + }; + + let app = Router::<_, Body>::with_state(state) .route("/set", get(set_cookie)) .route("/get", get(get_cookie)) - .route("/remove", get(remove_cookie)) - .layer(Extension(Key::generate())) - .layer(Extension(CustomKey(Key::generate()))); + .route("/remove", get(remove_cookie)); let res = app .clone() @@ -280,6 +283,24 @@ mod tests { cookie_test!(private_cookies, PrivateCookieJar); cookie_test!(private_cookies_with_custom_key, PrivateCookieJar); + #[derive(Clone)] + struct AppState { + key: Key, + custom_key: CustomKey, + } + + impl From for Key { + fn from(state: AppState) -> Key { + state.key + } + } + + impl From for CustomKey { + fn from(state: AppState) -> CustomKey { + state.custom_key + } + } + #[derive(Clone)] struct CustomKey(Key); @@ -295,9 +316,12 @@ mod tests { format!("{:?}", jar.get("key")) } - let app = Router::<_, Body>::new() - .route("/get", get(get_cookie)) - .layer(Extension(Key::generate())); + let state = AppState { + key: Key::generate(), + custom_key: CustomKey(Key::generate()), + }; + + let app = Router::<_, Body>::with_state(state).route("/get", get(get_cookie)); let res = app .clone() diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index d2bff242ac..4c77e14a35 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -3,7 +3,6 @@ use axum::{ async_trait, extract::{FromRequest, RequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, - Extension, }; use cookie::PrivateJar; use std::{convert::Infallible, fmt, marker::PhantomData}; @@ -22,7 +21,6 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// ```rust /// use axum::{ /// Router, -/// Extension, /// routing::{post, get}, /// extract::TypedHeader, /// response::{IntoResponse, Redirect}, @@ -44,22 +42,36 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// } /// } /// -/// // Generate a secure key -/// // -/// // You probably don't wanna generate a new one each time the app starts though -/// let key = Key::generate(); +/// // our application state +/// #[derive(Clone)] +/// struct AppState { +/// // that holds the key used to sign cookies +/// key: Key, +/// } +/// +/// // this impl tells `SignedCookieJar` how to access the key from our state +/// impl From for Key { +/// fn from(state: AppState) -> Self { +/// state.key +/// } +/// } +/// +/// let state = AppState { +/// // Generate a secure key +/// // +/// // You probably don't wanna generate a new one each time the app starts though +/// key: Key::generate(), +/// }; /// -/// let app = Router::new() +/// let app = Router::with_state(state) /// .route("/set", post(set_secret)) -/// .route("/get", get(get_secret)) -/// // add extension with the key so `PrivateCookieJar` can access it -/// .layer(Extension(key)); -/// # let app: Router = app; +/// .route("/get", get(get_secret)); +/// # let app: Router<_> = app; /// ``` pub struct PrivateCookieJar { jar: cookie::CookieJar, key: Key, - // The key used to extract the key extension. Allows users to use multiple keys for different + // The key used to extract the key. Allows users to use multiple keys for different // jars. Maybe a library wants its own key. _marker: PhantomData, } @@ -77,13 +89,15 @@ impl fmt::Debug for PrivateCookieJar { impl FromRequest for PrivateCookieJar where B: Send, - S: Send, + S: Into + Clone + Send, K: Into + Clone + Send + Sync + 'static, { - type Rejection = as FromRequest>::Rejection; + type Rejection = Infallible; async fn from_request(req: &mut RequestParts) -> Result { - let key = Extension::::from_request(req).await?.0.into(); + let state = req.state().clone(); + let key: K = state.into(); + let key: Key = key.into(); let mut jar = cookie::CookieJar::new(); let mut private_jar = jar.private_mut(&key); diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index 8279141393..dce9422453 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -3,7 +3,6 @@ use axum::{ async_trait, extract::{FromRequest, RequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, - Extension, }; use cookie::SignedJar; use cookie::{Cookie, Key}; @@ -23,7 +22,6 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// ```rust /// use axum::{ /// Router, -/// Extension, /// routing::{post, get}, /// extract::TypedHeader, /// response::{IntoResponse, Redirect}, @@ -62,22 +60,36 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// # todo!() /// } /// -/// // Generate a secure key -/// // -/// // You probably don't wanna generate a new one each time the app starts though -/// let key = Key::generate(); +/// // our application state +/// #[derive(Clone)] +/// struct AppState { +/// // that holds the key used to sign cookies +/// key: Key, +/// } +/// +/// // this impl tells `SignedCookieJar` how to access the key from our state +/// impl From for Key { +/// fn from(state: AppState) -> Self { +/// state.key +/// } +/// } +/// +/// let state = AppState { +/// // Generate a secure key +/// // +/// // You probably don't wanna generate a new one each time the app starts though +/// key: Key::generate(), +/// }; /// -/// let app = Router::new() +/// let app = Router::with_state(state) /// .route("/sessions", post(create_session)) -/// .route("/me", get(me)) -/// // add extension with the key so `SignedCookieJar` can access it -/// .layer(Extension(key)); -/// # let app: Router = app; +/// .route("/me", get(me)); +/// # let app: Router<_> = app; /// ``` pub struct SignedCookieJar { jar: cookie::CookieJar, key: Key, - // The key used to extract the key extension. Allows users to use multiple keys for different + // The key used to extract the key. Allows users to use multiple keys for different // jars. Maybe a library wants its own key. _marker: PhantomData, } @@ -95,13 +107,15 @@ impl fmt::Debug for SignedCookieJar { impl FromRequest for SignedCookieJar where B: Send, - S: Send, + S: Into + Clone + Send, K: Into + Clone + Send + Sync + 'static, { - type Rejection = as FromRequest>::Rejection; + type Rejection = Infallible; async fn from_request(req: &mut RequestParts) -> Result { - let key = Extension::::from_request(req).await?.0.into(); + let state = req.state().clone(); + let key: K = state.into(); + let key: Key = key.into(); let mut jar = cookie::CookieJar::new(); let mut signed_jar = jar.signed_mut(&key); From 47d8ef56173ee547bf6a431d9757ea143191554c Mon Sep 17 00:00:00 2001 From: Dani Pardo Date: Fri, 12 Aug 2022 12:02:52 +0200 Subject: [PATCH 30/45] Updates ECOSYSTEM with a new sample project (#1252) --- ECOSYSTEM.md | 1 + 1 file changed, 1 insertion(+) diff --git a/ECOSYSTEM.md b/ECOSYSTEM.md index b3f4c43319..2c74ef0a61 100644 --- a/ECOSYSTEM.md +++ b/ECOSYSTEM.md @@ -42,6 +42,7 @@ If your project isn't listed here and you would like it to be, please feel free - [sandbox_axum_observability](https://github.com/davidB/sandbox_axum_observability) A Sandbox/showcase project to experiment axum and observability (tracing, opentelemetry, jaeger, grafana tempo,...) - [axum_admin](https://github.com/lingdu1234/axum_admin): An admin panel built with **axum**, Sea-orm and Vue 3. - [rgit](https://git.inept.dev/~doyle/rgit.git/about): A blazingly fast Git repository browser, compatible with- and heavily inspired by cgit. +- [Petclinic](https://github.com/danipardo/petclinic): A port of Spring Framework's Petclinic showcase project to Axum [Realworld]: https://github.com/gothinkster/realworld [SQLx]: https://github.com/launchbadge/sqlx From e4b2075e3ef833b6f10c2ad2fa736bda14338b3c Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Fri, 12 Aug 2022 12:06:38 +0200 Subject: [PATCH 31/45] Avoid unhelpful compiler suggestion (#1251) --- axum-macros/src/debug_handler.rs | 33 +++++++++++-------- .../tests/debug_handler/fail/generics.rs | 2 +- .../tests/debug_handler/fail/generics.stderr | 10 +----- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 468d06cdaa..81b9c4439d 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -7,10 +7,26 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { let check_request_last_extractor = check_request_last_extractor(&item_fn); let check_path_extractor = check_path_extractor(&item_fn); let check_multiple_body_extractors = check_multiple_body_extractors(&item_fn); - - let check_inputs_impls_from_request = check_inputs_impls_from_request(&item_fn, &attr.body_ty); let check_output_impls_into_response = check_output_impls_into_response(&item_fn); - let check_future_send = check_future_send(&item_fn); + + // If the function is generic, we can't reliably check its inputs or whether the future it + // returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors. + let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() { + let check_inputs_impls_from_request = + check_inputs_impls_from_request(&item_fn, &attr.body_ty); + let check_future_send = check_future_send(&item_fn); + + quote! { + #check_inputs_impls_from_request + #check_future_send + } + } else { + syn::Error::new_spanned( + &item_fn.sig.generics, + "`#[axum_macros::debug_handler]` doesn't support generic functions", + ) + .into_compile_error() + }; quote! { #item_fn @@ -18,9 +34,8 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { #check_request_last_extractor #check_path_extractor #check_multiple_body_extractors - #check_inputs_impls_from_request #check_output_impls_into_response - #check_future_send + #check_inputs_and_future_send } } @@ -153,14 +168,6 @@ fn check_multiple_body_extractors(item_fn: &ItemFn) -> TokenStream { } fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStream { - if !item_fn.sig.generics.params.is_empty() { - return syn::Error::new_spanned( - &item_fn.sig.generics, - "`#[axum_macros::debug_handler]` doesn't support generic functions", - ) - .into_compile_error(); - } - item_fn .sig .inputs diff --git a/axum-macros/tests/debug_handler/fail/generics.rs b/axum-macros/tests/debug_handler/fail/generics.rs index 310de31867..dd15076761 100644 --- a/axum-macros/tests/debug_handler/fail/generics.rs +++ b/axum-macros/tests/debug_handler/fail/generics.rs @@ -1,6 +1,6 @@ use axum_macros::debug_handler; #[debug_handler] -async fn handler() {} +async fn handler(extract: T) {} fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/generics.stderr b/axum-macros/tests/debug_handler/fail/generics.stderr index 52b705983e..4a96a0e3cd 100644 --- a/axum-macros/tests/debug_handler/fail/generics.stderr +++ b/axum-macros/tests/debug_handler/fail/generics.stderr @@ -1,13 +1,5 @@ error: `#[axum_macros::debug_handler]` doesn't support generic functions --> tests/debug_handler/fail/generics.rs:4:17 | -4 | async fn handler() {} +4 | async fn handler(extract: T) {} | ^^^ - -error[E0282]: type annotations needed - --> tests/debug_handler/fail/generics.rs:4:10 - | -4 | async fn handler() {} - | ----- ^^^^^^^ cannot infer type for type parameter `T` declared on the function `handler` - | | - | consider giving `future` a type From a489cf7d79dcd32d3b669fe519af7016cc4bf3f8 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 12 Aug 2022 12:19:51 +0200 Subject: [PATCH 32/45] fix docs typo --- axum/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axum/src/lib.rs b/axum/src/lib.rs index 903f6f0a2f..a9b1bfb144 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -201,7 +201,7 @@ //! # }; //! ``` //! -//! You should prefer using [`State`] if possible since its more type safe. The downside is that +//! You should prefer using [`State`] if possible since it's more type safe. The downside is that //! its less dynamic than request extensions. //! //! See [`State`] for more details about accessing state. From 1ed3e0671182454eeb25856668081003935b3c10 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 15 Aug 2022 15:41:08 +0200 Subject: [PATCH 33/45] document how library authors should access state --- axum/src/extract/state.rs | 121 +++++++++++++++++++++++++------------- 1 file changed, 80 insertions(+), 41 deletions(-) diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index 66b5becf05..6e9aa224b3 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -12,9 +12,7 @@ use std::{ /// /// [state-from-middleware]: ../middleware/index.html#accessing-state-in-middleware /// -/// # Examples -/// -/// ## With `Router` +/// # With `Router` /// /// ``` /// use axum::{Router, routing::get, extract::State}; @@ -41,7 +39,54 @@ use std::{ /// # let _: Router = app; /// ``` /// -/// ### Substates +/// # With `MethodRouter` +/// +/// ``` +/// use axum::{routing::get, extract::State}; +/// +/// #[derive(Clone)] +/// struct AppState {} +/// +/// let state = AppState {}; +/// +/// let method_router_with_state = get(handler) +/// // provide the state so the handler can access it +/// .with_state(state); +/// +/// async fn handler(State(state): State) { +/// // use `state`... +/// } +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(method_router_with_state.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// # With `Handler` +/// +/// ``` +/// use axum::{routing::get, handler::Handler, extract::State}; +/// +/// #[derive(Clone)] +/// struct AppState {} +/// +/// let state = AppState {}; +/// +/// async fn handler(State(state): State) { +/// // use `state`... +/// } +/// +/// // provide the state so the handler can access it +/// let handler_with_state = handler.with_state(state); +/// +/// # async { +/// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) +/// .serve(handler_with_state.into_make_service()) +/// .await +/// .expect("server failed"); +/// # }; +/// ``` +/// +/// # Substates /// /// [`State`] only allows a single state type but you can use [`From`] to extract "substates": /// @@ -88,52 +133,46 @@ use std::{ /// # let _: Router = app; /// ``` /// -/// ## With `MethodRouter` +/// # For library authors /// -/// ``` -/// use axum::{routing::get, extract::State}; -/// -/// #[derive(Clone)] -/// struct AppState {} -/// -/// let state = AppState {}; -/// -/// let method_router_with_state = get(handler) -/// // provide the state so the handler can access it -/// .with_state(state); -/// -/// async fn handler(State(state): State) { -/// // use `state`... -/// } -/// # async { -/// # axum::Server::bind(&"".parse().unwrap()).serve(method_router_with_state.into_make_service()).await.unwrap(); -/// # }; -/// ``` +/// If you're writing a library that has an extractor that needs state, this is the recommended way +/// to do it: /// -/// ## With `Handler` +/// ```rust +/// use axum_core::extract::{FromRequest, RequestParts}; +/// use async_trait::async_trait; +/// use std::convert::Infallible; /// -/// ``` -/// use axum::{routing::get, handler::Handler, extract::State}; +/// // the extractor your library provides +/// struct MyLibraryExtractor; /// -/// #[derive(Clone)] -/// struct AppState {} +/// #[async_trait] +/// impl FromRequest for MyLibraryExtractor +/// where +/// B: Send, +/// // keep `S` generic but require that it can produce a `MyLibraryState` +/// // this means users will have to implement `From for MyLibraryState` +/// S: Into + Clone + Send, +/// { +/// type Rejection = Infallible; /// -/// let state = AppState {}; +/// async fn from_request(req: &mut RequestParts) -> Result { +/// // get a `MyLibraryState` from the shared application state +/// let state: MyLibraryState = req.state().clone().into(); /// -/// async fn handler(State(state): State) { -/// // use `state`... +/// // ... +/// # todo!() +/// } /// } /// -/// // provide the state so the handler can access it -/// let handler_with_state = handler.with_state(state); -/// -/// # async { -/// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) -/// .serve(handler_with_state.into_make_service()) -/// .await -/// .expect("server failed"); -/// # }; +/// // the state your library needs +/// struct MyLibraryState { +/// // ... +/// } /// ``` +/// +/// Note that you don't need to use the `State` extractor since you can access the state directly +/// from [`RequestParts`]. #[derive(Debug, Default, Clone, Copy)] pub struct State(pub S); From 94f050f03e99f2959334c02a014e5259d9bfbb91 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 16 Aug 2022 14:36:32 +0200 Subject: [PATCH 34/45] Add `RequestParts::with_state` --- axum-core/src/extract/mod.rs | 15 ++++++++++++++- axum-extra/src/extract/cached.rs | 2 +- axum-extra/src/handler/mod.rs | 2 +- axum-extra/src/handler/or.rs | 2 +- axum/CHANGELOG.md | 2 -- axum/src/error_handling/mod.rs | 2 +- axum/src/extract/query.rs | 2 +- axum/src/form.rs | 6 +++--- axum/src/handler/mod.rs | 2 +- axum/src/middleware/from_extractor.rs | 2 +- axum/src/middleware/from_fn.rs | 2 +- 11 files changed, 25 insertions(+), 14 deletions(-) diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index 7bcc8ea982..8622a87b23 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -85,6 +85,19 @@ pub struct RequestParts { body: Option, } +impl RequestParts { + /// Create a new `RequestParts` without any state. + /// + /// You generally shouldn't need to construct this type yourself, unless + /// using extractors outside of axum for example to implement a + /// [`tower::Service`]. + /// + /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html + pub fn new(req: Request) -> Self { + Self::with_state(req, ()) + } +} + impl RequestParts { /// Create a new `RequestParts` with the given state. /// @@ -93,7 +106,7 @@ impl RequestParts { /// [`tower::Service`]. /// /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html - pub fn new(req: Request, state: S) -> Self { + pub fn with_state(req: Request, state: S) -> Self { let ( http::request::Parts { method, diff --git a/axum-extra/src/extract/cached.rs b/axum-extra/src/extract/cached.rs index 2d7efafaa9..29d8f8d779 100644 --- a/axum-extra/src/extract/cached.rs +++ b/axum-extra/src/extract/cached.rs @@ -155,7 +155,7 @@ mod tests { } } - let mut req = RequestParts::new(Request::new(()), ()); + let mut req = RequestParts::with_state(Request::new(()), ()); let first = Cached::::from_request(&mut req).await.unwrap().0; assert_eq!(COUNTER.load(Ordering::SeqCst), 1); diff --git a/axum-extra/src/handler/mod.rs b/axum-extra/src/handler/mod.rs index 15c59d5386..31d464b548 100644 --- a/axum-extra/src/handler/mod.rs +++ b/axum-extra/src/handler/mod.rs @@ -178,7 +178,7 @@ where fn call(self, state: S, req: http::Request) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::new(req, state.clone()); + let mut req = RequestParts::with_state(req, state.clone()); match req.extract::().await { Ok(t) => self.handler.call(state, t).await, Err(rejection) => rejection.into_response(), diff --git a/axum-extra/src/handler/or.rs b/axum-extra/src/handler/or.rs index 518df1200b..df84e9a45e 100644 --- a/axum-extra/src/handler/or.rs +++ b/axum-extra/src/handler/or.rs @@ -71,7 +71,7 @@ where fn call(self, state: S, req: Request) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::new(req, state.clone()); + let mut req = RequestParts::with_state(req, state.clone()); if let Ok(lt) = req.extract::().await { return self.lhs.call(state, lt).await; diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 5a1a21aaad..92f40e66a5 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -46,8 +46,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `MethodRouter` - `Handler` - `Resource` -- **breaking:** `RequestParts::new` takes the state and the request rather than - just the request ([#1155]) - **breaking:** `Router::route` now only accepts `MethodRouter`s created with `get`, `post`, etc ([#1155]) - **added:** `Router::route_service` for routing to arbitrary `Service`s ([#1155]) diff --git a/axum/src/error_handling/mod.rs b/axum/src/error_handling/mod.rs index 9a887b511b..de0ee8bc02 100644 --- a/axum/src/error_handling/mod.rs +++ b/axum/src/error_handling/mod.rs @@ -181,7 +181,7 @@ macro_rules! impl_service { let inner = std::mem::replace(&mut self.inner, clone); let future = Box::pin(async move { - let mut req = RequestParts::new(req, ()); + let mut req = RequestParts::new(req); $( let $ty = match $ty::from_request(&mut req).await { diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index abd8c6d2b9..1050e6be68 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -83,7 +83,7 @@ mod tests { async fn check(uri: impl AsRef, value: T) { let req = Request::builder().uri(uri.as_ref()).body(()).unwrap(); - let mut req = RequestParts::new(req, ()); + let mut req = RequestParts::new(req); assert_eq!(Query::::from_request(&mut req).await.unwrap().0, value); } diff --git a/axum/src/form.rs b/axum/src/form.rs index f75584edf2..267cd608c2 100644 --- a/axum/src/form.rs +++ b/axum/src/form.rs @@ -130,7 +130,7 @@ mod tests { .uri(uri.as_ref()) .body(Empty::::new()) .unwrap(); - let mut req = RequestParts::new(req, ()); + let mut req = RequestParts::new(req); assert_eq!(Form::::from_request(&mut req).await.unwrap().0, value); } @@ -146,7 +146,7 @@ mod tests { serde_urlencoded::to_string(&value).unwrap().into(), )) .unwrap(); - let mut req = RequestParts::new(req, ()); + let mut req = RequestParts::new(req); assert_eq!(Form::::from_request(&mut req).await.unwrap().0, value); } @@ -216,7 +216,7 @@ mod tests { .into(), )) .unwrap(); - let mut req = RequestParts::new(req, ()); + let mut req = RequestParts::new(req); assert!(matches!( Form::::from_request(&mut req) .await diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index a685cbd49d..197bd6540f 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -187,7 +187,7 @@ macro_rules! impl_handler { fn call(self, state: S, req: Request) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::new(req, state); + let mut req = RequestParts::with_state(req, state); $( let $ty = match $ty::from_request(&mut req).await { diff --git a/axum/src/middleware/from_extractor.rs b/axum/src/middleware/from_extractor.rs index 49d523f934..4de38aa6ff 100644 --- a/axum/src/middleware/from_extractor.rs +++ b/axum/src/middleware/from_extractor.rs @@ -185,7 +185,7 @@ where fn call(&mut self, req: Request) -> Self::Future { let extract_future = Box::pin(async move { - let mut req = RequestParts::new(req, ()); + let mut req = RequestParts::new(req); let extracted = E::from_request(&mut req).await; (req, extracted) }); diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index 29256cbd21..4e26a760b9 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -280,7 +280,7 @@ macro_rules! impl_service { let mut f = self.f.clone(); let future = Box::pin(async move { - let mut parts = RequestParts::new(req, ()); + let mut parts = RequestParts::new(req); $( let $ty = match $ty::from_request(&mut parts).await { Ok(value) => value, From aeab0c46a3c366f212fc523ed3708a1a9206a78e Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 17 Aug 2022 09:52:56 +0200 Subject: [PATCH 35/45] fix example --- examples/consume-body-in-extractor-or-middleware/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/consume-body-in-extractor-or-middleware/src/main.rs b/examples/consume-body-in-extractor-or-middleware/src/main.rs index 6057963161..0c82440f9c 100644 --- a/examples/consume-body-in-extractor-or-middleware/src/main.rs +++ b/examples/consume-body-in-extractor-or-middleware/src/main.rs @@ -95,7 +95,7 @@ where let request = buffer_request_body(request).await?; - *req = RequestParts::new(request, state); + *req = RequestParts::with_state(request, state); Ok(Self) } From 02bc5bab238f8aa1196e9c3a901e6aa6ab97a680 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 17 Aug 2022 13:56:17 +0200 Subject: [PATCH 36/45] apply suggestions from review --- axum-extra/src/extract/cached.rs | 2 +- axum/src/handler/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/axum-extra/src/extract/cached.rs b/axum-extra/src/extract/cached.rs index 29d8f8d779..aa303aca18 100644 --- a/axum-extra/src/extract/cached.rs +++ b/axum-extra/src/extract/cached.rs @@ -155,7 +155,7 @@ mod tests { } } - let mut req = RequestParts::with_state(Request::new(()), ()); + let mut req = RequestParts::new(Request::new(())); let first = Cached::::from_request(&mut req).await.unwrap().0; assert_eq!(COUNTER.load(Ordering::SeqCst), 1); diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 197bd6540f..4d5e922066 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -277,7 +277,7 @@ where } } -/// Extension trait for [`Handler`]s who doesn't have state. +/// Extension trait for [`Handler`]s that don't have state. /// /// This provides convenience methods to convert the [`Handler`] into a [`Service`] or [`MakeService`]. /// From 9f63017c829d5e8a76f8dccc3dddd4f12b6671df Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 17 Aug 2022 13:58:53 +0200 Subject: [PATCH 37/45] add relevant changes to axum-extra and axum-core changelogs --- axum-core/CHANGELOG.md | 5 ++++- axum-extra/CHANGELOG.md | 2 ++ axum/CHANGELOG.md | 1 - 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/axum-core/CHANGELOG.md b/axum-core/CHANGELOG.md index da272e5f8a..a3cbc1fb32 100644 --- a/axum-core/CHANGELOG.md +++ b/axum-core/CHANGELOG.md @@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **breaking:** `FromRequest` and `RequestParts` has a new `S` type param which + represents the state ([#1155]) + +[#1155]: https://github.com/tokio-rs/axum/pull/1155 # 0.2.6 (18. June, 2022) diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index ef3a781c3c..54af89d02d 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -25,9 +25,11 @@ and this project adheres to [Semantic Versioning]. - **added:** `WithRejection` extractor for customizing other extractors' rejections ([#1262]) - **added:** Add sync constructors to `CookieJar`, `PrivateCookieJar`, and `SignedCookieJar` so they're easier to use in custom middleware +- **breaking:** `Resource` has a new `S` type param which represents the state ([#1155]) [#1086]: https://github.com/tokio-rs/axum/pull/1086 [#1119]: https://github.com/tokio-rs/axum/pull/1119 +[#1155]: https://github.com/tokio-rs/axum/pull/1155 [#1170]: https://github.com/tokio-rs/axum/pull/1170 [#1214]: https://github.com/tokio-rs/axum/pull/1214 [#1239]: https://github.com/tokio-rs/axum/pull/1239 diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 92f40e66a5..8321cc98ec 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -45,7 +45,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Router` - `MethodRouter` - `Handler` - - `Resource` - **breaking:** `Router::route` now only accepts `MethodRouter`s created with `get`, `post`, etc ([#1155]) - **added:** `Router::route_service` for routing to arbitrary `Service`s ([#1155]) From c1074b7056679a135601d0baff0e63e6dab70a69 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 17 Aug 2022 14:04:28 +0200 Subject: [PATCH 38/45] Add `route_service_with_tsr` --- axum-extra/CHANGELOG.md | 2 ++ axum-extra/src/routing/mod.rs | 37 +++++++++++++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 54af89d02d..8d6568f997 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning]. - **added:** Add sync constructors to `CookieJar`, `PrivateCookieJar`, and `SignedCookieJar` so they're easier to use in custom middleware - **breaking:** `Resource` has a new `S` type param which represents the state ([#1155]) +- **breaking:** `RouterExt::route_with_tsr` now only accepts `MethodRouter`s ([#1155]) +- **added:** `RouterExt::route_service_with_tsr` for routing to any `Service` ([#1155]) [#1086]: https://github.com/tokio-rs/axum/pull/1086 [#1119]: https://github.com/tokio-rs/axum/pull/1119 diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 67da57f106..e90578cd9e 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -2,11 +2,13 @@ use axum::{ handler::Handler, - response::Redirect, + http::Request, + response::{IntoResponse, Redirect}, routing::{any, MethodRouter}, Router, }; -use std::future::ready; +use std::{convert::Infallible, future::ready}; +use tower_service::Service; mod resource; @@ -161,6 +163,16 @@ pub trait RouterExt: sealed::Sealed { fn route_with_tsr(self, path: &str, method_router: MethodRouter) -> Self where Self: Sized; + + /// Add another route to the router with an additional "trailing slash redirect" route. + /// + /// This works like [`RouterExt::route_with_tsr`] but accepts any [`Service`]. + fn route_service_with_tsr(self, path: &str, service: T) -> Self + where + T: Service, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, + T::Future: Send + 'static, + Self: Sized; } impl RouterExt for Router @@ -265,6 +277,27 @@ where self.route(&format!("{}/", path), any(move || ready(redirect.clone()))) } } + + fn route_service_with_tsr(mut self, path: &str, service: T) -> Self + where + T: Service, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, + T::Future: Send + 'static, + Self: Sized, + { + self = self.route_service(path, service); + + let redirect = Redirect::permanent(path); + + if let Some(path_without_trailing_slash) = path.strip_suffix('/') { + self.route( + path_without_trailing_slash, + any(move || ready(redirect.clone())), + ) + } else { + self.route(&format!("{}/", path), any(move || ready(redirect.clone()))) + } + } } mod sealed { From 0668fd4a3888c16b5d351cd4c2638df52e7e7fce Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 17 Aug 2022 14:16:38 +0200 Subject: [PATCH 39/45] fix trybuild expectations --- .../from_request/fail/double_via_attr.rs | 3 +-- .../from_request/fail/double_via_attr.stderr | 14 +++------- .../from_request/fail/generic_without_via.rs | 4 +-- .../fail/generic_without_via.stderr | 27 ++++--------------- .../fail/generic_without_via_rejection.rs | 4 +-- .../fail/generic_without_via_rejection.stderr | 27 ++++--------------- .../generic_without_via_rejection_derive.rs | 4 +-- ...eneric_without_via_rejection_derive.stderr | 27 ++++--------------- .../fail/rejection_derive_and_via.rs | 3 +-- .../fail/rejection_derive_and_via.stderr | 12 ++------- .../fail/via_and_rejection_derive.rs | 3 +-- .../fail/via_and_rejection_derive.stderr | 14 +++------- .../fail/via_on_container_and_field.rs | 5 ++-- .../fail/via_on_container_and_field.stderr | 12 ++------- .../tests/from_request/pass/tuple_via.rs | 2 +- 15 files changed, 37 insertions(+), 124 deletions(-) diff --git a/axum-macros/tests/from_request/fail/double_via_attr.rs b/axum-macros/tests/from_request/fail/double_via_attr.rs index b78cec4653..e65406512c 100644 --- a/axum-macros/tests/from_request/fail/double_via_attr.rs +++ b/axum-macros/tests/from_request/fail/double_via_attr.rs @@ -1,8 +1,7 @@ use axum_macros::FromRequest; -use axum::extract::Extension; #[derive(FromRequest)] -struct Extractor(#[from_request(via(Extension), via(Extension))] State); +struct Extractor(#[from_request(via(axum::Extension), via(axum::Extension))] State); #[derive(Clone)] struct State; diff --git a/axum-macros/tests/from_request/fail/double_via_attr.stderr b/axum-macros/tests/from_request/fail/double_via_attr.stderr index e63d4502d2..9d0ff2490b 100644 --- a/axum-macros/tests/from_request/fail/double_via_attr.stderr +++ b/axum-macros/tests/from_request/fail/double_via_attr.stderr @@ -1,13 +1,5 @@ error: `via` specified more than once - --> tests/from_request/fail/double_via_attr.rs:5:49 + --> tests/from_request/fail/double_via_attr.rs:4:55 | -5 | struct Extractor(#[from_request(via(Extension), via(Extension))] State); - | ^^^ - -warning: unused import: `axum::extract::Extension` - --> tests/from_request/fail/double_via_attr.rs:2:5 - | -2 | use axum::extract::Extension; - | ^^^^^^^^^^^^^^^^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default +4 | struct Extractor(#[from_request(via(axum::Extension), via(axum::Extension))] State); + | ^^^ diff --git a/axum-macros/tests/from_request/fail/generic_without_via.rs b/axum-macros/tests/from_request/fail/generic_without_via.rs index 29ec609566..38eaa437a3 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via.rs +++ b/axum-macros/tests/from_request/fail/generic_without_via.rs @@ -1,4 +1,4 @@ -use axum::{body::Body, routing::get, Extension, Router}; +use axum::{body::Body, routing::get, Router}; use axum_macros::FromRequest; #[derive(FromRequest, Clone)] @@ -7,5 +7,5 @@ struct Extractor(T); async fn foo(_: Extractor<()>) {} fn main() { - Router::::new().route("/", get(foo)); + Router::<(), Body>::new().route("/", get(foo)); } diff --git a/axum-macros/tests/from_request/fail/generic_without_via.stderr b/axum-macros/tests/from_request/fail/generic_without_via.stderr index 953d4295a4..4b6b920a31 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via.stderr +++ b/axum-macros/tests/from_request/fail/generic_without_via.stderr @@ -4,30 +4,13 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_reque 5 | struct Extractor(T); | ^ -warning: unused import: `Extension` - --> tests/from_request/fail/generic_without_via.rs:1:38 - | -1 | use axum::{body::Body, routing::get, Extension, Router}; - | ^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default - -error[E0599]: no function or associated item named `new` found for struct `Router` in the current scope - --> tests/from_request/fail/generic_without_via.rs:10:21 - | -10 | Router::::new().route("/", get(foo)); - | ^^^ function or associated item not found in `Router` - | - = note: the function or associated item was found for - - `Router<(), B>` - error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _, _>` is not satisfied - --> tests/from_request/fail/generic_without_via.rs:10:42 + --> tests/from_request/fail/generic_without_via.rs:10:46 | -10 | Router::::new().route("/", get(foo)); - | --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future {foo}` - | | - | required by a bound introduced by this call +10 | Router::<(), Body>::new().route("/", get(foo)); + | --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future {foo}` + | | + | required by a bound introduced by this call | = help: the trait `Handler` is implemented for `Layered` note: required by a bound in `axum::routing::get` diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection.rs b/axum-macros/tests/from_request/fail/generic_without_via_rejection.rs index 99bdf38774..38d6b0910b 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection.rs +++ b/axum-macros/tests/from_request/fail/generic_without_via_rejection.rs @@ -1,4 +1,4 @@ -use axum::{body::Body, routing::get, Extension, Router}; +use axum::{body::Body, routing::get, Router}; use axum_macros::FromRequest; #[derive(FromRequest, Clone)] @@ -8,5 +8,5 @@ struct Extractor(T); async fn foo(_: Extractor<()>) {} fn main() { - Router::::new().route("/", get(foo)); + Router::<(), Body>::new().route("/", get(foo)); } diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr b/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr index b1b078e662..d470986cde 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr +++ b/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr @@ -4,30 +4,13 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_reque 6 | struct Extractor(T); | ^ -warning: unused import: `Extension` - --> tests/from_request/fail/generic_without_via_rejection.rs:1:38 - | -1 | use axum::{body::Body, routing::get, Extension, Router}; - | ^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default - -error[E0599]: no function or associated item named `new` found for struct `Router` in the current scope - --> tests/from_request/fail/generic_without_via_rejection.rs:11:21 - | -11 | Router::::new().route("/", get(foo)); - | ^^^ function or associated item not found in `Router` - | - = note: the function or associated item was found for - - `Router<(), B>` - error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _, _>` is not satisfied - --> tests/from_request/fail/generic_without_via_rejection.rs:11:42 + --> tests/from_request/fail/generic_without_via_rejection.rs:11:46 | -11 | Router::::new().route("/", get(foo)); - | --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future {foo}` - | | - | required by a bound introduced by this call +11 | Router::<(), Body>::new().route("/", get(foo)); + | --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future {foo}` + | | + | required by a bound introduced by this call | = help: the trait `Handler` is implemented for `Layered` note: required by a bound in `axum::routing::get` diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.rs b/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.rs index 8ded02831e..ec5bb80099 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.rs +++ b/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.rs @@ -1,4 +1,4 @@ -use axum::{body::Body, routing::get, Extension, Router}; +use axum::{body::Body, routing::get, Router}; use axum_macros::FromRequest; #[derive(FromRequest, Clone)] @@ -8,5 +8,5 @@ struct Extractor(T); async fn foo(_: Extractor<()>) {} fn main() { - Router::::new().route("/", get(foo)); + Router::<(), Body>::new().route("/", get(foo)); } diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.stderr b/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.stderr index d1b9b64311..10b674c150 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.stderr +++ b/axum-macros/tests/from_request/fail/generic_without_via_rejection_derive.stderr @@ -4,30 +4,13 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_reque 6 | struct Extractor(T); | ^ -warning: unused import: `Extension` - --> tests/from_request/fail/generic_without_via_rejection_derive.rs:1:38 - | -1 | use axum::{body::Body, routing::get, Extension, Router}; - | ^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default - -error[E0599]: no function or associated item named `new` found for struct `Router` in the current scope - --> tests/from_request/fail/generic_without_via_rejection_derive.rs:11:21 - | -11 | Router::::new().route("/", get(foo)); - | ^^^ function or associated item not found in `Router` - | - = note: the function or associated item was found for - - `Router<(), B>` - error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _, _>` is not satisfied - --> tests/from_request/fail/generic_without_via_rejection_derive.rs:11:42 + --> tests/from_request/fail/generic_without_via_rejection_derive.rs:11:46 | -11 | Router::::new().route("/", get(foo)); - | --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future {foo}` - | | - | required by a bound introduced by this call +11 | Router::<(), Body>::new().route("/", get(foo)); + | --- ^^^ the trait `Handler<_, _, _>` is not implemented for `fn(Extractor<()>) -> impl Future {foo}` + | | + | required by a bound introduced by this call | = help: the trait `Handler` is implemented for `Layered` note: required by a bound in `axum::routing::get` diff --git a/axum-macros/tests/from_request/fail/rejection_derive_and_via.rs b/axum-macros/tests/from_request/fail/rejection_derive_and_via.rs index bb658f1103..a369150d60 100644 --- a/axum-macros/tests/from_request/fail/rejection_derive_and_via.rs +++ b/axum-macros/tests/from_request/fail/rejection_derive_and_via.rs @@ -1,8 +1,7 @@ use axum_macros::FromRequest; -use axum::extract::Extension; #[derive(FromRequest, Clone)] -#[from_request(rejection_derive(!Error), via(Extension))] +#[from_request(rejection_derive(!Error), via(axum::Extension))] struct Extractor { config: String, } diff --git a/axum-macros/tests/from_request/fail/rejection_derive_and_via.stderr b/axum-macros/tests/from_request/fail/rejection_derive_and_via.stderr index 59c6b0bee3..3a50044f83 100644 --- a/axum-macros/tests/from_request/fail/rejection_derive_and_via.stderr +++ b/axum-macros/tests/from_request/fail/rejection_derive_and_via.stderr @@ -1,13 +1,5 @@ error: cannot use both `rejection_derive` and `via` - --> tests/from_request/fail/rejection_derive_and_via.rs:5:42 + --> tests/from_request/fail/rejection_derive_and_via.rs:4:42 | -5 | #[from_request(rejection_derive(!Error), via(Extension))] +4 | #[from_request(rejection_derive(!Error), via(axum::Extension))] | ^^^ - -warning: unused import: `axum::extract::Extension` - --> tests/from_request/fail/rejection_derive_and_via.rs:2:5 - | -2 | use axum::extract::Extension; - | ^^^^^^^^^^^^^^^^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default diff --git a/axum-macros/tests/from_request/fail/via_and_rejection_derive.rs b/axum-macros/tests/from_request/fail/via_and_rejection_derive.rs index 8c183a60d7..5f42ef0cf7 100644 --- a/axum-macros/tests/from_request/fail/via_and_rejection_derive.rs +++ b/axum-macros/tests/from_request/fail/via_and_rejection_derive.rs @@ -1,8 +1,7 @@ use axum_macros::FromRequest; -use axum::extract::Extension; #[derive(FromRequest, Clone)] -#[from_request(via(Extension), rejection_derive(!Error))] +#[from_request(via(axum::Extension), rejection_derive(!Error))] struct Extractor { config: String, } diff --git a/axum-macros/tests/from_request/fail/via_and_rejection_derive.stderr b/axum-macros/tests/from_request/fail/via_and_rejection_derive.stderr index 25f2011b28..af45e8f811 100644 --- a/axum-macros/tests/from_request/fail/via_and_rejection_derive.stderr +++ b/axum-macros/tests/from_request/fail/via_and_rejection_derive.stderr @@ -1,13 +1,5 @@ error: cannot use both `via` and `rejection_derive` - --> tests/from_request/fail/via_and_rejection_derive.rs:5:32 + --> tests/from_request/fail/via_and_rejection_derive.rs:4:38 | -5 | #[from_request(via(Extension), rejection_derive(!Error))] - | ^^^^^^^^^^^^^^^^ - -warning: unused import: `axum::extract::Extension` - --> tests/from_request/fail/via_and_rejection_derive.rs:2:5 - | -2 | use axum::extract::Extension; - | ^^^^^^^^^^^^^^^^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default +4 | #[from_request(via(axum::Extension), rejection_derive(!Error))] + | ^^^^^^^^^^^^^^^^ diff --git a/axum-macros/tests/from_request/fail/via_on_container_and_field.rs b/axum-macros/tests/from_request/fail/via_on_container_and_field.rs index f213857b03..8499659bde 100644 --- a/axum-macros/tests/from_request/fail/via_on_container_and_field.rs +++ b/axum-macros/tests/from_request/fail/via_on_container_and_field.rs @@ -1,9 +1,8 @@ use axum_macros::FromRequest; -use axum::extract::Extension; #[derive(FromRequest)] -#[from_request(via(Extension))] -struct Extractor(#[from_request(via(Extension))] State); +#[from_request(via(axum::Extension))] +struct Extractor(#[from_request(via(axum::Extension))] State); #[derive(Clone)] struct State; diff --git a/axum-macros/tests/from_request/fail/via_on_container_and_field.stderr b/axum-macros/tests/from_request/fail/via_on_container_and_field.stderr index 01b6cac073..ff63e37df0 100644 --- a/axum-macros/tests/from_request/fail/via_on_container_and_field.stderr +++ b/axum-macros/tests/from_request/fail/via_on_container_and_field.stderr @@ -1,13 +1,5 @@ error: `#[from_request(via(...))]` on a field cannot be used together with `#[from_request(...)]` on the container - --> tests/from_request/fail/via_on_container_and_field.rs:6:33 + --> tests/from_request/fail/via_on_container_and_field.rs:5:33 | -6 | struct Extractor(#[from_request(via(Extension))] State); +5 | struct Extractor(#[from_request(via(axum::Extension))] State); | ^^^ - -warning: unused import: `axum::extract::Extension` - --> tests/from_request/fail/via_on_container_and_field.rs:2:5 - | -2 | use axum::extract::Extension; - | ^^^^^^^^^^^^^^^^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default diff --git a/axum-macros/tests/from_request/pass/tuple_via.rs b/axum-macros/tests/from_request/pass/tuple_via.rs index d08c0f52ed..01b23bdb84 100644 --- a/axum-macros/tests/from_request/pass/tuple_via.rs +++ b/axum-macros/tests/from_request/pass/tuple_via.rs @@ -1,4 +1,4 @@ -use axum::extract::Extension; +use axum::Extension; use axum_macros::FromRequest; #[derive(FromRequest)] From b6a2f39ca3dca893f82bad7bc6a9bf3384d33d17 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 17 Aug 2022 14:18:54 +0200 Subject: [PATCH 40/45] make sure `SpaRouter` works with routers that have state --- axum-extra/src/routing/spa.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/axum-extra/src/routing/spa.rs b/axum-extra/src/routing/spa.rs index 0bbb54f12a..844c1a5d4c 100644 --- a/axum-extra/src/routing/spa.rs +++ b/axum-extra/src/routing/spa.rs @@ -266,4 +266,11 @@ mod tests { Router::<_, Body>::new().merge(spa); } + + #[allow(dead_code)] + fn works_with_router_with_state() { + let _: Router = Router::with_state(String::new()) + .merge(SpaRouter::new("/assets", "test_files")) + .route("/", get(|_: axum::extract::State| async {})); + } } From 02a1f8148366173eca02ab085d3878df0ec5603d Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 17 Aug 2022 14:46:10 +0200 Subject: [PATCH 41/45] Change order of type params on FromRequest and RequestParts --- axum-core/src/extract/mod.rs | 32 ++++++++--------- axum-core/src/extract/request_parts.rs | 34 +++++++++---------- axum-core/src/extract/tuple.rs | 10 +++--- axum-extra/src/either.rs | 8 ++--- axum-extra/src/extract/cached.rs | 18 +++++----- axum-extra/src/extract/cookie/mod.rs | 4 +-- axum-extra/src/extract/cookie/private.rs | 4 +-- axum-extra/src/extract/cookie/signed.rs | 4 +-- axum-extra/src/extract/form.rs | 6 ++-- axum-extra/src/extract/query.rs | 4 +-- axum-extra/src/extract/with_rejection.rs | 10 +++--- axum-extra/src/handler/mod.rs | 12 +++---- axum-extra/src/handler/or.rs | 6 ++-- axum-extra/src/json_lines.rs | 4 +-- axum-extra/src/protobuf.rs | 4 +-- axum-macros/src/debug_handler.rs | 2 +- axum-macros/src/from_request.rs | 32 ++++++++--------- axum-macros/src/lib.rs | 6 ++-- axum-macros/src/typed_path.rs | 14 ++++---- .../fail/argument_not_extractor.stderr | 22 ++++++------ .../debug_handler/fail/extract_self_mut.rs | 4 +-- .../debug_handler/fail/extract_self_ref.rs | 4 +-- .../pass/result_impl_into_response.rs | 4 +-- .../tests/debug_handler/pass/self_receiver.rs | 4 +-- .../tests/from_request/pass/container.rs | 2 +- .../tests/from_request/pass/derive_opt_out.rs | 4 +-- .../tests/from_request/pass/empty_named.rs | 2 +- .../tests/from_request/pass/empty_tuple.rs | 2 +- axum-macros/tests/from_request/pass/named.rs | 2 +- .../tests/from_request/pass/named_via.rs | 2 +- .../from_request/pass/override_rejection.rs | 4 +-- axum-macros/tests/from_request/pass/tuple.rs | 2 +- .../pass/tuple_same_type_twice.rs | 2 +- .../pass/tuple_same_type_twice_via.rs | 2 +- .../tests/from_request/pass/tuple_via.rs | 2 +- axum-macros/tests/from_request/pass/unit.rs | 2 +- .../typed_path/fail/not_deserialize.stderr | 2 +- axum/src/docs/extract.md | 8 ++--- axum/src/error_handling/mod.rs | 2 +- axum/src/extension.rs | 4 +-- axum/src/extract/connect_info.rs | 6 ++-- axum/src/extract/content_length_limit.rs | 6 ++-- axum/src/extract/host.rs | 4 +-- axum/src/extract/matched_path.rs | 6 ++-- axum/src/extract/mod.rs | 6 ++-- axum/src/extract/multipart.rs | 4 +-- axum/src/extract/path/mod.rs | 4 +-- axum/src/extract/query.rs | 4 +-- axum/src/extract/raw_query.rs | 4 +-- axum/src/extract/request_parts.rs | 12 +++---- axum/src/extract/state.rs | 8 ++--- axum/src/extract/ws.rs | 8 ++--- axum/src/form.rs | 4 +-- axum/src/handler/mod.rs | 4 +-- axum/src/json.rs | 6 ++-- axum/src/middleware/from_extractor.rs | 18 +++++----- axum/src/middleware/from_fn.rs | 2 +- axum/src/routing/method_routing.rs | 2 +- axum/src/routing/strip_prefix.rs | 2 +- axum/src/routing/tests/nest.rs | 2 +- axum/src/typed_header.rs | 4 +-- .../src/main.rs | 4 +-- .../customize-extractor-error/src/main.rs | 4 +-- examples/customize-path-rejection/src/main.rs | 4 +-- examples/jwt/src/main.rs | 4 +-- examples/oauth/src/main.rs | 4 +-- examples/sessions/src/main.rs | 4 +-- examples/sqlx-postgres/src/main.rs | 4 +-- examples/tokio-postgres/src/main.rs | 4 +-- examples/validator/src/main.rs | 4 +-- examples/versioning/src/main.rs | 4 +-- 71 files changed, 226 insertions(+), 226 deletions(-) diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index 8622a87b23..6425086f67 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -42,7 +42,7 @@ mod tuple; /// struct MyExtractor; /// /// #[async_trait] -/// impl FromRequest for MyExtractor +/// impl FromRequest for MyExtractor /// where /// // these bounds are required by `async_trait` /// B: Send, @@ -50,7 +50,7 @@ mod tuple; /// { /// type Rejection = http::StatusCode; /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: &mut RequestParts) -> Result { /// // ... /// # unimplemented!() /// } @@ -62,20 +62,20 @@ mod tuple; /// [`http::Request`]: http::Request /// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html #[async_trait] -pub trait FromRequest: Sized { +pub trait FromRequest: Sized { /// If the extractor fails it'll use this "rejection" type. A rejection is /// a kind of error that can be converted into a response. type Rejection: IntoResponse; /// Perform the extraction. - async fn from_request(req: &mut RequestParts) -> Result; + async fn from_request(req: &mut RequestParts) -> Result; } /// The type used with [`FromRequest`] to extract data from requests. /// /// Has several convenience methods for getting owned parts of the request. #[derive(Debug)] -pub struct RequestParts { +pub struct RequestParts { state: S, method: Method, uri: Uri, @@ -85,7 +85,7 @@ pub struct RequestParts { body: Option, } -impl RequestParts { +impl RequestParts<(), B> { /// Create a new `RequestParts` without any state. /// /// You generally shouldn't need to construct this type yourself, unless @@ -98,7 +98,7 @@ impl RequestParts { } } -impl RequestParts { +impl RequestParts { /// Create a new `RequestParts` with the given state. /// /// You generally shouldn't need to construct this type yourself, unless @@ -147,14 +147,14 @@ impl RequestParts { /// use http::{Method, Uri}; /// /// #[async_trait] - /// impl FromRequest for MyExtractor + /// impl FromRequest for MyExtractor /// where /// B: Send, /// S: Send, /// { /// type Rejection = Infallible; /// - /// async fn from_request(req: &mut RequestParts) -> Result { + /// async fn from_request(req: &mut RequestParts) -> Result { /// let method = req.extract::().await?; /// let path = req.extract::().await?.path().to_owned(); /// @@ -164,7 +164,7 @@ impl RequestParts { /// ``` pub async fn extract(&mut self) -> Result where - E: FromRequest, + E: FromRequest, { E::from_request(self).await } @@ -278,29 +278,29 @@ impl RequestParts { } #[async_trait] -impl FromRequest for Option +impl FromRequest for Option where - T: FromRequest, + T: FromRequest, B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result, Self::Rejection> { + async fn from_request(req: &mut RequestParts) -> Result, Self::Rejection> { Ok(T::from_request(req).await.ok()) } } #[async_trait] -impl FromRequest for Result +impl FromRequest for Result where - T: FromRequest, + T: FromRequest, B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(T::from_request(req).await) } } diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index e9027d2611..4faaf2d355 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -6,14 +6,14 @@ use http::{Extensions, HeaderMap, Method, Request, Uri, Version}; use std::convert::Infallible; #[async_trait] -impl FromRequest for Request +impl FromRequest for Request where B: Send, S: Clone + Send, { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let req = std::mem::replace( req, RequestParts { @@ -32,40 +32,40 @@ where } #[async_trait] -impl FromRequest for Method +impl FromRequest for Method where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.method().clone()) } } #[async_trait] -impl FromRequest for Uri +impl FromRequest for Uri where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.uri().clone()) } } #[async_trait] -impl FromRequest for Version +impl FromRequest for Version where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.version()) } } @@ -76,20 +76,20 @@ where /// /// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html #[async_trait] -impl FromRequest for HeaderMap +impl FromRequest for HeaderMap where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(req.headers().clone()) } } #[async_trait] -impl FromRequest for Bytes +impl FromRequest for Bytes where B: http_body::Body + Send, B::Data: Send, @@ -98,7 +98,7 @@ where { type Rejection = BytesRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; let bytes = crate::body::to_bytes(body) @@ -110,7 +110,7 @@ where } #[async_trait] -impl FromRequest for String +impl FromRequest for String where B: http_body::Body + Send, B::Data: Send, @@ -119,7 +119,7 @@ where { type Rejection = StringRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; let bytes = crate::body::to_bytes(body) @@ -134,14 +134,14 @@ where } #[async_trait] -impl FromRequest for http::request::Parts +impl FromRequest for http::request::Parts where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let method = unwrap_infallible(Method::from_request(req).await); let uri = unwrap_infallible(Uri::from_request(req).await); let version = unwrap_infallible(Version::from_request(req).await); @@ -168,6 +168,6 @@ fn unwrap_infallible(result: Result) -> T { } } -pub(crate) fn take_body(req: &mut RequestParts) -> Result { +pub(crate) fn take_body(req: &mut RequestParts) -> Result { req.take_body().ok_or(BodyAlreadyExtracted) } diff --git a/axum-core/src/extract/tuple.rs b/axum-core/src/extract/tuple.rs index 1ae56e032a..05e38bf004 100644 --- a/axum-core/src/extract/tuple.rs +++ b/axum-core/src/extract/tuple.rs @@ -4,14 +4,14 @@ use async_trait::async_trait; use std::convert::Infallible; #[async_trait] -impl FromRequest for () +impl FromRequest for () where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(_: &mut RequestParts) -> Result<(), Self::Rejection> { + async fn from_request(_: &mut RequestParts) -> Result<(), Self::Rejection> { Ok(()) } } @@ -22,15 +22,15 @@ macro_rules! impl_from_request { ( $($ty:ident),* $(,)? ) => { #[async_trait] #[allow(non_snake_case)] - impl FromRequest for ($($ty,)*) + impl FromRequest for ($($ty,)*) where - $( $ty: FromRequest + Send, )* + $( $ty: FromRequest + Send, )* B: Send, S: Send, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { $( let $ty = $ty::from_request(req).await.map_err(|err| err.into_response())?; )* Ok(($($ty,)*)) } diff --git a/axum-extra/src/either.rs b/axum-extra/src/either.rs index e052105ad0..84b2a91f65 100755 --- a/axum-extra/src/either.rs +++ b/axum-extra/src/either.rs @@ -190,16 +190,16 @@ macro_rules! impl_traits_for_either { $last:ident $(,)? ) => { #[async_trait] - impl FromRequest for $either<$($ident),*, $last> + impl FromRequest for $either<$($ident),*, $last> where - $($ident: FromRequest),*, - $last: FromRequest, + $($ident: FromRequest),*, + $last: FromRequest, B: Send, S: Send, { type Rejection = $last::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { $( if let Ok(value) = req.extract().await { return Ok(Self::$ident(value)); diff --git a/axum-extra/src/extract/cached.rs b/axum-extra/src/extract/cached.rs index aa303aca18..9545fe16dd 100644 --- a/axum-extra/src/extract/cached.rs +++ b/axum-extra/src/extract/cached.rs @@ -30,14 +30,14 @@ use std::ops::{Deref, DerefMut}; /// struct Session { /* ... */ } /// /// #[async_trait] -/// impl FromRequest for Session +/// impl FromRequest for Session /// where /// B: Send, /// S: Send, /// { /// type Rejection = (StatusCode, String); /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: &mut RequestParts) -> Result { /// // load session... /// # unimplemented!() /// } @@ -46,14 +46,14 @@ use std::ops::{Deref, DerefMut}; /// struct CurrentUser { /* ... */ } /// /// #[async_trait] -/// impl FromRequest for CurrentUser +/// impl FromRequest for CurrentUser /// where /// B: Send, /// S: Send, /// { /// type Rejection = Response; /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: &mut RequestParts) -> Result { /// // loading a `CurrentUser` requires first loading the `Session` /// // /// // by using `Cached` we avoid extracting the session more than @@ -90,15 +90,15 @@ pub struct Cached(pub T); struct CachedEntry(T); #[async_trait] -impl FromRequest for Cached +impl FromRequest for Cached where B: Send, S: Send, - T: FromRequest + Clone + Send + Sync + 'static, + T: FromRequest + Clone + Send + Sync + 'static, { type Rejection = T::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match Extension::>::from_request(req).await { Ok(Extension(CachedEntry(value))) => Ok(Self(value)), Err(_) => { @@ -142,14 +142,14 @@ mod tests { struct Extractor(Instant); #[async_trait] - impl FromRequest for Extractor + impl FromRequest for Extractor where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { COUNTER.fetch_add(1, Ordering::SeqCst); Ok(Self(Instant::now())) } diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index 897d5fd423..38ecbbe54b 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -88,14 +88,14 @@ pub struct CookieJar { } #[async_trait] -impl FromRequest for CookieJar +impl FromRequest for CookieJar where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(Self::from_headers(req.headers())) } } diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index d285230f22..7540b1764f 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -87,7 +87,7 @@ impl fmt::Debug for PrivateCookieJar { } #[async_trait] -impl FromRequest for PrivateCookieJar +impl FromRequest for PrivateCookieJar where B: Send, S: Into + Clone + Send, @@ -95,7 +95,7 @@ where { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let state = req.state().clone(); let key: K = state.into(); let key: Key = key.into(); diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index 94a7142bfa..d56a8eb567 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -105,7 +105,7 @@ impl fmt::Debug for SignedCookieJar { } #[async_trait] -impl FromRequest for SignedCookieJar +impl FromRequest for SignedCookieJar where B: Send, S: Into + Clone + Send, @@ -113,7 +113,7 @@ where { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let state = req.state().clone(); let key: K = state.into(); let key: Key = key.into(); diff --git a/axum-extra/src/extract/form.rs b/axum-extra/src/extract/form.rs index dde954489a..593bfba660 100644 --- a/axum-extra/src/extract/form.rs +++ b/axum-extra/src/extract/form.rs @@ -55,7 +55,7 @@ impl Deref for Form { } #[async_trait] -impl FromRequest for Form +impl FromRequest for Form where T: DeserializeOwned, B: HttpBody + Send, @@ -65,7 +65,7 @@ where { type Rejection = FormRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if req.method() == Method::GET { let query = req.uri().query().unwrap_or_default(); let value = serde_html_form::from_str(query) @@ -86,7 +86,7 @@ where } // this is duplicated in `axum/src/extract/mod.rs` -fn has_content_type(req: &RequestParts, expected_content_type: &mime::Mime) -> bool { +fn has_content_type(req: &RequestParts, expected_content_type: &mime::Mime) -> bool { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { content_type } else { diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index 9080c23446..debc6957a3 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -58,7 +58,7 @@ use std::ops::Deref; pub struct Query(pub T); #[async_trait] -impl FromRequest for Query +impl FromRequest for Query where T: DeserializeOwned, B: Send, @@ -66,7 +66,7 @@ where { type Rejection = QueryRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let query = req.uri().query().unwrap_or_default(); let value = serde_html_form::from_str(query) .map_err(FailedToDeserializeQueryString::__private_new)?; diff --git a/axum-extra/src/extract/with_rejection.rs b/axum-extra/src/extract/with_rejection.rs index 0cca28891d..e0d2135cc3 100644 --- a/axum-extra/src/extract/with_rejection.rs +++ b/axum-extra/src/extract/with_rejection.rs @@ -107,16 +107,16 @@ impl DerefMut for WithRejection { } #[async_trait] -impl FromRequest for WithRejection +impl FromRequest for WithRejection where B: Send, S: Send, - E: FromRequest, + E: FromRequest, R: From + IntoResponse, { type Rejection = R; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let extractor = req.extract::().await?; Ok(WithRejection(extractor, PhantomData)) } @@ -135,14 +135,14 @@ mod tests { struct TestRejection; #[async_trait] - impl FromRequest for TestExtractor + impl FromRequest for TestExtractor where B: Send, S: Send, { type Rejection = (); - async fn from_request(_: &mut RequestParts) -> Result { + async fn from_request(_: &mut RequestParts) -> Result { Err(()) } } diff --git a/axum-extra/src/handler/mod.rs b/axum-extra/src/handler/mod.rs index 31d464b548..56af0ea215 100644 --- a/axum-extra/src/handler/mod.rs +++ b/axum-extra/src/handler/mod.rs @@ -67,14 +67,14 @@ pub trait HandlerCallWithExtractors: Sized { /// struct AdminPermissions {} /// /// #[async_trait] - /// impl FromRequest for AdminPermissions + /// impl FromRequest for AdminPermissions /// where /// B: Send, /// S: Send, /// { /// // check for admin permissions... /// # type Rejection = (); - /// # async fn from_request(req: &mut axum::extract::RequestParts) -> Result { + /// # async fn from_request(req: &mut axum::extract::RequestParts) -> Result { /// # todo!() /// # } /// } @@ -82,14 +82,14 @@ pub trait HandlerCallWithExtractors: Sized { /// struct User {} /// /// #[async_trait] - /// impl FromRequest for User + /// impl FromRequest for User /// where /// B: Send, /// S: Send, /// { /// // check for a logged in user... /// # type Rejection = (); - /// # async fn from_request(req: &mut axum::extract::RequestParts) -> Result { + /// # async fn from_request(req: &mut axum::extract::RequestParts) -> Result { /// # todo!() /// # } /// } @@ -119,7 +119,7 @@ pub trait HandlerCallWithExtractors: Sized { macro_rules! impl_handler_call_with { ( $($ty:ident),* $(,)? ) => { #[allow(non_snake_case)] - impl HandlerCallWithExtractors<($($ty,)*), S, B> for F + impl HandlerCallWithExtractors<($($ty,)*), S, B> for F where F: FnOnce($($ty,)*) -> Fut, Fut: Future + Send + 'static, @@ -169,7 +169,7 @@ pub struct IntoHandler { impl Handler for IntoHandler where H: HandlerCallWithExtractors + Clone + Send + 'static, - T: FromRequest + Send + 'static, + T: FromRequest + Send + 'static, T::Rejection: Send, B: Send + 'static, S: Clone + Send + 'static, diff --git a/axum-extra/src/handler/or.rs b/axum-extra/src/handler/or.rs index 4116253dc7..c2599d582f 100644 --- a/axum-extra/src/handler/or.rs +++ b/axum-extra/src/handler/or.rs @@ -55,12 +55,12 @@ where } } -impl Handler<(Lt, Rt), S, B> for Or +impl Handler<(Lt, Rt), S, B> for Or where L: HandlerCallWithExtractors + Clone + Send + 'static, R: HandlerCallWithExtractors + Clone + Send + 'static, - Lt: FromRequest + Send + 'static, - Rt: FromRequest + Send + 'static, + Lt: FromRequest + Send + 'static, + Rt: FromRequest + Send + 'static, Lt::Rejection: Send, Rt::Rejection: Send, B: Send + 'static, diff --git a/axum-extra/src/json_lines.rs b/axum-extra/src/json_lines.rs index bd5b0b944a..242b43e70f 100644 --- a/axum-extra/src/json_lines.rs +++ b/axum-extra/src/json_lines.rs @@ -98,7 +98,7 @@ impl JsonLines { } #[async_trait] -impl FromRequest for JsonLines +impl FromRequest for JsonLines where B: HttpBody + Send + 'static, B::Data: Into, @@ -108,7 +108,7 @@ where { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { // `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead` // so we can call `AsyncRead::lines` and then convert it back to a `Stream` diff --git a/axum-extra/src/protobuf.rs b/axum-extra/src/protobuf.rs index 698e4b7985..a30421a033 100644 --- a/axum-extra/src/protobuf.rs +++ b/axum-extra/src/protobuf.rs @@ -97,7 +97,7 @@ use std::ops::{Deref, DerefMut}; pub struct ProtoBuf(pub T); #[async_trait] -impl FromRequest for ProtoBuf +impl FromRequest for ProtoBuf where T: Message + Default, B: HttpBody + Send, @@ -107,7 +107,7 @@ where { type Rejection = ProtoBufRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let mut bytes = Bytes::from_request(req).await?; match T::decode(&mut bytes) { diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 81b9c4439d..4972fc49d6 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -203,7 +203,7 @@ fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStr #[allow(warnings)] fn #name() where - #ty: ::axum::extract::FromRequest<#body_ty, ()> + Send, + #ty: ::axum::extract::FromRequest<(), #body_ty> + Send, {} } }) diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs index bf7bc7d547..9e84bd248f 100644 --- a/axum-macros/src/from_request.rs +++ b/axum-macros/src/from_request.rs @@ -218,7 +218,7 @@ fn impl_struct_by_extracting_each_field( Ok(quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, @@ -228,7 +228,7 @@ fn impl_struct_by_extracting_each_field( type Rejection = #rejection_ident; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: &mut ::axum::extract::RequestParts, ) -> ::std::result::Result { ::std::result::Result::Ok(Self { #(#extract_fields)* @@ -423,7 +423,7 @@ fn extract_each_field_rejection( Ok(quote_spanned! {ty_span=> #[allow(non_camel_case_types)] - #variant_name(<#extractor_ty as ::axum::extract::FromRequest<::axum::body::Body, ()>>::Rejection), + #variant_name(<#extractor_ty as ::axum::extract::FromRequest<(), ::axum::body::Body>>::Rejection), }) }) .collect::>>()?; @@ -610,26 +610,26 @@ fn impl_struct_by_extracting_all_at_once( quote! { #rejection } } else { quote! { - <#path as ::axum::extract::FromRequest>::Rejection + <#path as ::axum::extract::FromRequest>::Rejection } }; let rejection_bound = rejection.as_ref().map(|rejection| { if generic_ident.is_some() { quote! { - #rejection: ::std::convert::From<<#path as ::axum::extract::FromRequest>::Rejection>, + #rejection: ::std::convert::From<<#path as ::axum::extract::FromRequest>::Rejection>, } } else { quote! { - #rejection: ::std::convert::From<<#path as ::axum::extract::FromRequest>::Rejection>, + #rejection: ::std::convert::From<<#path as ::axum::extract::FromRequest>::Rejection>, } } }).unwrap_or_default(); let impl_generics = if generic_ident.is_some() { - quote! { B, S, T } + quote! { S, B, T } } else { - quote! { B, S } + quote! { S, B } }; let type_generics = generic_ident @@ -654,9 +654,9 @@ fn impl_struct_by_extracting_all_at_once( Ok(quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] - impl<#impl_generics> ::axum::extract::FromRequest for #ident #type_generics + impl<#impl_generics> ::axum::extract::FromRequest for #ident #type_generics where - #path<#via_type_generics>: ::axum::extract::FromRequest, + #path<#via_type_generics>: ::axum::extract::FromRequest, #rejection_bound B: ::std::marker::Send, S: ::std::marker::Send, @@ -664,9 +664,9 @@ fn impl_struct_by_extracting_all_at_once( type Rejection = #associated_rejection_type; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: &mut ::axum::extract::RequestParts, ) -> ::std::result::Result { - ::axum::extract::FromRequest::::from_request(req) + ::axum::extract::FromRequest::::from_request(req) .await .map(|#path(value)| #value_to_self) .map_err(::std::convert::From::from) @@ -711,7 +711,7 @@ fn impl_enum_by_extracting_all_at_once( quote! { #rejection } } else { quote! { - <#path as ::axum::extract::FromRequest>::Rejection + <#path as ::axum::extract::FromRequest>::Rejection } }; @@ -720,7 +720,7 @@ fn impl_enum_by_extracting_all_at_once( Ok(quote_spanned! {path_span=> #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, @@ -730,9 +730,9 @@ fn impl_enum_by_extracting_all_at_once( type Rejection = #associated_rejection_type; async fn from_request( - req: &mut ::axum::extract::RequestParts, + req: &mut ::axum::extract::RequestParts, ) -> ::std::result::Result { - ::axum::extract::FromRequest::::from_request(req) + ::axum::extract::FromRequest::::from_request(req) .await .map(|#path(inner)| inner) .map_err(::std::convert::From::from) diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index 7bebec7236..7b27a4bd83 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -125,7 +125,7 @@ mod typed_path; /// ``` /// pub struct ViaExtractor(pub T); /// -/// // impl FromRequest for ViaExtractor { ... } +/// // impl FromRequest for ViaExtractor { ... } /// ``` /// /// More complex via extractors are not supported and require writing a manual implementation. @@ -223,7 +223,7 @@ mod typed_path; /// struct OtherExtractor; /// /// #[async_trait] -/// impl FromRequest for OtherExtractor +/// impl FromRequest for OtherExtractor /// where /// B: Send, /// S: Send, @@ -231,7 +231,7 @@ mod typed_path; /// // this rejection doesn't implement `Display` and `Error` /// type Rejection = (StatusCode, String); /// -/// async fn from_request(_req: &mut RequestParts) -> Result { +/// async fn from_request(_req: &mut RequestParts) -> Result { /// // ... /// # unimplemented!() /// } diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs index 9e3c508d57..6a8f03c170 100644 --- a/axum-macros/src/typed_path.rs +++ b/axum-macros/src/typed_path.rs @@ -127,14 +127,14 @@ fn expand_named_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: Send, S: Send, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { + async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { ::axum::extract::Path::from_request(req) .await .map(|path| path.0) @@ -230,14 +230,14 @@ fn expand_unnamed_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: Send, S: Send, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { + async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { ::axum::extract::Path::from_request(req) .await .map(|path| path.0) @@ -312,14 +312,14 @@ fn expand_unit_fields( let from_request_impl = quote! { #[::axum::async_trait] #[automatically_derived] - impl ::axum::extract::FromRequest for #ident + impl ::axum::extract::FromRequest for #ident where B: Send, S: Send, { type Rejection = #rejection_assoc_type; - async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { + async fn from_request(req: &mut ::axum::extract::RequestParts) -> ::std::result::Result { if req.uri().path() == ::PATH { Ok(Self) } else { @@ -390,7 +390,7 @@ enum Segment { fn path_rejection() -> TokenStream { quote! { - <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection + <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection } } diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr index 420005970c..265258419e 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr @@ -1,17 +1,17 @@ -error[E0277]: the trait bound `bool: FromRequest` is not satisfied +error[E0277]: the trait bound `bool: FromRequest<(), Body>` is not satisfied --> tests/debug_handler/fail/argument_not_extractor.rs:4:23 | 4 | async fn handler(foo: bool) {} - | ^^^^ the trait `FromRequest` is not implemented for `bool` + | ^^^^ the trait `FromRequest<(), Body>` is not implemented for `bool` | - = help: the following other types implement trait `FromRequest`: - <() as FromRequest> - <(T1, T2) as FromRequest> - <(T1, T2, T3) as FromRequest> - <(T1, T2, T3, T4) as FromRequest> - <(T1, T2, T3, T4, T5) as FromRequest> - <(T1, T2, T3, T4, T5, T6) as FromRequest> - <(T1, T2, T3, T4, T5, T6, T7) as FromRequest> - <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequest> + = help: the following other types implement trait `FromRequest`: + <() as FromRequest> + <(T1, T2) as FromRequest> + <(T1, T2, T3) as FromRequest> + <(T1, T2, T3, T4) as FromRequest> + <(T1, T2, T3, T4, T5) as FromRequest> + <(T1, T2, T3, T4, T5, T6) as FromRequest> + <(T1, T2, T3, T4, T5, T6, T7) as FromRequest> + <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequest> and 34 others = help: see issue #48214 diff --git a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs index 910ba78ced..d38d5e0c4d 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs @@ -7,14 +7,14 @@ use axum_macros::debug_handler; struct A; #[async_trait] -impl FromRequest for A +impl FromRequest for A where B: Send, S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs index 75d8f5ae18..06b87f0a82 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs @@ -7,14 +7,14 @@ use axum_macros::debug_handler; struct A; #[async_trait] -impl FromRequest for A +impl FromRequest for A where B: Send, S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs index ebf02a2629..762809b62a 100644 --- a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs +++ b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs @@ -120,14 +120,14 @@ impl A { } #[async_trait] -impl FromRequest for A +impl FromRequest for A where B: Send, S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/debug_handler/pass/self_receiver.rs b/axum-macros/tests/debug_handler/pass/self_receiver.rs index 3939349836..a88382cf18 100644 --- a/axum-macros/tests/debug_handler/pass/self_receiver.rs +++ b/axum-macros/tests/debug_handler/pass/self_receiver.rs @@ -7,14 +7,14 @@ use axum_macros::debug_handler; struct A; #[async_trait] -impl FromRequest for A +impl FromRequest for A where B: Send, S: Send, { type Rejection = (); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/from_request/pass/container.rs b/axum-macros/tests/from_request/pass/container.rs index fe388e4806..e8eaa0a58a 100644 --- a/axum-macros/tests/from_request/pass/container.rs +++ b/axum-macros/tests/from_request/pass/container.rs @@ -15,7 +15,7 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest, + Extractor: FromRequest<(), Body, Rejection = JsonRejection>, { } diff --git a/axum-macros/tests/from_request/pass/derive_opt_out.rs b/axum-macros/tests/from_request/pass/derive_opt_out.rs index f852115361..e73d5a959c 100644 --- a/axum-macros/tests/from_request/pass/derive_opt_out.rs +++ b/axum-macros/tests/from_request/pass/derive_opt_out.rs @@ -14,14 +14,14 @@ struct Extractor { struct OtherExtractor; #[async_trait] -impl FromRequest for OtherExtractor +impl FromRequest for OtherExtractor where B: Send, S: Send, { type Rejection = OtherExtractorRejection; - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { unimplemented!() } } diff --git a/axum-macros/tests/from_request/pass/empty_named.rs b/axum-macros/tests/from_request/pass/empty_named.rs index c550f77a03..eec021d0f5 100644 --- a/axum-macros/tests/from_request/pass/empty_named.rs +++ b/axum-macros/tests/from_request/pass/empty_named.rs @@ -5,7 +5,7 @@ struct Extractor {} fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, { } diff --git a/axum-macros/tests/from_request/pass/empty_tuple.rs b/axum-macros/tests/from_request/pass/empty_tuple.rs index 6429b4f9f2..3d8bcd25c0 100644 --- a/axum-macros/tests/from_request/pass/empty_tuple.rs +++ b/axum-macros/tests/from_request/pass/empty_tuple.rs @@ -5,7 +5,7 @@ struct Extractor(); fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, { } diff --git a/axum-macros/tests/from_request/pass/named.rs b/axum-macros/tests/from_request/pass/named.rs index cd4b8649ac..89fb8da004 100644 --- a/axum-macros/tests/from_request/pass/named.rs +++ b/axum-macros/tests/from_request/pass/named.rs @@ -18,7 +18,7 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest, + Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>, { } diff --git a/axum-macros/tests/from_request/pass/named_via.rs b/axum-macros/tests/from_request/pass/named_via.rs index 44c633a097..8a81869d1a 100644 --- a/axum-macros/tests/from_request/pass/named_via.rs +++ b/axum-macros/tests/from_request/pass/named_via.rs @@ -25,7 +25,7 @@ struct Extractor { fn assert_from_request() where - Extractor: FromRequest, + Extractor: FromRequest<(), Body, Rejection = ExtractorRejection>, { } diff --git a/axum-macros/tests/from_request/pass/override_rejection.rs b/axum-macros/tests/from_request/pass/override_rejection.rs index 7339bbdb36..c308d61521 100644 --- a/axum-macros/tests/from_request/pass/override_rejection.rs +++ b/axum-macros/tests/from_request/pass/override_rejection.rs @@ -28,7 +28,7 @@ struct MyExtractor { struct OtherExtractor; #[async_trait] -impl FromRequest for OtherExtractor +impl FromRequest for OtherExtractor where B: Send + 'static, S: Send, @@ -36,7 +36,7 @@ where // this rejection doesn't implement `Display` and `Error` type Rejection = (StatusCode, String); - async fn from_request(_req: &mut RequestParts) -> Result { + async fn from_request(_req: &mut RequestParts) -> Result { todo!() } } diff --git a/axum-macros/tests/from_request/pass/tuple.rs b/axum-macros/tests/from_request/pass/tuple.rs index 7561f998a2..2af407d0f9 100644 --- a/axum-macros/tests/from_request/pass/tuple.rs +++ b/axum-macros/tests/from_request/pass/tuple.rs @@ -5,7 +5,7 @@ struct Extractor(axum::http::HeaderMap, String); fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs index 3ed6ad0853..00b6dd78df 100644 --- a/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs @@ -13,7 +13,7 @@ struct Payload {} fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs index 13ee5f259c..0b148ebc50 100644 --- a/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice_via.rs @@ -27,7 +27,7 @@ struct Payload {} fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } diff --git a/axum-macros/tests/from_request/pass/tuple_via.rs b/axum-macros/tests/from_request/pass/tuple_via.rs index 01b23bdb84..03a9e3610c 100644 --- a/axum-macros/tests/from_request/pass/tuple_via.rs +++ b/axum-macros/tests/from_request/pass/tuple_via.rs @@ -9,7 +9,7 @@ struct State; fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body>, { } diff --git a/axum-macros/tests/from_request/pass/unit.rs b/axum-macros/tests/from_request/pass/unit.rs index 76073d2777..3e5d986917 100644 --- a/axum-macros/tests/from_request/pass/unit.rs +++ b/axum-macros/tests/from_request/pass/unit.rs @@ -5,7 +5,7 @@ struct Extractor; fn assert_from_request() where - Extractor: axum::extract::FromRequest, + Extractor: axum::extract::FromRequest<(), axum::body::Body, Rejection = std::convert::Infallible>, { } diff --git a/axum-macros/tests/typed_path/fail/not_deserialize.stderr b/axum-macros/tests/typed_path/fail/not_deserialize.stderr index bc77a0d2ea..9aabf3625f 100644 --- a/axum-macros/tests/typed_path/fail/not_deserialize.stderr +++ b/axum-macros/tests/typed_path/fail/not_deserialize.stderr @@ -15,5 +15,5 @@ error[E0277]: the trait bound `for<'de> MyPath: serde::de::Deserialize<'de>` is (T0, T1, T2, T3) and 138 others = note: required because of the requirements on the impl of `serde::de::DeserializeOwned` for `MyPath` - = note: required because of the requirements on the impl of `FromRequest` for `axum::extract::Path` + = note: required because of the requirements on the impl of `FromRequest` for `axum::extract::Path` = note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum/src/docs/extract.md b/axum/src/docs/extract.md index c652093157..9c31d0e4d3 100644 --- a/axum/src/docs/extract.md +++ b/axum/src/docs/extract.md @@ -421,14 +421,14 @@ use http::{StatusCode, header::{HeaderValue, USER_AGENT}}; struct ExtractUserAgent(HeaderValue); #[async_trait] -impl FromRequest for ExtractUserAgent +impl FromRequest for ExtractUserAgent where B: Send, S: Send, { type Rejection = (StatusCode, &'static str); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if let Some(user_agent) = req.headers().get(USER_AGENT) { Ok(ExtractUserAgent(user_agent.clone())) } else { @@ -473,14 +473,14 @@ struct AuthenticatedUser { } #[async_trait] -impl FromRequest for AuthenticatedUser +impl FromRequest for AuthenticatedUser where B: Send, S: Send, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let TypedHeader(Authorization(token)) = TypedHeader::>::from_request(req) .await diff --git a/axum/src/error_handling/mod.rs b/axum/src/error_handling/mod.rs index de0ee8bc02..6a72d82069 100644 --- a/axum/src/error_handling/mod.rs +++ b/axum/src/error_handling/mod.rs @@ -161,7 +161,7 @@ macro_rules! impl_service { F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static, Fut: Future + Send, Res: IntoResponse, - $( $ty: FromRequest + Send,)* + $( $ty: FromRequest<(), B> + Send,)* B: Send + 'static, { type Response = Response; diff --git a/axum/src/extension.rs b/axum/src/extension.rs index 5d8d60ecfc..d040b9e4e5 100644 --- a/axum/src/extension.rs +++ b/axum/src/extension.rs @@ -73,7 +73,7 @@ use tower_service::Service; pub struct Extension(pub T); #[async_trait] -impl FromRequest for Extension +impl FromRequest for Extension where T: Clone + Send + Sync + 'static, B: Send, @@ -81,7 +81,7 @@ where { type Rejection = ExtensionRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let value = req .extensions() .get::() diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index ceedc116ac..3aa7684c81 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -128,15 +128,15 @@ opaque_future! { pub struct ConnectInfo(pub T); #[async_trait] -impl FromRequest for ConnectInfo +impl FromRequest for ConnectInfo where B: Send, S: Send, T: Clone + Send + Sync + 'static, { - type Rejection = as FromRequest>::Rejection; + type Rejection = as FromRequest>::Rejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let Extension(connect_info) = Extension::::from_request(req).await?; Ok(connect_info) } diff --git a/axum/src/extract/content_length_limit.rs b/axum/src/extract/content_length_limit.rs index 8584c737dc..f4c475437f 100644 --- a/axum/src/extract/content_length_limit.rs +++ b/axum/src/extract/content_length_limit.rs @@ -36,16 +36,16 @@ use std::ops::Deref; pub struct ContentLengthLimit(pub T); #[async_trait] -impl FromRequest for ContentLengthLimit +impl FromRequest for ContentLengthLimit where - T: FromRequest, + T: FromRequest, T::Rejection: IntoResponse, B: Send, S: Send, { type Rejection = ContentLengthLimitRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let content_length = req .headers() .get(http::header::CONTENT_LENGTH) diff --git a/axum/src/extract/host.rs b/axum/src/extract/host.rs index 71acb938ea..79ae13fc28 100644 --- a/axum/src/extract/host.rs +++ b/axum/src/extract/host.rs @@ -21,14 +21,14 @@ const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; pub struct Host(pub String); #[async_trait] -impl FromRequest for Host +impl FromRequest for Host where B: Send, S: Send, { type Rejection = HostRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if let Some(host) = parse_forwarded(req.headers()) { return Ok(Host(host.to_owned())); } diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index cf2b7e8b94..6413cf7714 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -64,14 +64,14 @@ impl MatchedPath { } #[async_trait] -impl FromRequest for MatchedPath +impl FromRequest for MatchedPath where B: Send, S: Send, { type Rejection = MatchedPathRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let matched_path = req .extensions() .get::() @@ -96,7 +96,7 @@ mod tests { #[derive(Clone)] struct SetMatchedPathExtension(S); - impl Service> for SetMatchedPathExtension + impl Service> for SetMatchedPathExtension where S: Service>, { diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 511447d41c..c4aeb3ad70 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -75,13 +75,13 @@ pub use self::ws::WebSocketUpgrade; #[doc(no_inline)] pub use crate::TypedHeader; -pub(crate) fn take_body(req: &mut RequestParts) -> Result { +pub(crate) fn take_body(req: &mut RequestParts) -> Result { req.take_body().ok_or_else(BodyAlreadyExtracted::default) } // this is duplicated in `axum-extra/src/extract/form.rs` -pub(super) fn has_content_type( - req: &RequestParts, +pub(super) fn has_content_type( + req: &RequestParts, expected_content_type: &mime::Mime, ) -> bool { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index a3cb9ea3e0..076f4db106 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -50,7 +50,7 @@ pub struct Multipart { } #[async_trait] -impl FromRequest for Multipart +impl FromRequest for Multipart where B: HttpBody + Default + Unpin + Send + 'static, B::Error: Into, @@ -58,7 +58,7 @@ where { type Rejection = MultipartRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let stream = BodyStream::from_request(req).await?; let headers = req.headers(); let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?; diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index e7d7c4e437..ca9e9fb605 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -163,7 +163,7 @@ impl DerefMut for Path { } #[async_trait] -impl FromRequest for Path +impl FromRequest for Path where T: DeserializeOwned + Send, B: Send, @@ -171,7 +171,7 @@ where { type Rejection = PathRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let params = match req.extensions_mut().get::() { Some(UrlParams::Params(params)) => params, Some(UrlParams::InvalidUtf8InPathParam { key }) => { diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index 1050e6be68..ce1f747cc1 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -49,7 +49,7 @@ use std::ops::Deref; pub struct Query(pub T); #[async_trait] -impl FromRequest for Query +impl FromRequest for Query where T: DeserializeOwned, B: Send, @@ -57,7 +57,7 @@ where { type Rejection = QueryRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let query = req.uri().query().unwrap_or_default(); let value = serde_urlencoded::from_str(query) .map_err(FailedToDeserializeQueryString::__private_new)?; diff --git a/axum/src/extract/raw_query.rs b/axum/src/extract/raw_query.rs index eeda0f44be..faf8df6e4c 100644 --- a/axum/src/extract/raw_query.rs +++ b/axum/src/extract/raw_query.rs @@ -27,14 +27,14 @@ use std::convert::Infallible; pub struct RawQuery(pub Option); #[async_trait] -impl FromRequest for RawQuery +impl FromRequest for RawQuery where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let query = req.uri().query().map(|query| query.to_owned()); Ok(Self(query)) } diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs index 5ab9ae53f0..c04ff5349b 100644 --- a/axum/src/extract/request_parts.rs +++ b/axum/src/extract/request_parts.rs @@ -86,14 +86,14 @@ pub struct OriginalUri(pub Uri); #[cfg(feature = "original-uri")] #[async_trait] -impl FromRequest for OriginalUri +impl FromRequest for OriginalUri where B: Send, S: Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let uri = Extension::::from_request(req) .await .unwrap_or_else(|_| Extension(OriginalUri(req.uri().clone()))) @@ -141,7 +141,7 @@ impl Stream for BodyStream { } #[async_trait] -impl FromRequest for BodyStream +impl FromRequest for BodyStream where B: HttpBody + Send + 'static, B::Data: Into, @@ -150,7 +150,7 @@ where { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)? .map_data(Into::into) .map_err(|err| Error::new(err.into())); @@ -198,14 +198,14 @@ fn body_stream_traits() { pub struct RawBody(pub B); #[async_trait] -impl FromRequest for RawBody +impl FromRequest for RawBody where B: Send, S: Send, { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; Ok(Self(body)) } diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index 6e9aa224b3..9991ad45c2 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -147,7 +147,7 @@ use std::{ /// struct MyLibraryExtractor; /// /// #[async_trait] -/// impl FromRequest for MyLibraryExtractor +/// impl FromRequest for MyLibraryExtractor /// where /// B: Send, /// // keep `S` generic but require that it can produce a `MyLibraryState` @@ -156,7 +156,7 @@ use std::{ /// { /// type Rejection = Infallible; /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: &mut RequestParts) -> Result { /// // get a `MyLibraryState` from the shared application state /// let state: MyLibraryState = req.state().clone().into(); /// @@ -177,14 +177,14 @@ use std::{ pub struct State(pub S); #[async_trait] -impl FromRequest for State +impl FromRequest for State where B: Send, OuterState: Clone + Into + Send, { type Rejection = Infallible; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let outer_state = req.state().clone(); let inner_state = outer_state.into(); Ok(Self(inner_state)) diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index ff057343ca..952ea13636 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -275,14 +275,14 @@ impl WebSocketUpgrade { } #[async_trait] -impl FromRequest for WebSocketUpgrade +impl FromRequest for WebSocketUpgrade where B: Send, S: Send, { type Rejection = WebSocketUpgradeRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if req.method() != Method::GET { return Err(MethodNotGet.into()); } @@ -321,7 +321,7 @@ where } } -fn header_eq(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { +fn header_eq(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { if let Some(header) = req.headers().get(&key) { header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) } else { @@ -329,7 +329,7 @@ fn header_eq(req: &RequestParts, key: HeaderName, value: &'static st } } -fn header_contains(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { +fn header_contains(req: &RequestParts, key: HeaderName, value: &'static str) -> bool { let header = if let Some(header) = req.headers().get(&key) { header } else { diff --git a/axum/src/form.rs b/axum/src/form.rs index 267cd608c2..8267b8efe5 100644 --- a/axum/src/form.rs +++ b/axum/src/form.rs @@ -56,7 +56,7 @@ use std::ops::Deref; pub struct Form(pub T); #[async_trait] -impl FromRequest for Form +impl FromRequest for Form where T: DeserializeOwned, B: HttpBody + Send, @@ -66,7 +66,7 @@ where { type Rejection = FormRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if req.method() == Method::GET { let query = req.uri().query().unwrap_or_default(); let value = serde_urlencoded::from_str(query) diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 4d5e922066..2a975d6bb3 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -174,14 +174,14 @@ where macro_rules! impl_handler { ( $($ty:ident),* $(,)? ) => { #[allow(non_snake_case)] - impl Handler<($($ty,)*), S, B> for F + impl Handler<($($ty,)*), S, B> for F where F: FnOnce($($ty,)*) -> Fut + Clone + Send + 'static, Fut: Future + Send, B: Send + 'static, S: Send + 'static, Res: IntoResponse, - $( $ty: FromRequest + Send,)* + $( $ty: FromRequest + Send,)* { type Future = Pin + Send>>; diff --git a/axum/src/json.rs b/axum/src/json.rs index 1a81c3867f..e35a1623ca 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -94,7 +94,7 @@ use std::ops::{Deref, DerefMut}; pub struct Json(pub T); #[async_trait] -impl FromRequest for Json +impl FromRequest for Json where T: DeserializeOwned, B: HttpBody + Send, @@ -104,7 +104,7 @@ where { type Rejection = JsonRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if json_content_type(req) { let bytes = Bytes::from_request(req).await?; @@ -137,7 +137,7 @@ where } } -fn json_content_type(req: &RequestParts) -> bool { +fn json_content_type(req: &RequestParts) -> bool { let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { content_type } else { diff --git a/axum/src/middleware/from_extractor.rs b/axum/src/middleware/from_extractor.rs index 4de38aa6ff..dfa3dfec82 100644 --- a/axum/src/middleware/from_extractor.rs +++ b/axum/src/middleware/from_extractor.rs @@ -45,14 +45,14 @@ use tower_service::Service; /// struct RequireAuth; /// /// #[async_trait] -/// impl FromRequest for RequireAuth +/// impl FromRequest for RequireAuth /// where /// B: Send, /// S: Send, /// { /// type Rejection = StatusCode; /// -/// async fn from_request(req: &mut RequestParts) -> Result { +/// async fn from_request(req: &mut RequestParts) -> Result { /// let auth_header = req /// .headers() /// .get(header::AUTHORIZATION) @@ -169,7 +169,7 @@ where impl Service> for FromExtractor where - E: FromRequest + 'static, + E: FromRequest<(), B> + 'static, B: Default + Send + 'static, S: Service> + Clone, S::Response: IntoResponse, @@ -204,7 +204,7 @@ pin_project! { #[allow(missing_debug_implementations)] pub struct ResponseFuture where - E: FromRequest, + E: FromRequest<(), B>, S: Service>, { #[pin] @@ -217,11 +217,11 @@ pin_project! { #[project = StateProj] enum State where - E: FromRequest, + E: FromRequest<(), B>, S: Service>, { Extracting { - future: BoxFuture<'static, (RequestParts, Result)>, + future: BoxFuture<'static, (RequestParts<(), B>, Result)>, }, Call { #[pin] future: S::Future }, } @@ -229,7 +229,7 @@ pin_project! { impl Future for ResponseFuture where - E: FromRequest, + E: FromRequest<(), B>, S: Service>, S::Response: IntoResponse, B: Default, @@ -280,14 +280,14 @@ mod tests { struct RequireAuth; #[async_trait::async_trait] - impl FromRequest for RequireAuth + impl FromRequest for RequireAuth where B: Send, S: Send, { type Rejection = StatusCode; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if let Some(auth) = req .headers() .get(header::AUTHORIZATION) diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index 4e26a760b9..0d37c61863 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -254,7 +254,7 @@ macro_rules! impl_service { impl Service> for FromFn where F: FnMut($($ty),*, Next) -> Fut + Clone + Send + 'static, - $( $ty: FromRequest + Send, )* + $( $ty: FromRequest<(), B> + Send, )* Fut: Future + Send + 'static, Out: IntoResponse + 'static, S: Service, Error = Infallible> diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 9c2dbd293e..a71dc2fd47 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -567,7 +567,7 @@ impl fmt::Debug for MethodRouter { } } -impl MethodRouter +impl MethodRouter where B: Send + 'static, { diff --git a/axum/src/routing/strip_prefix.rs b/axum/src/routing/strip_prefix.rs index feb5a5c4a3..ec0e232553 100644 --- a/axum/src/routing/strip_prefix.rs +++ b/axum/src/routing/strip_prefix.rs @@ -20,7 +20,7 @@ impl StripPrefix { } } -impl Service> for StripPrefix +impl Service> for StripPrefix where S: Service>, { diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index a931a032a2..f856b6e8bc 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -275,7 +275,7 @@ async fn outer_middleware_still_see_whole_url() { #[derive(Clone)] struct Uri(http::Uri); - impl Service> for SetUriExtension + impl Service> for SetUriExtension where S: Service>, { diff --git a/axum/src/typed_header.rs b/axum/src/typed_header.rs index 88f3bd55a3..c28a24a81a 100644 --- a/axum/src/typed_header.rs +++ b/axum/src/typed_header.rs @@ -52,7 +52,7 @@ use std::{convert::Infallible, ops::Deref}; pub struct TypedHeader(pub T); #[async_trait] -impl FromRequest for TypedHeader +impl FromRequest for TypedHeader where T: headers::Header, B: Send, @@ -60,7 +60,7 @@ where { type Rejection = TypedHeaderRejection; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match req.headers().typed_try_get::() { Ok(Some(value)) => Ok(Self(value)), Ok(None) => Err(TypedHeaderRejection { diff --git a/examples/consume-body-in-extractor-or-middleware/src/main.rs b/examples/consume-body-in-extractor-or-middleware/src/main.rs index 0c82440f9c..442a42b820 100644 --- a/examples/consume-body-in-extractor-or-middleware/src/main.rs +++ b/examples/consume-body-in-extractor-or-middleware/src/main.rs @@ -80,13 +80,13 @@ async fn handler(_: PrintRequestBody, body: Bytes) { struct PrintRequestBody; #[async_trait] -impl FromRequest for PrintRequestBody +impl FromRequest for PrintRequestBody where S: Send + Clone, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let state = req.state().clone(); let request = Request::from_request(req) diff --git a/examples/customize-extractor-error/src/main.rs b/examples/customize-extractor-error/src/main.rs index c435372fa5..20e3b4d482 100644 --- a/examples/customize-extractor-error/src/main.rs +++ b/examples/customize-extractor-error/src/main.rs @@ -56,7 +56,7 @@ struct User { struct Json(T); #[async_trait] -impl FromRequest for Json +impl FromRequest for Json where S: Send, // these trait bounds are copied from `impl FromRequest for axum::Json` @@ -67,7 +67,7 @@ where { type Rejection = (StatusCode, axum::Json); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match axum::Json::::from_request(req).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { diff --git a/examples/customize-path-rejection/src/main.rs b/examples/customize-path-rejection/src/main.rs index 7807249280..8330b95a93 100644 --- a/examples/customize-path-rejection/src/main.rs +++ b/examples/customize-path-rejection/src/main.rs @@ -52,7 +52,7 @@ struct Params { struct Path(T); #[async_trait] -impl FromRequest for Path +impl FromRequest for Path where // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` T: DeserializeOwned + Send, @@ -61,7 +61,7 @@ where { type Rejection = (StatusCode, axum::Json); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { match axum::extract::Path::::from_request(req).await { Ok(value) => Ok(Self(value.0)), Err(rejection) => { diff --git a/examples/jwt/src/main.rs b/examples/jwt/src/main.rs index ffed4c36a3..8725581da7 100644 --- a/examples/jwt/src/main.rs +++ b/examples/jwt/src/main.rs @@ -122,14 +122,14 @@ impl AuthBody { } #[async_trait] -impl FromRequest for Claims +impl FromRequest for Claims where S: Send, B: Send, { type Rejection = AuthError; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { // Extract the token from the authorization header let TypedHeader(Authorization(bearer)) = TypedHeader::>::from_request(req) diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index ba87b1cc95..303ca1647f 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -223,14 +223,14 @@ impl IntoResponse for AuthRedirect { } #[async_trait] -impl FromRequest for User +impl FromRequest for User where B: Send, { // If anything goes wrong or no session is found, redirect to the auth page type Rejection = AuthRedirect; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let store = req.state().clone().store; let cookies = TypedHeader::::from_request(req) diff --git a/examples/sessions/src/main.rs b/examples/sessions/src/main.rs index 50c1feb76f..cd0d41a1f6 100644 --- a/examples/sessions/src/main.rs +++ b/examples/sessions/src/main.rs @@ -80,13 +80,13 @@ enum UserIdFromSession { } #[async_trait] -impl FromRequest for UserIdFromSession +impl FromRequest for UserIdFromSession where B: Send, { type Rejection = (StatusCode, &'static str); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let store = req.state().clone(); let cookie = req.extract::>>().await.unwrap(); diff --git a/examples/sqlx-postgres/src/main.rs b/examples/sqlx-postgres/src/main.rs index c76444079d..6548cdeb97 100644 --- a/examples/sqlx-postgres/src/main.rs +++ b/examples/sqlx-postgres/src/main.rs @@ -75,13 +75,13 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(sqlx::pool::PoolConnection); #[async_trait] -impl FromRequest for DatabaseConnection +impl FromRequest for DatabaseConnection where B: Send, { type Rejection = (StatusCode, String); - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let pool = req.state().clone(); let conn = pool.acquire().await.map_err(internal_error)?; diff --git a/examples/tokio-postgres/src/main.rs b/examples/tokio-postgres/src/main.rs index 9b2310342c..e0c60453e3 100644 --- a/examples/tokio-postgres/src/main.rs +++ b/examples/tokio-postgres/src/main.rs @@ -68,14 +68,14 @@ async fn using_connection_pool_extractor( struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager>); #[async_trait] -impl FromRequest for DatabaseConnection +impl FromRequest for DatabaseConnection where B: Send, { type Rejection = (StatusCode, String); async fn from_request( - req: &mut RequestParts, + req: &mut RequestParts, ) -> Result { let pool = req.state().clone(); diff --git a/examples/validator/src/main.rs b/examples/validator/src/main.rs index efc966902d..8682eb85e5 100644 --- a/examples/validator/src/main.rs +++ b/examples/validator/src/main.rs @@ -60,7 +60,7 @@ async fn handler(ValidatedForm(input): ValidatedForm) -> Html pub struct ValidatedForm(pub T); #[async_trait] -impl FromRequest for ValidatedForm +impl FromRequest for ValidatedForm where T: DeserializeOwned + Validate, S: Send, @@ -70,7 +70,7 @@ where { type Rejection = ServerError; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let Form(value) = Form::::from_request(req).await?; value.validate()?; Ok(ValidatedForm(value)) diff --git a/examples/versioning/src/main.rs b/examples/versioning/src/main.rs index 4150a04385..cf8e15f280 100644 --- a/examples/versioning/src/main.rs +++ b/examples/versioning/src/main.rs @@ -48,14 +48,14 @@ enum Version { } #[async_trait] -impl FromRequest for Version +impl FromRequest for Version where B: Send, S: Send, { type Rejection = Response; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let params = Path::>::from_request(req) .await .map_err(IntoResponse::into_response)?; From e211b150f5a0de75650dc01e2e5c17bd6930b13c Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 17 Aug 2022 15:15:44 +0200 Subject: [PATCH 42/45] reverse order of `RequestParts::with_state` args to match type params --- axum-core/src/extract/mod.rs | 4 ++-- axum-extra/src/handler/mod.rs | 2 +- axum-extra/src/handler/or.rs | 2 +- axum/src/handler/mod.rs | 2 +- examples/consume-body-in-extractor-or-middleware/src/main.rs | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index 6425086f67..f9de0f399d 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -94,7 +94,7 @@ impl RequestParts<(), B> { /// /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html pub fn new(req: Request) -> Self { - Self::with_state(req, ()) + Self::with_state((), req) } } @@ -106,7 +106,7 @@ impl RequestParts { /// [`tower::Service`]. /// /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html - pub fn with_state(req: Request, state: S) -> Self { + pub fn with_state(state: S, req: Request) -> Self { let ( http::request::Parts { method, diff --git a/axum-extra/src/handler/mod.rs b/axum-extra/src/handler/mod.rs index 56af0ea215..6842327a01 100644 --- a/axum-extra/src/handler/mod.rs +++ b/axum-extra/src/handler/mod.rs @@ -178,7 +178,7 @@ where fn call(self, state: S, req: http::Request) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::with_state(req, state.clone()); + let mut req = RequestParts::with_state(state.clone(), req); match req.extract::().await { Ok(t) => self.handler.call(state, t).await, Err(rejection) => rejection.into_response(), diff --git a/axum-extra/src/handler/or.rs b/axum-extra/src/handler/or.rs index c2599d582f..6478b35d50 100644 --- a/axum-extra/src/handler/or.rs +++ b/axum-extra/src/handler/or.rs @@ -71,7 +71,7 @@ where fn call(self, state: S, req: Request) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::with_state(req, state.clone()); + let mut req = RequestParts::with_state(state.clone(), req); if let Ok(lt) = req.extract::().await { return self.lhs.call(state, lt).await; diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 2a975d6bb3..d51423468a 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -187,7 +187,7 @@ macro_rules! impl_handler { fn call(self, state: S, req: Request) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::with_state(req, state); + let mut req = RequestParts::with_state(state, req); $( let $ty = match $ty::from_request(&mut req).await { diff --git a/examples/consume-body-in-extractor-or-middleware/src/main.rs b/examples/consume-body-in-extractor-or-middleware/src/main.rs index 442a42b820..be948375d4 100644 --- a/examples/consume-body-in-extractor-or-middleware/src/main.rs +++ b/examples/consume-body-in-extractor-or-middleware/src/main.rs @@ -95,7 +95,7 @@ where let request = buffer_request_body(request).await?; - *req = RequestParts::with_state(request, state); + *req = RequestParts::with_state(state, request); Ok(Self) } From 96531b70fd1f75c5ac4eeb5a34a758fe005d975c Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 17 Aug 2022 16:59:40 +0200 Subject: [PATCH 43/45] Add `FromRef` trait (#1268) * Add `FromRef` trait * Remove unnecessary type params * format * fix docs link * format examples --- axum-core/src/extract/from_ref.rs | 23 ++++++++++++++++ axum-core/src/extract/mod.rs | 3 +++ axum-extra/src/extract/cookie/mod.rs | 14 +++++----- axum-extra/src/extract/cookie/private.rs | 19 +++++++------ axum-extra/src/extract/cookie/signed.rs | 19 +++++++------ axum/src/extract/mod.rs | 2 +- axum/src/extract/state.rs | 27 ++++++++++--------- axum/src/routing/tests/mod.rs | 8 +++--- examples/oauth/src/main.rs | 15 ++++++----- .../src/main.rs | 2 +- .../src/main.rs | 8 +++--- examples/testing/src/main.rs | 2 +- 12 files changed, 84 insertions(+), 58 deletions(-) create mode 100644 axum-core/src/extract/from_ref.rs diff --git a/axum-core/src/extract/from_ref.rs b/axum-core/src/extract/from_ref.rs new file mode 100644 index 0000000000..c0124140e5 --- /dev/null +++ b/axum-core/src/extract/from_ref.rs @@ -0,0 +1,23 @@ +/// Used to do reference-to-value conversions thus not consuming the input value. +/// +/// This is mainly used with [`State`] to extract "substates" from a reference to main application +/// state. +/// +/// See [`State`] for more details on how library authors should use this trait. +/// +/// [`State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html +// NOTE: This trait is defined in axum-core, even though it is mainly used with `State` which is +// defined in axum. That allows crate authors to use it when implementing extractors. +pub trait FromRef { + /// Converts to this type from a reference to the input type. + fn from_ref(input: &T) -> Self; +} + +impl FromRef for T +where + T: Clone, +{ + fn from_ref(input: &T) -> Self { + input.clone() + } +} diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index f9de0f399d..ade7bf0345 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -12,9 +12,12 @@ use std::convert::Infallible; pub mod rejection; +mod from_ref; mod request_parts; mod tuple; +pub use self::from_ref::FromRef; + /// Types that can be created from requests. /// /// See [`axum::extract`] for more details. diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index 38ecbbe54b..44842d039c 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -227,7 +227,7 @@ fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) { #[cfg(test)] mod tests { use super::*; - use axum::{body::Body, http::Request, routing::get, Router}; + use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router}; use tower::ServiceExt; macro_rules! cookie_test { @@ -308,15 +308,15 @@ mod tests { custom_key: CustomKey, } - impl From for Key { - fn from(state: AppState) -> Key { - state.key + impl FromRef for Key { + fn from_ref(state: &AppState) -> Key { + state.key.clone() } } - impl From for CustomKey { - fn from(state: AppState) -> CustomKey { - state.custom_key + impl FromRef for CustomKey { + fn from_ref(state: &AppState) -> CustomKey { + state.custom_key.clone() } } diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index 7540b1764f..d3705fb2be 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -1,7 +1,7 @@ use super::{cookies_from_request, set_cookies, Cookie, Key}; use axum::{ async_trait, - extract::{FromRequest, RequestParts}, + extract::{FromRef, FromRequest, RequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use cookie::PrivateJar; @@ -23,7 +23,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// use axum::{ /// Router, /// routing::{post, get}, -/// extract::TypedHeader, +/// extract::{TypedHeader, FromRef}, /// response::{IntoResponse, Redirect}, /// headers::authorization::{Authorization, Bearer}, /// http::StatusCode, @@ -51,9 +51,9 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// } /// /// // this impl tells `SignedCookieJar` how to access the key from our state -/// impl From for Key { -/// fn from(state: AppState) -> Self { -/// state.key +/// impl FromRef for Key { +/// fn from_ref(state: &AppState) -> Self { +/// state.key.clone() /// } /// } /// @@ -90,15 +90,14 @@ impl fmt::Debug for PrivateCookieJar { impl FromRequest for PrivateCookieJar where B: Send, - S: Into + Clone + Send, - K: Into + Clone + Send + Sync + 'static, + S: Send, + K: FromRef + Into, { type Rejection = Infallible; async fn from_request(req: &mut RequestParts) -> Result { - let state = req.state().clone(); - let key: K = state.into(); - let key: Key = key.into(); + let k = K::from_ref(req.state()); + let key = k.into(); let PrivateCookieJar { jar, key, diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index d56a8eb567..74da2a11ae 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -1,7 +1,7 @@ use super::{cookies_from_request, set_cookies}; use axum::{ async_trait, - extract::{FromRequest, RequestParts}, + extract::{FromRef, FromRequest, RequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use cookie::SignedJar; @@ -24,7 +24,7 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// use axum::{ /// Router, /// routing::{post, get}, -/// extract::TypedHeader, +/// extract::{TypedHeader, FromRef}, /// response::{IntoResponse, Redirect}, /// headers::authorization::{Authorization, Bearer}, /// http::StatusCode, @@ -69,9 +69,9 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// } /// /// // this impl tells `SignedCookieJar` how to access the key from our state -/// impl From for Key { -/// fn from(state: AppState) -> Self { -/// state.key +/// impl FromRef for Key { +/// fn from_ref(state: &AppState) -> Self { +/// state.key.clone() /// } /// } /// @@ -108,15 +108,14 @@ impl fmt::Debug for SignedCookieJar { impl FromRequest for SignedCookieJar where B: Send, - S: Into + Clone + Send, - K: Into + Clone + Send + Sync + 'static, + S: Send, + K: FromRef + Into, { type Rejection = Infallible; async fn from_request(req: &mut RequestParts) -> Result { - let state = req.state().clone(); - let key: K = state.into(); - let key: Key = key.into(); + let k = K::from_ref(req.state()); + let key = k.into(); let SignedCookieJar { jar, key, diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index c4aeb3ad70..081793a83c 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -17,7 +17,7 @@ mod request_parts; mod state; #[doc(inline)] -pub use axum_core::extract::{FromRequest, RequestParts}; +pub use axum_core::extract::{FromRef, FromRequest, RequestParts}; #[doc(inline)] #[allow(deprecated)] diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index 9991ad45c2..94ccf5b1a6 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use axum_core::extract::{FromRequest, RequestParts}; +use axum_core::extract::{FromRef, FromRequest, RequestParts}; use std::{ convert::Infallible, ops::{Deref, DerefMut}, @@ -91,7 +91,7 @@ use std::{ /// [`State`] only allows a single state type but you can use [`From`] to extract "substates": /// /// ``` -/// use axum::{Router, routing::get, extract::State}; +/// use axum::{Router, routing::get, extract::{State, FromRef}}; /// /// // the application state /// #[derive(Clone)] @@ -105,9 +105,9 @@ use std::{ /// struct ApiState {} /// /// // support converting an `AppState` in an `ApiState` -/// impl From for ApiState { -/// fn from(app_state: AppState) -> ApiState { -/// app_state.api_state +/// impl FromRef for ApiState { +/// fn from_ref(app_state: &AppState) -> ApiState { +/// app_state.api_state.clone() /// } /// } /// @@ -139,7 +139,7 @@ use std::{ /// to do it: /// /// ```rust -/// use axum_core::extract::{FromRequest, RequestParts}; +/// use axum_core::extract::{FromRequest, RequestParts, FromRef}; /// use async_trait::async_trait; /// use std::convert::Infallible; /// @@ -151,14 +151,15 @@ use std::{ /// where /// B: Send, /// // keep `S` generic but require that it can produce a `MyLibraryState` -/// // this means users will have to implement `From for MyLibraryState` -/// S: Into + Clone + Send, +/// // this means users will have to implement `FromRef for MyLibraryState` +/// MyLibraryState: FromRef, +/// S: Send, /// { /// type Rejection = Infallible; /// /// async fn from_request(req: &mut RequestParts) -> Result { -/// // get a `MyLibraryState` from the shared application state -/// let state: MyLibraryState = req.state().clone().into(); +/// // get a `MyLibraryState` from a reference to the state +/// let state = MyLibraryState::from_ref(req.state()); /// /// // ... /// # todo!() @@ -180,13 +181,13 @@ pub struct State(pub S); impl FromRequest for State where B: Send, - OuterState: Clone + Into + Send, + InnerState: FromRef, + OuterState: Send, { type Rejection = Infallible; async fn from_request(req: &mut RequestParts) -> Result { - let outer_state = req.state().clone(); - let inner_state = outer_state.into(); + let inner_state = InnerState::from_ref(req.state()); Ok(Self(inner_state)) } } diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index a90ca2110c..fa46cdf86c 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -1,7 +1,7 @@ use crate::{ body::{Bytes, Empty}, error_handling::HandleErrorLayer, - extract::{self, Path, State}, + extract::{self, FromRef, Path, State}, handler::{Handler, HandlerWithoutStateExt}, response::IntoResponse, routing::{delete, get, get_service, on, on_service, patch, patch_service, post, MethodFilter}, @@ -654,9 +654,9 @@ async fn extracting_state() { value: i32, } - impl From for InnerState { - fn from(state: AppState) -> Self { - state.inner + impl FromRef for InnerState { + fn from_ref(state: &AppState) -> Self { + state.inner.clone() } } diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index 303ca1647f..a61113b97d 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -12,7 +12,8 @@ use async_session::{MemoryStore, Session, SessionStore}; use axum::{ async_trait, extract::{ - rejection::TypedHeaderRejectionReason, FromRequest, Query, RequestParts, State, TypedHeader, + rejection::TypedHeaderRejectionReason, FromRef, FromRequest, Query, RequestParts, State, + TypedHeader, }, http::{header::SET_COOKIE, HeaderMap}, response::{IntoResponse, Redirect, Response}, @@ -69,15 +70,15 @@ struct AppState { oauth_client: BasicClient, } -impl From for MemoryStore { - fn from(state: AppState) -> Self { - state.store +impl FromRef for MemoryStore { + fn from_ref(state: &AppState) -> Self { + state.store.clone() } } -impl From for BasicClient { - fn from(state: AppState) -> Self { - state.oauth_client +impl FromRef for BasicClient { + fn from_ref(state: &AppState) -> Self { + state.oauth_client.clone() } } diff --git a/examples/query-params-with-empty-strings/src/main.rs b/examples/query-params-with-empty-strings/src/main.rs index 7e9a08894c..0af20111d7 100644 --- a/examples/query-params-with-empty-strings/src/main.rs +++ b/examples/query-params-with-empty-strings/src/main.rs @@ -16,7 +16,7 @@ async fn main() { .unwrap(); } -fn app() -> Router<()> { +fn app() -> Router { Router::new().route("/", get(handler)) } diff --git a/examples/routes-and-handlers-close-together/src/main.rs b/examples/routes-and-handlers-close-together/src/main.rs index 6fc75c9e41..41aaa49db5 100644 --- a/examples/routes-and-handlers-close-together/src/main.rs +++ b/examples/routes-and-handlers-close-together/src/main.rs @@ -25,7 +25,7 @@ async fn main() { .unwrap(); } -fn root() -> Router<()> { +fn root() -> Router { async fn handler() -> &'static str { "Hello, World!" } @@ -33,7 +33,7 @@ fn root() -> Router<()> { route("/", get(handler)) } -fn get_foo() -> Router<()> { +fn get_foo() -> Router { async fn handler() -> &'static str { "Hi from `GET /foo`" } @@ -41,7 +41,7 @@ fn get_foo() -> Router<()> { route("/foo", get(handler)) } -fn post_foo() -> Router<()> { +fn post_foo() -> Router { async fn handler() -> &'static str { "Hi from `POST /foo`" } @@ -49,6 +49,6 @@ fn post_foo() -> Router<()> { route("/foo", post(handler)) } -fn route(path: &str, method_router: MethodRouter<()>) -> Router<()> { +fn route(path: &str, method_router: MethodRouter<()>) -> Router { Router::new().route(path, method_router) } diff --git a/examples/testing/src/main.rs b/examples/testing/src/main.rs index f31935188a..0bb9b352a0 100644 --- a/examples/testing/src/main.rs +++ b/examples/testing/src/main.rs @@ -34,7 +34,7 @@ async fn main() { /// Having a function that produces our app makes it easy to call it from tests /// without having to create an HTTP server. #[allow(dead_code)] -fn app() -> Router<()> { +fn app() -> Router { Router::new() .route("/", get(|| async { "Hello, World!" })) .route( From e23d81c7e8c1e00b41b0be7f85dee15bae5968ce Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 17 Aug 2022 17:03:45 +0200 Subject: [PATCH 44/45] Avoid unnecessary `MethodRouter` --- axum-extra/src/routing/mod.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index e90578cd9e..0d968b941c 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -1,13 +1,13 @@ //! Additional types for defining routes. use axum::{ - handler::Handler, + handler::{Handler, HandlerWithoutStateExt}, http::Request, response::{IntoResponse, Redirect}, routing::{any, MethodRouter}, Router, }; -use std::{convert::Infallible, future::ready}; +use std::{convert::Infallible, future::ready, sync::Arc}; use tower_service::Service; mod resource; @@ -266,15 +266,15 @@ where { self = self.route(path, method_router); - let redirect = Redirect::permanent(path); + let redirect_service = { + let path: Arc = path.into(); + (move || ready(Redirect::permanent(&path))).into_service() + }; if let Some(path_without_trailing_slash) = path.strip_suffix('/') { - self.route( - path_without_trailing_slash, - any(move || ready(redirect.clone())), - ) + self.route_service(path_without_trailing_slash, redirect_service) } else { - self.route(&format!("{}/", path), any(move || ready(redirect.clone()))) + self.route_service(&format!("{}/", path), redirect_service) } } From 34d24d28e6d5e79f90b54d9e64ff70a3879764d0 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 17 Aug 2022 17:04:52 +0200 Subject: [PATCH 45/45] apply suggestions from review --- axum/src/routing/tests/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index fa46cdf86c..68746a2fa1 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -642,7 +642,7 @@ async fn limited_body_with_streaming_body() { } #[tokio::test] -async fn extracting_state() { +async fn extract_state() { #[derive(Clone)] struct AppState { value: i32, @@ -678,7 +678,7 @@ async fn extracting_state() { } #[tokio::test] -async fn explicitly_setting_state() { +async fn explicitly_set_state() { let app = Router::with_state("...").route_service( "/", get(|State(state): State<&'static str>| async move { state }).with_state("foo"),