Skip to content

Commit

Permalink
Add type safe state extractor (#1155)
Browse files Browse the repository at this point in the history
* begin threading the state through

* Pass state to extractors

* make state extractor work

* make sure nesting with different states work

* impl Service for MethodRouter<()>

* Fix some of axum-macro's tests

* Implement more traits for `State`

* Update examples to use `State`

* consistent naming of request body param

* swap type params

* Default the state param to ()

* fix docs references

* Docs and handler state refactoring

* docs clean ups

* more consistent naming

* when does MethodRouter implement Service?

* add missing docs

* use `Router`'s default state type param

* changelog

* don't use default type param for FromRequest and RequestParts

probably safer for library authors so you don't accidentally forget

* fix examples

* minor docs tweaks

* clarify how to convert handlers into services

* group methods in one impl block

* make sure merged `MethodRouter`s can access state

* fix docs link

* test merge with same state type

* Document how to access state from middleware

* Port cookie extractors to use state to extract keys (#1250)

* Updates ECOSYSTEM with a new sample project (#1252)

* Avoid unhelpful compiler suggestion (#1251)

* fix docs typo

* document how library authors should access state

* Add `RequestParts::with_state`

* fix example

* apply suggestions from review

* add relevant changes to axum-extra and axum-core changelogs

* Add `route_service_with_tsr`

* fix trybuild expectations

* make sure `SpaRouter` works with routers that have state

* Change order of type params on FromRequest and RequestParts

* reverse order of `RequestParts::with_state` args to match type params

* Add `FromRef` trait (#1268)

* Add `FromRef` trait

* Remove unnecessary type params

* format

* fix docs link

* format examples

* Avoid unnecessary `MethodRouter`

* apply suggestions from review

Co-authored-by: Dani Pardo <[email protected]>
Co-authored-by: Jonas Platte <[email protected]>
  • Loading branch information
3 people authored Aug 17, 2022
1 parent 90dbd52 commit 423308d
Show file tree
Hide file tree
Showing 132 changed files with 2,404 additions and 1,126 deletions.
5 changes: 4 additions & 1 deletion axum-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- None.
- **breaking:** `FromRequest` and `RequestParts` has a new `S` type param which
represents the state ([#1155])

[#1155]: https://github.com/tokio-rs/axum/pull/1155

# 0.2.6 (18. June, 2022)

Expand Down
23 changes: 23 additions & 0 deletions axum-core/src/extract/from_ref.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/// Used to do reference-to-value conversions thus not consuming the input value.
///
/// This is mainly used with [`State`] to extract "substates" from a reference to main application
/// state.
///
/// See [`State`] for more details on how library authors should use this trait.
///
/// [`State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html
// NOTE: This trait is defined in axum-core, even though it is mainly used with `State` which is
// defined in axum. That allows crate authors to use it when implementing extractors.
pub trait FromRef<T> {
/// Converts to this type from a reference to the input type.
fn from_ref(input: &T) -> Self;
}

impl<T> FromRef<T> for T
where
T: Clone,
{
fn from_ref(input: &T) -> Self {
input.clone()
}
}
69 changes: 52 additions & 17 deletions axum-core/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ use std::convert::Infallible;

pub mod rejection;

mod from_ref;
mod request_parts;
mod tuple;

pub use self::from_ref::FromRef;

/// Types that can be created from requests.
///
/// See [`axum::extract`] for more details.
Expand Down Expand Up @@ -42,13 +45,15 @@ mod tuple;
/// struct MyExtractor;
///
/// #[async_trait]
/// impl<B> FromRequest<B> for MyExtractor
/// impl<S, B> FromRequest<S, B> for MyExtractor
/// where
/// B: Send, // required by `async_trait`
/// // these bounds are required by `async_trait`
/// B: Send,
/// S: Send,
/// {
/// type Rejection = http::StatusCode;
///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
/// // ...
/// # unimplemented!()
/// }
Expand All @@ -60,20 +65,21 @@ mod tuple;
/// [`http::Request<B>`]: http::Request
/// [`axum::extract`]: https://docs.rs/axum/latest/axum/extract/index.html
#[async_trait]
pub trait FromRequest<B>: Sized {
pub trait FromRequest<S, B>: Sized {
/// If the extractor fails it'll use this "rejection" type. A rejection is
/// a kind of error that can be converted into a response.
type Rejection: IntoResponse;

/// Perform the extraction.
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection>;
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection>;
}

/// The type used with [`FromRequest`] to extract data from requests.
///
/// Has several convenience methods for getting owned parts of the request.
#[derive(Debug)]
pub struct RequestParts<B> {
pub struct RequestParts<S, B> {
state: S,
method: Method,
uri: Uri,
version: Version,
Expand All @@ -82,15 +88,28 @@ pub struct RequestParts<B> {
body: Option<B>,
}

impl<B> RequestParts<B> {
/// Create a new `RequestParts`.
impl<B> RequestParts<(), B> {
/// Create a new `RequestParts` without any state.
///
/// You generally shouldn't need to construct this type yourself, unless
/// using extractors outside of axum for example to implement a
/// [`tower::Service`].
///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn new(req: Request<B>) -> Self {
Self::with_state((), req)
}
}

impl<S, B> RequestParts<S, B> {
/// Create a new `RequestParts` with the given state.
///
/// You generally shouldn't need to construct this type yourself, unless
/// using extractors outside of axum for example to implement a
/// [`tower::Service`].
///
/// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html
pub fn with_state(state: S, req: Request<B>) -> Self {
let (
http::request::Parts {
method,
Expand All @@ -104,6 +123,7 @@ impl<B> RequestParts<B> {
) = req.into_parts();

RequestParts {
state,
method,
uri,
version,
Expand All @@ -130,18 +150,25 @@ impl<B> RequestParts<B> {
/// use http::{Method, Uri};
///
/// #[async_trait]
/// impl<B: Send> FromRequest<B> for MyExtractor {
/// impl<S, B> FromRequest<S, B> for MyExtractor
/// where
/// B: Send,
/// S: Send,
/// {
/// type Rejection = Infallible;
///
/// async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Infallible> {
/// async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Infallible> {
/// let method = req.extract::<Method>().await?;
/// let path = req.extract::<Uri>().await?.path().to_owned();
///
/// todo!()
/// }
/// }
/// ```
pub async fn extract<E: FromRequest<B>>(&mut self) -> Result<E, E::Rejection> {
pub async fn extract<E>(&mut self) -> Result<E, E::Rejection>
where
E: FromRequest<S, B>,
{
E::from_request(self).await
}

Expand All @@ -153,6 +180,7 @@ impl<B> RequestParts<B> {
/// [`take_body`]: RequestParts::take_body
pub fn try_into_request(self) -> Result<Request<B>, BodyAlreadyExtracted> {
let Self {
state: _,
method,
uri,
version,
Expand Down Expand Up @@ -245,30 +273,37 @@ impl<B> RequestParts<B> {
pub fn take_body(&mut self) -> Option<B> {
self.body.take()
}

/// Get a reference to the state.
pub fn state(&self) -> &S {
&self.state
}
}

#[async_trait]
impl<T, B> FromRequest<B> for Option<T>
impl<S, T, B> FromRequest<S, B> for Option<T>
where
T: FromRequest<B>,
T: FromRequest<S, B>,
B: Send,
S: Send,
{
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<B>) -> Result<Option<T>, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Option<T>, Self::Rejection> {
Ok(T::from_request(req).await.ok())
}
}

#[async_trait]
impl<T, B> FromRequest<B> for Result<T, T::Rejection>
impl<S, T, B> FromRequest<S, B> for Result<T, T::Rejection>
where
T: FromRequest<B>,
T: FromRequest<S, B>,
B: Send,
S: Send,
{
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(T::from_request(req).await)
}
}
43 changes: 26 additions & 17 deletions axum-core/src/extract/request_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@ use http::{Extensions, HeaderMap, Method, Request, Uri, Version};
use std::convert::Infallible;

#[async_trait]
impl<B> FromRequest<B> for Request<B>
impl<S, B> FromRequest<S, B> for Request<B>
where
B: Send,
S: Clone + Send,
{
type Rejection = BodyAlreadyExtracted;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let req = std::mem::replace(
req,
RequestParts {
state: req.state().clone(),
method: req.method.clone(),
version: req.version,
uri: req.uri.clone(),
Expand All @@ -30,37 +32,40 @@ where
}

#[async_trait]
impl<B> FromRequest<B> for Method
impl<S, B> FromRequest<S, B> for Method
where
B: Send,
S: Send,
{
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.method().clone())
}
}

#[async_trait]
impl<B> FromRequest<B> for Uri
impl<S, B> FromRequest<S, B> for Uri
where
B: Send,
S: Send,
{
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.uri().clone())
}
}

#[async_trait]
impl<B> FromRequest<B> for Version
impl<S, B> FromRequest<S, B> for Version
where
B: Send,
S: Send,
{
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.version())
}
}
Expand All @@ -71,27 +76,29 @@ where
///
/// [`TypedHeader`]: https://docs.rs/axum/latest/axum/extract/struct.TypedHeader.html
#[async_trait]
impl<B> FromRequest<B> for HeaderMap
impl<S, B> FromRequest<S, B> for HeaderMap
where
B: Send,
S: Send,
{
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
Ok(req.headers().clone())
}
}

#[async_trait]
impl<B> FromRequest<B> for Bytes
impl<S, B> FromRequest<S, B> for Bytes
where
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
{
type Rejection = BytesRejection;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?;

let bytes = crate::body::to_bytes(body)
Expand All @@ -103,15 +110,16 @@ where
}

#[async_trait]
impl<B> FromRequest<B> for String
impl<S, B> FromRequest<S, B> for String
where
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<BoxError>,
S: Send,
{
type Rejection = StringRejection;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let body = take_body(req)?;

let bytes = crate::body::to_bytes(body)
Expand All @@ -126,13 +134,14 @@ where
}

#[async_trait]
impl<B> FromRequest<B> for http::request::Parts
impl<S, B> FromRequest<S, B> for http::request::Parts
where
B: Send,
S: Send,
{
type Rejection = Infallible;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
let method = unwrap_infallible(Method::from_request(req).await);
let uri = unwrap_infallible(Uri::from_request(req).await);
let version = unwrap_infallible(Version::from_request(req).await);
Expand All @@ -159,6 +168,6 @@ fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
}
}

pub(crate) fn take_body<B>(req: &mut RequestParts<B>) -> Result<B, BodyAlreadyExtracted> {
pub(crate) fn take_body<S, B>(req: &mut RequestParts<S, B>) -> Result<B, BodyAlreadyExtracted> {
req.take_body().ok_or(BodyAlreadyExtracted)
}
12 changes: 7 additions & 5 deletions axum-core/src/extract/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ use async_trait::async_trait;
use std::convert::Infallible;

#[async_trait]
impl<B> FromRequest<B> for ()
impl<S, B> FromRequest<S, B> for ()
where
B: Send,
S: Send,
{
type Rejection = Infallible;

async fn from_request(_: &mut RequestParts<B>) -> Result<(), Self::Rejection> {
async fn from_request(_: &mut RequestParts<S, B>) -> Result<(), Self::Rejection> {
Ok(())
}
}
Expand All @@ -21,14 +22,15 @@ macro_rules! impl_from_request {
( $($ty:ident),* $(,)? ) => {
#[async_trait]
#[allow(non_snake_case)]
impl<B, $($ty,)*> FromRequest<B> for ($($ty,)*)
impl<S, B, $($ty,)*> FromRequest<S, B> for ($($ty,)*)
where
$( $ty: FromRequest<B> + Send, )*
$( $ty: FromRequest<S, B> + Send, )*
B: Send,
S: Send,
{
type Rejection = Response;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
async fn from_request(req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
$( let $ty = $ty::from_request(req).await.map_err(|err| err.into_response())?; )*
Ok(($($ty,)*))
}
Expand Down
Loading

0 comments on commit 423308d

Please sign in to comment.