diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 0bb97ba03f..0f66190484 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,8 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +- **fixed:** Fixed performance regression with `Router::nest` introduced in + 0.6.0. `nest` now flattens the routes which performs better ([#1711]) +- **fixed:** Extracting `MatchedPath` in nested handlers now gives the full + matched path, including the nested path ([#1711]) - **added:** Implement `Deref` and `DerefMut` for built-in extractors ([#1922]) +[#1711]: https://github.com/tokio-rs/axum/pull/1711 [#1922]: https://github.com/tokio-rs/axum/pull/1922 # 0.6.12 (22. March, 2023) diff --git a/axum/src/boxed.rs b/axum/src/boxed.rs index 6aaea39a66..f8191f2e26 100644 --- a/axum/src/boxed.rs +++ b/axum/src/boxed.rs @@ -28,25 +28,6 @@ where into_route: |handler, state| Route::new(Handler::with_state(handler, state)), })) } - - pub(crate) fn from_router(router: Router) -> Self - where - B: HttpBody + Send + 'static, - S: Clone + Send + Sync + 'static, - { - Self(Box::new(MakeErasedRouter { - router, - into_route: |router, state| Route::new(router.with_state(state)), - })) - } - - pub(crate) fn call_with_state( - self, - request: Request, - state: S, - ) -> RouteFuture { - self.0.call_with_state(request, state) - } } impl BoxedIntoRoute { diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index e51b38a40c..57a17bee63 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -235,6 +235,26 @@ mod tests { req } + let app = Router::new() + .nest_service("/:a", Router::new().route("/:b", get(|| async move {}))) + .layer(map_request(extract_matched_path)); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[crate::test] + async fn can_extract_nested_matched_path_in_middleware_using_nest() { + async fn extract_matched_path( + matched_path: Option, + req: Request, + ) -> Request { + assert_eq!(matched_path.unwrap().as_str(), "/:a/:b"); + req + } + let app = Router::new() .nest("/:a", Router::new().route("/:b", get(|| async move {}))) .layer(map_request(extract_matched_path)); @@ -253,7 +273,7 @@ mod tests { } let app = Router::new() - .nest("/:a", Router::new().route("/:b", get(|| async move {}))) + .nest_service("/:a", Router::new().route("/:b", get(|| async move {}))) .layer(map_request(assert_no_matched_path)); let client = TestClient::new(app); @@ -262,6 +282,23 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); } + #[tokio::test] + async fn can_extract_nested_matched_path_in_middleware_via_extension_using_nest() { + async fn assert_matched_path(req: Request) -> Request { + assert!(req.extensions().get::().is_some()); + req + } + + let app = Router::new() + .nest("/:a", Router::new().route("/:b", get(|| async move {}))) + .layer(map_request(assert_matched_path)); + + let client = TestClient::new(app); + + let res = client.get("/foo/bar").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + #[crate::test] async fn can_extract_nested_matched_path_in_middleware_on_nested_router() { async fn extract_matched_path(matched_path: MatchedPath, req: Request) -> Request { diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 7322401c45..b7799d7b29 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -1,6 +1,6 @@ //! Routing between [`Service`]s and handlers. -use self::{future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix}; +use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter}; #[cfg(feature = "tokio")] use crate::extract::connect_info::IntoMakeServiceWithConnectInfo; use crate::{ @@ -11,12 +11,9 @@ use crate::{ }; use axum_core::response::{IntoResponse, Response}; use http::Request; -use matchit::MatchError; use std::{ - collections::HashMap, convert::Infallible, fmt, - sync::Arc, task::{Context, Poll}, }; use sync_wrapper::SyncWrapper; @@ -29,6 +26,7 @@ pub mod method_routing; mod into_make_service; mod method_filter; mod not_found; +pub(crate) mod path_router; mod route; mod strip_prefix; pub(crate) mod url_params; @@ -44,25 +42,32 @@ pub use self::method_routing::{ trace_service, MethodRouter, }; +macro_rules! panic_on_err { + ($expr:expr) => { + match $expr { + Ok(x) => x, + Err(err) => panic!("{err}"), + } + }; +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub(crate) struct RouteId(u32); /// The router type for composing handlers and services. #[must_use] pub struct Router { - routes: HashMap>, - node: Arc, - fallback: Fallback, - prev_route_id: RouteId, + path_router: PathRouter, + fallback_router: PathRouter, + default_fallback: bool, } impl Clone for Router { fn clone(&self) -> Self { Self { - routes: self.routes.clone(), - node: Arc::clone(&self.node), - fallback: self.fallback.clone(), - prev_route_id: self.prev_route_id, + path_router: self.path_router.clone(), + fallback_router: self.fallback_router.clone(), + default_fallback: self.default_fallback, } } } @@ -80,16 +85,16 @@ where impl fmt::Debug for Router { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Router") - .field("routes", &self.routes) - .field("node", &self.node) - .field("fallback", &self.fallback) - .field("prev_route_id", &self.prev_route_id) + .field("path_router", &self.path_router) + .field("fallback_router", &self.fallback_router) + .field("default_fallback", &self.default_fallback) .finish() } } pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param"; pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param"; +pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback"; impl Router where @@ -101,57 +106,25 @@ where /// Unless you add additional routes this will respond with `404 Not Found` to /// all requests. pub fn new() -> Self { - Self { - routes: Default::default(), - node: Default::default(), - fallback: Fallback::Default(Route::new(NotFound)), - prev_route_id: RouteId(0), - } + let mut this = Self { + path_router: Default::default(), + fallback_router: Default::default(), + default_fallback: true, + }; + this = this.fallback_service(NotFound); + this.default_fallback = true; + this } #[doc = include_str!("../docs/routing/route.md")] #[track_caller] pub fn route(mut self, path: &str, method_router: MethodRouter) -> Self { - #[track_caller] - fn validate_path(path: &str) { - if path.is_empty() { - panic!("Paths must start with a `/`. Use \"/\" for root routes"); - } else if !path.starts_with('/') { - panic!("Paths must start with a `/`"); - } - } - - validate_path(path); - - let id = self.next_route_id(); - - let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self - .node - .path_to_route_id - .get(path) - .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc))) - { - // if we're adding a new `MethodRouter` to a route that already has one just - // merge them. This makes `.route("/", get(_)).route("/", post(_))` work - let service = Endpoint::MethodRouter( - prev_method_router - .clone() - .merge_for_path(Some(path), method_router), - ); - self.routes.insert(route_id, service); - return self; - } else { - Endpoint::MethodRouter(method_router) - }; - - self.set_node(path, id); - self.routes.insert(id, endpoint); - + panic_on_err!(self.path_router.route(path, method_router)); self } #[doc = include_str!("../docs/routing/route_service.md")] - pub fn route_service(self, path: &str, service: T) -> Self + pub fn route_service(mut self, path: &str, service: T) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, @@ -164,104 +137,40 @@ where Use `Router::nest` instead" ); } - Err(svc) => svc, + Err(service) => service, }; - self.route_endpoint(path, Endpoint::Route(Route::new(service))) + panic_on_err!(self.path_router.route_service(path, service)); + self } + #[doc = include_str!("../docs/routing/nest.md")] #[track_caller] - fn route_endpoint(mut self, path: &str, endpoint: Endpoint) -> Self { - if path.is_empty() { - panic!("Paths must start with a `/`. Use \"/\" for root routes"); - } else if !path.starts_with('/') { - panic!("Paths must start with a `/`"); - } + pub fn nest(mut self, path: &str, router: Router) -> Self { + let Router { + path_router, + fallback_router, + default_fallback, + } = router; - let id = self.next_route_id(); - self.set_node(path, id); - self.routes.insert(id, endpoint); - self - } + panic_on_err!(self.path_router.nest(path, path_router)); - #[track_caller] - fn set_node(&mut self, path: &str, id: RouteId) { - let mut node = - Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone()); - if let Err(err) = node.insert(path, id) { - panic!("Invalid route {path:?}: {err}"); + if !default_fallback { + panic_on_err!(self.fallback_router.nest(path, fallback_router)); } - self.node = Arc::new(node); - } - #[doc = include_str!("../docs/routing/nest.md")] - #[track_caller] - pub fn nest(self, path: &str, router: Router) -> Self { - self.nest_endpoint(path, RouterOrService::<_, _, NotFound>::Router(router)) + self } /// Like [`nest`](Self::nest), but accepts an arbitrary `Service`. #[track_caller] - pub fn nest_service(self, path: &str, svc: T) -> Self + pub fn nest_service(mut self, path: &str, service: T) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { - self.nest_endpoint(path, RouterOrService::Service(svc)) - } - - #[track_caller] - fn nest_endpoint( - mut self, - mut path: &str, - router_or_service: RouterOrService, - ) -> Self - where - T: Service, Error = Infallible> + Clone + Send + 'static, - T::Response: IntoResponse, - T::Future: Send + 'static, - { - if path.is_empty() { - // nesting at `""` and `"/"` should mean the same thing - path = "/"; - } - - if path.contains('*') { - panic!("Invalid route: nested routes cannot contain wildcards (*)"); - } - - let prefix = path; - - let path = if path.ends_with('/') { - format!("{path}*{NEST_TAIL_PARAM}") - } else { - format!("{path}/*{NEST_TAIL_PARAM}") - }; - - let endpoint = match router_or_service { - RouterOrService::Router(router) => { - let prefix = prefix.to_owned(); - let boxed = BoxedIntoRoute::from_router(router) - .map(move |route| Route::new(StripPrefix::new(route, &prefix))); - Endpoint::NestedRouter(boxed) - } - RouterOrService::Service(svc) => { - Endpoint::Route(Route::new(StripPrefix::new(svc, prefix))) - } - }; - - self = self.route_endpoint(&path, endpoint.clone()); - - // `/*rest` is not matched by `/` so we need to also register a router at the - // prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself - // wouldn't match, which it should - self = self.route_endpoint(prefix, endpoint.clone()); - if !prefix.ends_with('/') { - // same goes for `/foo/`, that should also match - self = self.route_endpoint(&format!("{prefix}/"), endpoint); - } - + panic_on_err!(self.path_router.nest_service(path, service)); self } @@ -272,30 +181,32 @@ where R: Into>, { let Router { - routes, - node, - fallback, - prev_route_id: _, + path_router, + fallback_router: other_fallback, + default_fallback, } = other.into(); - for (id, route) in routes { - let path = node - .route_id_to_path - .get(&id) - .expect("no path for route id. This is a bug in axum. Please file an issue"); - self = match route { - Endpoint::MethodRouter(method_router) => self.route(path, method_router), - Endpoint::Route(route) => self.route_service(path, route), - Endpoint::NestedRouter(router) => { - self.route_endpoint(path, Endpoint::NestedRouter(router)) - } - }; - } + panic_on_err!(self.path_router.merge(path_router)); - self.fallback = self - .fallback - .merge(fallback) - .expect("Cannot merge two `Router`s that both have a fallback"); + match (self.default_fallback, default_fallback) { + // both have the default fallback + // use the one from other + (true, true) => { + self.fallback_router = other_fallback; + } + // self has default fallback, other has a custom fallback + (true, false) => { + self.fallback_router = other_fallback; + self.default_fallback = false; + } + // self has a custom fallback, other has a default + // nothing to do + (false, true) => {} + // both have a custom fallback, not allowed + (false, false) => { + panic!("Cannot merge two `Router`s that both have a fallback") + } + }; self } @@ -310,22 +221,10 @@ where >>::Future: Send + 'static, NewReqBody: HttpBody + 'static, { - let routes = self - .routes - .into_iter() - .map(|(id, endpoint)| { - let route = endpoint.layer(layer.clone()); - (id, route) - }) - .collect(); - - let fallback = self.fallback.map(|route| route.layer(layer)); - Router { - routes, - node: self.node, - fallback, - prev_route_id: self.prev_route_id, + path_router: self.path_router.layer(layer.clone()), + fallback_router: self.fallback_router.layer(layer), + default_fallback: self.default_fallback, } } @@ -339,79 +238,50 @@ where >>::Error: Into + 'static, >>::Future: Send + 'static, { - if self.routes.is_empty() { - panic!( - "Adding a route_layer before any routes is a no-op. \ - Add the routes you want the layer to apply to first." - ); - } - - let routes = self - .routes - .into_iter() - .map(|(id, endpoint)| { - let route = endpoint.layer(layer.clone()); - (id, route) - }) - .collect(); - Router { - routes, - node: self.node, - fallback: self.fallback, - prev_route_id: self.prev_route_id, + path_router: self.path_router.route_layer(layer), + fallback_router: self.fallback_router, + default_fallback: self.default_fallback, } } + #[track_caller] #[doc = include_str!("../docs/routing/fallback.md")] - pub fn fallback(mut self, handler: H) -> Self + pub fn fallback(self, handler: H) -> Self where H: Handler, T: 'static, { - self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler)); - self + let endpoint = Endpoint::MethodRouter(any(handler)); + self.fallback_endpoint(endpoint) } /// Add a fallback [`Service`] to the router. /// /// See [`Router::fallback`] for more details. - pub fn fallback_service(mut self, svc: T) -> Self + pub fn fallback_service(self, service: T) -> Self where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { - self.fallback = Fallback::Service(Route::new(svc)); + self.fallback_endpoint(Endpoint::Route(Route::new(service))) + } + + fn fallback_endpoint(mut self, endpoint: Endpoint) -> Self { + self.fallback_router.replace_endpoint("/", endpoint.clone()); + self.fallback_router + .replace_endpoint(&format!("/*{FALLBACK_PARAM}"), endpoint); + self.default_fallback = false; self } #[doc = include_str!("../docs/routing/with_state.md")] pub fn with_state(self, state: S) -> Router { - let routes = self - .routes - .into_iter() - .map(|(id, endpoint)| { - let endpoint: Endpoint = match endpoint { - Endpoint::MethodRouter(method_router) => { - Endpoint::MethodRouter(method_router.with_state(state.clone())) - } - Endpoint::Route(route) => Endpoint::Route(route), - Endpoint::NestedRouter(router) => { - Endpoint::Route(router.into_route(state.clone())) - } - }; - (id, endpoint) - }) - .collect(); - - let fallback = self.fallback.with_state(state); - Router { - routes, - node: self.node, - fallback, - prev_route_id: self.prev_route_id, + path_router: self.path_router.with_state(state.clone()), + fallback_router: self.fallback_router.with_state(state), + default_fallback: self.default_fallback, } } @@ -420,86 +290,43 @@ where mut req: Request, state: S, ) -> RouteFuture { - #[cfg(feature = "original-uri")] - { - use crate::extract::OriginalUri; - - if req.extensions().get::().is_none() { - let original_uri = OriginalUri(req.uri().clone()); - req.extensions_mut().insert(original_uri); - } + // required for opaque routers to still inherit the fallback + // TODO(david): remove this feature in 0.7 + if !self.default_fallback { + req.extensions_mut().insert(SuperFallback(SyncWrapper::new( + self.fallback_router.clone(), + ))); } - let path = req.uri().path().to_owned(); - - match self.node.at(&path) { - Ok(match_) => { - match &self.fallback { - Fallback::Default(_) => {} - Fallback::Service(fallback) => { - req.extensions_mut() - .insert(SuperFallback(SyncWrapper::new(fallback.clone()))); - } - Fallback::BoxedHandler(fallback) => { - req.extensions_mut().insert(SuperFallback(SyncWrapper::new( - fallback.clone().into_route(state.clone()), - ))); - } + match self.path_router.call_with_state(req, state) { + Ok(future) => { + println!("path_router hit"); + future + } + Err((mut req, state)) => { + let super_fallback = req + .extensions_mut() + .remove::>() + .map(|SuperFallback(path_router)| path_router.into_inner()); + + if let Some(mut super_fallback) = super_fallback { + return super_fallback + .call_with_state(req, state) + .unwrap_or_else(|_| unreachable!()); } - let id = *match_.value; - - #[cfg(feature = "matched-path")] - crate::extract::matched_path::set_matched_path_for_request( - id, - &self.node.route_id_to_path, - req.extensions_mut(), - ); - - url_params::insert_url_params(req.extensions_mut(), match_.params); - - let endpont = self - .routes - .get_mut(&id) - .expect("no route for id. This is a bug in axum. Please file an issue"); - - match endpont { - Endpoint::MethodRouter(method_router) => { - method_router.call_with_state(req, state) + match self.fallback_router.call_with_state(req, state) { + Ok(future) => future, + Err((_req, _state)) => { + unreachable!( + "the default fallback added in `Router::new` \ + matches everything" + ) } - Endpoint::Route(route) => route.call(req), - Endpoint::NestedRouter(router) => router.clone().call_with_state(req, state), } } - Err( - MatchError::NotFound - | MatchError::ExtraTrailingSlash - | MatchError::MissingTrailingSlash, - ) => match &mut self.fallback { - Fallback::Default(fallback) => { - if let Some(super_fallback) = req.extensions_mut().remove::>() - { - let mut super_fallback = super_fallback.0.into_inner(); - super_fallback.call(req) - } else { - fallback.call(req) - } - } - Fallback::Service(fallback) => fallback.call(req), - Fallback::BoxedHandler(handler) => handler.clone().into_route(state).call(req), - }, } } - - fn next_route_id(&mut self) -> RouteId { - let next_id = self - .prev_route_id - .0 - .checked_add(1) - .expect("Over `u32::MAX` routes created. If you need this, please file an issue."); - self.prev_route_id = RouteId(next_id); - self.prev_route_id - } } impl Router<(), B> @@ -563,47 +390,6 @@ where } } -/// Wrapper around `matchit::Router` that supports merging two `Router`s. -#[derive(Clone, Default)] -struct Node { - inner: matchit::Router, - route_id_to_path: HashMap>, - path_to_route_id: HashMap, RouteId>, -} - -impl Node { - fn insert( - &mut self, - path: impl Into, - val: RouteId, - ) -> Result<(), matchit::InsertError> { - let path = path.into(); - - self.inner.insert(&path, val)?; - - let shared_path: Arc = path.into(); - self.route_id_to_path.insert(val, shared_path.clone()); - self.path_to_route_id.insert(shared_path, val); - - Ok(()) - } - - fn at<'n, 'p>( - &'n self, - path: &'p str, - ) -> Result, MatchError> { - self.inner.at(path) - } -} - -impl fmt::Debug for Node { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Node") - .field("paths", &self.route_id_to_path) - .finish() - } -} - enum Fallback { Default(Route), Service(Route), @@ -671,7 +457,6 @@ impl fmt::Debug for Fallback { enum Endpoint { MethodRouter(MethodRouter), Route(Route), - NestedRouter(BoxedIntoRoute), } impl Endpoint @@ -693,9 +478,6 @@ where Endpoint::MethodRouter(method_router.layer(layer)) } Endpoint::Route(route) => Endpoint::Route(route.layer(layer)), - Endpoint::NestedRouter(router) => { - Endpoint::NestedRouter(router.map(|route| route.layer(layer))) - } } } } @@ -705,7 +487,6 @@ impl Clone for Endpoint { match self { Self::MethodRouter(inner) => Self::MethodRouter(inner.clone()), Self::Route(inner) => Self::Route(inner.clone()), - Self::NestedRouter(router) => Self::NestedRouter(router.clone()), } } } @@ -717,17 +498,11 @@ impl fmt::Debug for Endpoint { f.debug_tuple("MethodRouter").field(method_router).finish() } Self::Route(route) => f.debug_tuple("Route").field(route).finish(), - Self::NestedRouter(router) => f.debug_tuple("NestedRouter").field(router).finish(), } } } -enum RouterOrService { - Router(Router), - Service(T), -} - -struct SuperFallback(SyncWrapper>); +struct SuperFallback(SyncWrapper>); #[test] #[allow(warnings)] diff --git a/axum/src/routing/not_found.rs b/axum/src/routing/not_found.rs index dc3fec46ac..ffd185bfa1 100644 --- a/axum/src/routing/not_found.rs +++ b/axum/src/routing/not_found.rs @@ -29,6 +29,7 @@ where } fn call(&mut self, _req: Request) -> Self::Future { + println!("NotFound hit"); ready(Ok(StatusCode::NOT_FOUND.into_response())) } } diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs new file mode 100644 index 0000000000..1790b32d88 --- /dev/null +++ b/axum/src/routing/path_router.rs @@ -0,0 +1,445 @@ +use crate::body::{Body, HttpBody}; +use axum_core::response::IntoResponse; +use http::Request; +use matchit::MatchError; +use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc}; +use tower_layer::Layer; +use tower_service::Service; + +use super::{ + future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route, + RouteId, NEST_TAIL_PARAM, +}; + +pub(super) struct PathRouter { + routes: HashMap>, + node: Arc, + prev_route_id: RouteId, +} + +impl PathRouter +where + B: HttpBody + Send + 'static, + S: Clone + Send + Sync + 'static, +{ + pub(super) fn route( + &mut self, + path: &str, + method_router: MethodRouter, + ) -> Result<(), Cow<'static, str>> { + fn validate_path(path: &str) -> Result<(), &'static str> { + if path.is_empty() { + return Err("Paths must start with a `/`. Use \"/\" for root routes"); + } else if !path.starts_with('/') { + return Err("Paths must start with a `/`"); + } + + Ok(()) + } + + validate_path(path)?; + + let id = self.next_route_id(); + + let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self + .node + .path_to_route_id + .get(path) + .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc))) + { + // if we're adding a new `MethodRouter` to a route that already has one just + // merge them. This makes `.route("/", get(_)).route("/", post(_))` work + let service = Endpoint::MethodRouter( + prev_method_router + .clone() + .merge_for_path(Some(path), method_router), + ); + self.routes.insert(route_id, service); + return Ok(()); + } else { + Endpoint::MethodRouter(method_router) + }; + + self.set_node(path, id)?; + self.routes.insert(id, endpoint); + + Ok(()) + } + + pub(super) fn route_service( + &mut self, + path: &str, + service: T, + ) -> Result<(), Cow<'static, str>> + where + T: Service, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, + T::Future: Send + 'static, + { + self.route_endpoint(path, Endpoint::Route(Route::new(service))) + } + + pub(super) fn route_endpoint( + &mut self, + path: &str, + endpoint: Endpoint, + ) -> Result<(), Cow<'static, str>> { + if path.is_empty() { + return Err("Paths must start with a `/`. Use \"/\" for root routes".into()); + } else if !path.starts_with('/') { + return Err("Paths must start with a `/`".into()); + } + + let id = self.next_route_id(); + self.set_node(path, id)?; + self.routes.insert(id, endpoint); + + Ok(()) + } + + fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> { + let mut node = + Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone()); + if let Err(err) = node.insert(path, id) { + return Err(format!("Invalid route {path:?}: {err}")); + } + self.node = Arc::new(node); + Ok(()) + } + + pub(super) fn merge(&mut self, other: PathRouter) -> Result<(), Cow<'static, str>> { + let PathRouter { + routes, + node, + prev_route_id: _, + } = other; + + for (id, route) in routes { + let path = node + .route_id_to_path + .get(&id) + .expect("no path for route id. This is a bug in axum. Please file an issue"); + match route { + Endpoint::MethodRouter(method_router) => self.route(path, method_router)?, + Endpoint::Route(route) => self.route_service(path, route)?, + }; + } + + Ok(()) + } + + pub(super) fn nest( + &mut self, + path: &str, + router: PathRouter, + ) -> Result<(), Cow<'static, str>> { + let prefix = validate_nest_path(path); + + let PathRouter { + routes, + node, + prev_route_id: _, + } = router; + + for (id, endpoint) in routes { + let inner_path = node + .route_id_to_path + .get(&id) + .expect("no path for route id. This is a bug in axum. Please file an issue"); + + let path = path_for_nested_route(prefix, inner_path); + + match endpoint.layer(StripPrefix::layer(prefix)) { + Endpoint::MethodRouter(method_router) => { + self.route(&path, method_router)?; + } + Endpoint::Route(route) => { + self.route_endpoint(&path, Endpoint::Route(route))?; + } + } + } + + Ok(()) + } + + pub(super) fn nest_service(&mut self, path: &str, svc: T) -> Result<(), Cow<'static, str>> + where + T: Service, Error = Infallible> + Clone + Send + 'static, + T::Response: IntoResponse, + T::Future: Send + 'static, + { + let path = validate_nest_path(path); + let prefix = path; + + let path = if path.ends_with('/') { + format!("{path}*{NEST_TAIL_PARAM}") + } else { + format!("{path}/*{NEST_TAIL_PARAM}") + }; + + let endpoint = Endpoint::Route(Route::new(StripPrefix::new(svc, prefix))); + + self.route_endpoint(&path, endpoint.clone())?; + + // `/*rest` is not matched by `/` so we need to also register a router at the + // prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself + // wouldn't match, which it should + self.route_endpoint(prefix, endpoint.clone())?; + if !prefix.ends_with('/') { + // same goes for `/foo/`, that should also match + self.route_endpoint(&format!("{prefix}/"), endpoint)?; + } + + Ok(()) + } + + pub(super) fn layer(self, layer: L) -> PathRouter + where + L: Layer> + Clone + Send + 'static, + L::Service: Service> + Clone + Send + 'static, + >>::Response: IntoResponse + 'static, + >>::Error: Into + 'static, + >>::Future: Send + 'static, + NewReqBody: HttpBody + 'static, + { + let routes = self + .routes + .into_iter() + .map(|(id, endpoint)| { + let route = endpoint.layer(layer.clone()); + (id, route) + }) + .collect(); + + PathRouter { + routes, + node: self.node, + prev_route_id: self.prev_route_id, + } + } + + #[track_caller] + pub(super) fn route_layer(self, layer: L) -> Self + where + L: Layer> + Clone + Send + 'static, + L::Service: Service> + Clone + Send + 'static, + >>::Response: IntoResponse + 'static, + >>::Error: Into + 'static, + >>::Future: Send + 'static, + { + if self.routes.is_empty() { + panic!( + "Adding a route_layer before any routes is a no-op. \ + Add the routes you want the layer to apply to first." + ); + } + + let routes = self + .routes + .into_iter() + .map(|(id, endpoint)| { + let route = endpoint.layer(layer.clone()); + (id, route) + }) + .collect(); + + PathRouter { + routes, + node: self.node, + prev_route_id: self.prev_route_id, + } + } + + pub(super) fn with_state(self, state: S) -> PathRouter { + let routes = self + .routes + .into_iter() + .map(|(id, endpoint)| { + let endpoint: Endpoint = match endpoint { + Endpoint::MethodRouter(method_router) => { + Endpoint::MethodRouter(method_router.with_state(state.clone())) + } + Endpoint::Route(route) => Endpoint::Route(route), + }; + (id, endpoint) + }) + .collect(); + + PathRouter { + routes, + node: self.node, + prev_route_id: self.prev_route_id, + } + } + + pub(super) fn call_with_state( + &mut self, + mut req: Request, + state: S, + ) -> Result, (Request, S)> { + #[cfg(feature = "original-uri")] + { + use crate::extract::OriginalUri; + + if req.extensions().get::().is_none() { + let original_uri = OriginalUri(req.uri().clone()); + req.extensions_mut().insert(original_uri); + } + } + + let path = req.uri().path().to_owned(); + + match self.node.at(&path) { + Ok(match_) => { + let id = *match_.value; + + #[cfg(feature = "matched-path")] + crate::extract::matched_path::set_matched_path_for_request( + id, + &self.node.route_id_to_path, + req.extensions_mut(), + ); + + url_params::insert_url_params(req.extensions_mut(), match_.params); + + let endpont = self + .routes + .get_mut(&id) + .expect("no route for id. This is a bug in axum. Please file an issue"); + + match endpont { + Endpoint::MethodRouter(method_router) => { + Ok(method_router.call_with_state(req, state)) + } + Endpoint::Route(route) => Ok(route.clone().call(req)), + } + } + // explicitly handle all variants in case matchit adds + // new ones we need to handle differently + Err( + MatchError::NotFound + | MatchError::ExtraTrailingSlash + | MatchError::MissingTrailingSlash, + ) => Err((req, state)), + } + } + + pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint) { + match self.node.at(path) { + Ok(match_) => { + let id = *match_.value; + self.routes.insert(id, endpoint); + } + Err(_) => self + .route_endpoint(path, endpoint) + .expect("path wasn't matched so endpoint shouldn't exist"), + } + } + + fn next_route_id(&mut self) -> RouteId { + let next_id = self + .prev_route_id + .0 + .checked_add(1) + .expect("Over `u32::MAX` routes created. If you need this, please file an issue."); + self.prev_route_id = RouteId(next_id); + self.prev_route_id + } +} + +impl Default for PathRouter { + fn default() -> Self { + Self { + routes: Default::default(), + node: Default::default(), + prev_route_id: RouteId(0), + } + } +} + +impl fmt::Debug for PathRouter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PathRouter") + .field("routes", &self.routes) + .field("node", &self.node) + .finish() + } +} + +impl Clone for PathRouter { + fn clone(&self) -> Self { + Self { + routes: self.routes.clone(), + node: self.node.clone(), + prev_route_id: self.prev_route_id, + } + } +} + +/// Wrapper around `matchit::Router` that supports merging two `Router`s. +#[derive(Clone, Default)] +struct Node { + inner: matchit::Router, + route_id_to_path: HashMap>, + path_to_route_id: HashMap, RouteId>, +} + +impl Node { + fn insert( + &mut self, + path: impl Into, + val: RouteId, + ) -> Result<(), matchit::InsertError> { + let path = path.into(); + + self.inner.insert(&path, val)?; + + let shared_path: Arc = path.into(); + self.route_id_to_path.insert(val, shared_path.clone()); + self.path_to_route_id.insert(shared_path, val); + + Ok(()) + } + + fn at<'n, 'p>( + &'n self, + path: &'p str, + ) -> Result, MatchError> { + self.inner.at(path) + } +} + +impl fmt::Debug for Node { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Node") + .field("paths", &self.route_id_to_path) + .finish() + } +} + +#[track_caller] +fn validate_nest_path(path: &str) -> &str { + if path.is_empty() { + // nesting at `""` and `"/"` should mean the same thing + return "/"; + } + + if path.contains('*') { + panic!("Invalid route: nested routes cannot contain wildcards (*)"); + } + + path +} + +pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str> { + debug_assert!(prefix.starts_with('/')); + debug_assert!(path.starts_with('/')); + + if prefix.ends_with('/') { + format!("{prefix}{}", path.trim_start_matches('/')).into() + } else if path == "/" { + prefix.into() + } else { + format!("{prefix}{path}").into() + } +} diff --git a/axum/src/routing/strip_prefix.rs b/axum/src/routing/strip_prefix.rs index 7daef140b7..671c4de773 100644 --- a/axum/src/routing/strip_prefix.rs +++ b/axum/src/routing/strip_prefix.rs @@ -3,6 +3,8 @@ use std::{ sync::Arc, task::{Context, Poll}, }; +use tower::Layer; +use tower_layer::layer_fn; use tower_service::Service; #[derive(Clone)] @@ -18,6 +20,14 @@ impl StripPrefix { prefix: prefix.into(), } } + + pub(super) fn layer(prefix: &str) -> impl Layer + Clone { + let prefix = Arc::from(prefix); + layer_fn(move |inner| Self { + inner, + prefix: Arc::clone(&prefix), + }) + } } impl Service> for StripPrefix diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index 65a6791ace..ac72de5980 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -93,6 +93,10 @@ async fn doesnt_inherit_fallback_if_overriden() { let res = client.get("/foo/bar").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner"); + + let res = client.get("/").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "outer"); } #[crate::test] @@ -203,3 +207,21 @@ async fn fallback_inherited_into_nested_opaque_service() { assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } + +#[crate::test] +async fn nest_fallback_on_inner() { + let app = Router::new() + .nest( + "/foo", + Router::new() + .route("/", get(|| async {})) + .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }), + ) + .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") }); + + let client = TestClient::new(app); + + let res = client.get("/foo/not-found").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + assert_eq!(res.text().await, "inner fallback"); +} diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index bed48415d3..b69dbc8f44 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -4,7 +4,10 @@ use crate::{ extract::{self, DefaultBodyLimit, FromRef, Path, State}, handler::{Handler, HandlerWithoutStateExt}, response::IntoResponse, - routing::{delete, get, get_service, on, on_service, patch, patch_service, post, MethodFilter}, + routing::{ + delete, get, get_service, on, on_service, patch, patch_service, + path_router::path_for_nested_route, post, MethodFilter, + }, test_helpers::*, BoxError, Json, Router, }; @@ -601,7 +604,10 @@ async fn head_with_middleware_applied() { use tower_http::compression::{predicate::SizeAbove, CompressionLayer}; let app = Router::new() - .route("/", get(|| async { "Hello, World!" })) + .nest( + "/", + Router::new().route("/", get(|| async { "Hello, World!" })), + ) .layer(CompressionLayer::new().compress_when(SizeAbove::new(0))); let client = TestClient::new(app); @@ -841,6 +847,21 @@ fn method_router_fallback_with_state() { .with_state(state); } +#[test] +fn test_path_for_nested_route() { + assert_eq!(path_for_nested_route("/", "/"), "/"); + + assert_eq!(path_for_nested_route("/a", "/"), "/a"); + assert_eq!(path_for_nested_route("/", "/b"), "/b"); + assert_eq!(path_for_nested_route("/a/", "/"), "/a/"); + assert_eq!(path_for_nested_route("/", "/b/"), "/b/"); + + assert_eq!(path_for_nested_route("/a", "/b"), "/a/b"); + assert_eq!(path_for_nested_route("/a/", "/b"), "/a/b"); + assert_eq!(path_for_nested_route("/a", "/b/"), "/a/b/"); + assert_eq!(path_for_nested_route("/a/", "/b/"), "/a/b/"); +} + #[crate::test] async fn state_isnt_cloned_too_much() { static SETUP_DONE: AtomicBool = AtomicBool::new(false); diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index e0fb6b6e83..0544f8be59 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -230,6 +230,13 @@ async fn nested_multiple_routes() { #[test] #[should_panic = "Invalid route \"/\": insertion failed due to conflict with previously registered route: /*__private__axum_nest_tail_param"] +fn nested_service_at_root_with_other_routes() { + let _: Router = Router::new() + .nest_service("/", Router::new().route("/users", get(|| async {}))) + .route("/", get(|| async {})); +} + +#[test] fn nested_at_root_with_other_routes() { let _: Router = Router::new() .nest("/", Router::new().route("/users", get(|| async {}))) @@ -343,42 +350,40 @@ async fn nest_with_and_without_trailing() { assert_eq!(res.status(), StatusCode::OK); } -#[crate::test] +#[tokio::test] async fn nesting_with_root_inner_router() { - let app = Router::new().nest( - "/foo", - Router::new().route("/", get(|| async { "inner route" })), - ); + let app = Router::new() + .nest_service("/service", Router::new().route("/", get(|| async {}))) + .nest("/router", Router::new().route("/", get(|| async {}))) + .nest("/router-slash/", Router::new().route("/", get(|| async {}))); let client = TestClient::new(app); - // `/foo/` does match the `/foo` prefix and the remaining path is technically + // `/service/` does match the `/service` prefix and the remaining path is technically // empty, which is the same as `/` which matches `.route("/", _)` - let res = client.get("/foo").send().await; + let res = client.get("/service").send().await; assert_eq!(res.status(), StatusCode::OK); - // `/foo/` does match the `/foo` prefix and the remaining path is `/` + // `/service/` does match the `/service` prefix and the remaining path is `/` // which matches `.route("/", _)` - let res = client.get("/foo/").send().await; + // + // this is perhaps a little surprising but don't think there is much we can do + let res = client.get("/service/").send().await; assert_eq!(res.status(), StatusCode::OK); -} -#[crate::test] -async fn fallback_on_inner() { - let app = Router::new() - .nest( - "/foo", - Router::new() - .route("/", get(|| async {})) - .fallback(|| async { (StatusCode::NOT_FOUND, "inner fallback") }), - ) - .fallback(|| async { (StatusCode::NOT_FOUND, "outer fallback") }); + // at least it does work like you'd expect when using `nest` - let client = TestClient::new(app); + let res = client.get("/router").send().await; + assert_eq!(res.status(), StatusCode::OK); + + let res = client.get("/router/").send().await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); - let res = client.get("/foo/not-found").send().await; + let res = client.get("/router-slash").send().await; assert_eq!(res.status(), StatusCode::NOT_FOUND); - assert_eq!(res.text().await, "inner fallback"); + + let res = client.get("/router-slash/").send().await; + assert_eq!(res.status(), StatusCode::OK); } macro_rules! nested_route_test { diff --git a/axum/src/routing/url_params.rs b/axum/src/routing/url_params.rs index c9f05bb6ec..6243d379c0 100644 --- a/axum/src/routing/url_params.rs +++ b/axum/src/routing/url_params.rs @@ -19,6 +19,7 @@ pub(super) fn insert_url_params(extensions: &mut Extensions, params: Params) { let params = params .iter() .filter(|(key, _)| !key.starts_with(super::NEST_TAIL_PARAM)) + .filter(|(key, _)| !key.starts_with(super::FALLBACK_PARAM)) .map(|(k, v)| { if let Some(decoded) = PercentDecodedStr::new(v) { Ok((Arc::from(k), decoded))