Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't allow extracting MatchedPath in fallbacks #1934

Merged
merged 2 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- None.
- **fixed:** Don't allow extracting `MatchedPath` in fallbacks ([#1934])
- **fixed:** Fix panic if `Router` with something nested at `/` was used as a fallback ([#1934])

[#1934]: https://github.com/tokio-rs/axum/pull/1934

# 0.6.15 (12. April, 2023)

Expand Down
16 changes: 16 additions & 0 deletions axum/src/extract/matched_path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ mod tests {
Router,
};
use http::{Request, StatusCode};
use hyper::Body;

#[crate::test]
async fn extracting_on_handler() {
Expand Down Expand Up @@ -353,4 +354,19 @@ mod tests {
let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}

#[crate::test]
async fn cant_extract_in_fallback() {
async fn handler(path: Option<MatchedPath>, req: Request<Body>) {
assert!(path.is_none());
assert!(req.extensions().get::<MatchedPath>().is_none());
}

let app = Router::new().fallback(handler);

let client = TestClient::new(app);

let res = client.get("/foo/bar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
}
12 changes: 8 additions & 4 deletions axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
//! Routing between [`Service`]s and handlers.

use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter};
use self::{
future::RouteFuture,
not_found::NotFound,
path_router::{IsFallback, IsNotFallback, PathRouter},
};
#[cfg(feature = "tokio")]
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{
Expand Down Expand Up @@ -57,8 +61,8 @@ pub(crate) struct RouteId(u32);
/// The router type for composing handlers and services.
#[must_use]
pub struct Router<S = (), B = Body> {
path_router: PathRouter<S, B>,
fallback_router: PathRouter<S, B>,
path_router: PathRouter<IsNotFallback, S, B>,
fallback_router: PathRouter<IsFallback, S, B>,
default_fallback: bool,
}

Expand Down Expand Up @@ -499,7 +503,7 @@ impl<S, B> fmt::Debug for Endpoint<S, B> {
}
}

struct SuperFallback<S, B>(SyncWrapper<PathRouter<S, B>>);
struct SuperFallback<S, B>(SyncWrapper<PathRouter<IsFallback, S, B>>);

#[test]
#[allow(warnings)]
Expand Down
67 changes: 50 additions & 17 deletions axum/src/routing/path_router.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use crate::body::{Body, HttpBody};
use crate::body::HttpBody;
use axum_core::response::IntoResponse;
use http::Request;
use matchit::MatchError;
use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
use std::{
borrow::Cow, collections::HashMap, convert::Infallible, fmt, marker::PhantomData, sync::Arc,
};
use tower_layer::Layer;
use tower_service::Service;

Expand All @@ -11,14 +13,16 @@ use super::{
RouteId, NEST_TAIL_PARAM,
};

pub(super) struct PathRouter<S = (), B = Body> {
pub(super) struct PathRouter<F, S, B> {
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
routes: HashMap<RouteId, Endpoint<S, B>>,
node: Arc<Node>,
prev_route_id: RouteId,
_is_it_a_fallback: PhantomData<F>,
}

impl<S, B> PathRouter<S, B>
impl<F, S, B> PathRouter<F, S, B>
where
F: IsItAFallback,
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
Expand Down Expand Up @@ -107,11 +111,12 @@ where
Ok(())
}

pub(super) fn merge(&mut self, other: PathRouter<S, B>) -> Result<(), Cow<'static, str>> {
pub(super) fn merge(&mut self, other: PathRouter<F, S, B>) -> Result<(), Cow<'static, str>> {
let PathRouter {
routes,
node,
prev_route_id: _,
_is_it_a_fallback: _,
} = other;

for (id, route) in routes {
Expand All @@ -131,14 +136,15 @@ where
pub(super) fn nest(
&mut self,
path: &str,
router: PathRouter<S, B>,
router: PathRouter<F, S, B>,
) -> Result<(), Cow<'static, str>> {
let prefix = validate_nest_path(path);

let PathRouter {
routes,
node,
prev_route_id: _,
_is_it_a_fallback: _,
} = router;

for (id, endpoint) in routes {
Expand Down Expand Up @@ -193,7 +199,7 @@ where
Ok(())
}

pub(super) fn layer<L, NewReqBody>(self, layer: L) -> PathRouter<S, NewReqBody>
pub(super) fn layer<L, NewReqBody>(self, layer: L) -> PathRouter<F, S, NewReqBody>
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
Expand All @@ -215,6 +221,7 @@ where
routes,
node: self.node,
prev_route_id: self.prev_route_id,
_is_it_a_fallback: self._is_it_a_fallback,
}
}

Expand Down Expand Up @@ -247,10 +254,11 @@ where
routes,
node: self.node,
prev_route_id: self.prev_route_id,
_is_it_a_fallback: self._is_it_a_fallback,
}
}

pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, B> {
pub(super) fn with_state<S2>(self, state: S) -> PathRouter<F, S2, B> {
let routes = self
.routes
.into_iter()
Expand All @@ -269,6 +277,7 @@ where
routes,
node: self.node,
prev_route_id: self.prev_route_id,
_is_it_a_fallback: self._is_it_a_fallback,
}
}

Expand All @@ -293,12 +302,14 @@ where
Ok(match_) => {
let id = *match_.value;

#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
req.extensions_mut(),
);
if !F::FALLBACK {
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
req.extensions_mut(),
);
}

