diff --git a/CHANGELOG.md b/CHANGELOG.md index 47702ddf63..03a04574aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Breaking changes +- Add `RoutingDsl::or` for combining routes. ([#108](https://github.com/tokio-rs/axum/pull/108)) - Ensure a `HandleError` service created from `axum::ServiceExt::handle_error` _does not_ implement `RoutingDsl` as that could lead to confusing routing behavior. ([#120](https://github.com/tokio-rs/axum/pull/120)) diff --git a/src/lib.rs b/src/lib.rs index 0717f6c796..23c1a0674d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,6 +125,8 @@ //! Routes can also be dynamic like `/users/:id`. See [extractors](#extractors) //! for more details. //! +//! You can also define routes separately and merge them with [`RoutingDsl::or`]. +//! //! ## Precedence //! //! Note that routes are matched _bottom to top_ so routes that should have @@ -662,6 +664,7 @@ //! [`IntoResponse`]: crate::response::IntoResponse //! [`Timeout`]: tower::timeout::Timeout //! [examples]: https://github.com/tokio-rs/axum/tree/main/examples +//! [`RoutingDsl::or`]: crate::routing::RoutingDsl::or //! [`axum::Server`]: hyper::server::Server #![warn( diff --git a/src/routing.rs b/src/routing.rs index 72b6c2e825..38d16fedf2 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -28,6 +28,7 @@ use tower::{ use tower_http::map_response_body::MapResponseBodyLayer; pub mod future; +pub mod or; /// A filter that matches one or more HTTP methods. #[derive(Debug, Copy, Clone)] @@ -354,6 +355,40 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { { IntoMakeServiceWithConnectInfo::new(self) } + + /// Merge two routers into one. + /// + /// This is useful for breaking apps into smaller pieces and combining them + /// into one. + /// + /// ``` + /// use axum::prelude::*; + /// # + /// # async fn users_list() {} + /// # async fn users_show() {} + /// # async fn teams_list() {} + /// + /// // define some routes separately + /// let user_routes = route("/users", get(users_list)) + /// .route("/users/:id", get(users_show)); + /// + /// let team_routes = route("/teams", get(teams_list)); + /// + /// // combine them into one + /// let app = user_routes.or(team_routes); + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); + /// # }; + /// ``` + fn or(self, other: S) -> or::Or + where + S: RoutingDsl, + { + or::Or { + first: self, + second: other, + } + } } impl RoutingDsl for Route {} @@ -448,7 +483,10 @@ impl RoutingDsl for EmptyRouter {} impl crate::sealed::Sealed for EmptyRouter {} -impl Service> for EmptyRouter { +impl Service> for EmptyRouter +where + B: Send + Sync + 'static, +{ type Response = Response; type Error = E; type Future = EmptyRouterFuture; @@ -457,8 +495,9 @@ impl Service> for EmptyRouter { Poll::Ready(Ok(())) } - fn call(&mut self, _req: Request) -> Self::Future { + fn call(&mut self, request: Request) -> Self::Future { let mut res = Response::new(crate::body::empty()); + res.extensions_mut().insert(FromEmptyRouter { request }); *res.status_mut() = self.status; EmptyRouterFuture { future: futures_util::future::ok(res), @@ -466,6 +505,16 @@ impl Service> for EmptyRouter { } } +/// Response extension used by [`EmptyRouter`] to send the request back to [`Or`] so +/// the other service can be called. +/// +/// Without this we would loose ownership of the request when calling the first +/// service in [`Or`]. We also wouldn't be able to identify if the response came +/// from [`EmptyRouter`] and therefore can be discarded in [`Or`]. +struct FromEmptyRouter { + request: Request, +} + #[derive(Debug, Clone)] pub(crate) struct PathPattern(Arc); diff --git a/src/routing/or.rs b/src/routing/or.rs new file mode 100644 index 0000000000..0c884e5a2b --- /dev/null +++ b/src/routing/or.rs @@ -0,0 +1,124 @@ +//! [`Or`] used to combine two services into one. + +use super::{FromEmptyRouter, RoutingDsl}; +use crate::body::BoxBody; +use futures_util::ready; +use http::{Request, Response}; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tower::{util::Oneshot, Service, ServiceExt}; + +/// [`tower::Service`] that is the combination of two routers. +/// +/// See [`RoutingDsl::or`] for more details. +/// +/// [`RoutingDsl::or`]: super::RoutingDsl::or +#[derive(Debug, Clone, Copy)] +pub struct Or { + pub(super) first: A, + pub(super) second: B, +} + +impl RoutingDsl for Or {} + +impl crate::sealed::Sealed for Or {} + +#[allow(warnings)] +impl Service> for Or +where + A: Service, Response = Response> + Clone, + B: Service, Response = Response, Error = A::Error> + Clone, + ReqBody: Send + Sync + 'static, + A: Send + 'static, + B: Send + 'static, + A::Future: Send + 'static, + B::Future: Send + 'static, +{ + type Response = Response; + type Error = A::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + ResponseFuture { + state: State::FirstFuture { + f: self.first.clone().oneshot(req), + }, + second: Some(self.second.clone()), + } + } +} + +pin_project! { + /// Response future for [`Or`]. + pub struct ResponseFuture + where + A: Service>, + B: Service>, + { + #[pin] + state: State, + second: Option, + } +} + +pin_project! { + #[project = StateProj] + enum State + where + A: Service>, + B: Service>, + { + FirstFuture { #[pin] f: Oneshot> }, + SecondFuture { + #[pin] + f: Oneshot>, + } + } +} + +impl Future for ResponseFuture +where + A: Service, Response = Response>, + B: Service, Response = Response, Error = A::Error>, + ReqBody: Send + Sync + 'static, +{ + type Output = Result, A::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + let mut this = self.as_mut().project(); + + let new_state = match this.state.as_mut().project() { + StateProj::FirstFuture { f } => { + let mut response = ready!(f.poll(cx)?); + + let req = if let Some(ext) = response + .extensions_mut() + .remove::>() + { + ext.request + } else { + return Poll::Ready(Ok(response)); + }; + + let second = this.second.take().expect("future polled after completion"); + + State::SecondFuture { + f: second.oneshot(req), + } + } + StateProj::SecondFuture { f } => return f.poll(cx), + }; + + this.state.set(new_state); + } + } +} diff --git a/src/tests.rs b/src/tests.rs index 7f984376b9..07fe3990ef 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,3 +1,5 @@ +#![allow(clippy::blacklisted_name)] + use crate::{ extract::RequestParts, handler::on, prelude::*, routing::nest, routing::MethodFilter, service, }; @@ -18,6 +20,7 @@ use tower::{make::Shared, service_fn, BoxError, Service, ServiceBuilder}; use tower_http::{compression::CompressionLayer, trace::TraceLayer}; mod nest; +mod or; #[tokio::test] async fn hello_world() { diff --git a/src/tests/or.rs b/src/tests/or.rs new file mode 100644 index 0000000000..3c43275146 --- /dev/null +++ b/src/tests/or.rs @@ -0,0 +1,203 @@ +use tower::{limit::ConcurrencyLimitLayer, timeout::TimeoutLayer}; + +use super::*; + +#[tokio::test] +async fn basic() { + let one = route("/foo", get(|| async {})).route("/bar", get(|| async {})); + let two = route("/baz", get(|| async {})); + let app = one.or(two); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .get(format!("http://{}/bar", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .get(format!("http://{}/baz", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .get(format!("http://{}/qux", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test] +async fn layer() { + let one = route("/foo", get(|| async {})); + let two = route("/bar", get(|| async {})).layer(ConcurrencyLimitLayer::new(10)); + let app = one.or(two); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .get(format!("http://{}/bar", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn layer_and_handle_error() { + let one = route("/foo", get(|| async {})); + let two = route("/time-out", get(futures::future::pending::<()>)) + .layer(TimeoutLayer::new(Duration::from_millis(10))) + .handle_error(|_| Ok(StatusCode::REQUEST_TIMEOUT)); + let app = one.or(two); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/time-out", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); +} + +#[tokio::test] +async fn nesting() { + let one = route("/foo", get(|| async {})); + let two = nest("/bar", route("/baz", get(|| async {}))); + let app = one.or(two); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/bar/baz", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn boxed() { + let one = route("/foo", get(|| async {})).boxed(); + let two = route("/bar", get(|| async {})).boxed(); + let app = one.or(two); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/bar", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn many_ors() { + let app = route("/r1", get(|| async {})) + .or(route("/r2", get(|| async {}))) + .or(route("/r3", get(|| async {}))) + .or(route("/r4", get(|| async {}))) + .or(route("/r5", get(|| async {}))) + .or(route("/r6", get(|| async {}))) + .or(route("/r7", get(|| async {}))); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + for n in 1..=7 { + let res = client + .get(format!("http://{}/r{}", addr, n)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + } + + let res = client + .get(format!("http://{}/r8", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test] +async fn services() { + let app = route( + "/foo", + crate::service::get(service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::empty())) + })), + ) + .or(route( + "/bar", + crate::service::get(service_fn(|_: Request| async { + Ok::<_, Infallible>(Response::new(Body::empty())) + })), + )); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .get(format!("http://{}/bar", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); +} + +// TODO(david): can we make this not compile? +// #[tokio::test] +// async fn foo() { +// let svc_one = service_fn(|_: Request| async { +// Ok::<_, hyper::Error>(Response::new(Body::empty())) +// }) +// .handle_error::<_, _, hyper::Error>(|_| Ok(StatusCode::INTERNAL_SERVER_ERROR)); + +// let svc_two = svc_one.clone(); + +// let app = svc_one.or(svc_two); + +// let addr = run_in_background(app).await; +// }