url_params::insert_url_params(req.extensions_mut(), match_.params);

Expand Down Expand Up @@ -347,17 +358,18 @@ where
}
}

impl<B, S> Default for PathRouter<S, B> {
impl<F, B, S> Default for PathRouter<F, S, B> {
fn default() -> Self {
Self {
routes: Default::default(),
node: Default::default(),
prev_route_id: RouteId(0),
_is_it_a_fallback: PhantomData,
}
}
}

impl<S, B> fmt::Debug for PathRouter<S, B> {
impl<F, S, B> fmt::Debug for PathRouter<F, S, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PathRouter")
.field("routes", &self.routes)
Expand All @@ -366,12 +378,13 @@ impl<S, B> fmt::Debug for PathRouter<S, B> {
}
}

impl<S, B> Clone for PathRouter<S, B> {
impl<F, S, B> Clone for PathRouter<F, S, B> {
fn clone(&self) -> Self {
Self {
routes: self.routes.clone(),
node: self.node.clone(),
prev_route_id: self.prev_route_id,
_is_it_a_fallback: self._is_it_a_fallback,
}
}
}
Expand Down Expand Up @@ -443,3 +456,23 @@ pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'
format!("{prefix}{path}").into()
}
}

/// Used to statically enforce that we don't merge/nest a fallback `PathRouter` into a non-fallback
/// `PathRouter`.
pub(super) trait IsItAFallback {
const FALLBACK: bool;
}

#[derive(Copy, Clone)]
pub(super) struct IsFallback;

#[derive(Copy, Clone)]
pub(super) struct IsNotFallback;

impl IsItAFallback for IsFallback {
const FALLBACK: bool = true;
}

impl IsItAFallback for IsNotFallback {
const FALLBACK: bool = false;
}
16 changes: 16 additions & 0 deletions axum/src/routing/tests/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,19 @@ async fn nest_fallback_on_inner() {
assert_eq!(res.status(), StatusCode::NOT_FOUND);
assert_eq!(res.text().await, "inner fallback");
}

// https://github.com/tokio-rs/axum/issues/1931
#[crate::test]
async fn doesnt_panic_if_used_with_nested_router() {
async fn handler() {}

let routes_static =
Router::new().nest_service("/", crate::routing::get_service(handler.into_service()));

let routes_all = Router::new().fallback_service(routes_static);

let client = TestClient::new(routes_all);

let res = client.get("/foobar").send().await;
assert_eq!(res.status(), StatusCode::OK);
}