diff --git a/examples/src/interceptor/server.rs b/examples/src/interceptor/server.rs index b73a15d9f..a79d98e2e 100644 --- a/examples/src/interceptor/server.rs +++ b/examples/src/interceptor/server.rs @@ -16,6 +16,9 @@ impl Greeter for MyGreeter { &self, request: Request, ) -> Result, Status> { + let extension = request.extensions().get::().unwrap(); + println!("extension data = {}", extension.some_piece_of_data); + let reply = hello_world::HelloReply { message: format!("Hello {}!", request.into_inner().name), }; @@ -40,7 +43,17 @@ async fn main() -> Result<(), Box> { /// This function will get called on each inbound request, if a `Status` /// is returned, it will cancel the request and return that status to the /// client. -fn intercept(req: Request<()>) -> Result, Status> { +fn intercept(mut req: Request<()>) -> Result, Status> { println!("Intercepting request: {:?}", req); + + // Set an extension that can be retrieved by `say_hello` + req.extensions_mut().insert(MyExtension { + some_piece_of_data: "foo".to_string(), + }); + Ok(req) } + +struct MyExtension { + some_piece_of_data: String, +} diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index 2e6836862..1b4acef2a 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -17,6 +17,9 @@ bytes = "1.0" [dev-dependencies] tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] } tokio-stream = { version = "0.1.5", features = ["net"] } +tower-service = "0.3" +hyper = "0.14" +futures = "0.3" [build-dependencies] tonic-build = { path = "../../tonic-build" } diff --git a/tests/integration_tests/tests/extensions.rs b/tests/integration_tests/tests/extensions.rs new file mode 100644 index 000000000..68c5f0375 --- /dev/null +++ b/tests/integration_tests/tests/extensions.rs @@ -0,0 +1,150 @@ +use futures_util::FutureExt; +use hyper::{Body, Request as HyperRequest, Response as HyperResponse}; +use integration_tests::pb::{test_client, test_server, Input, Output}; +use std::{ + task::{Context, Poll}, + time::Duration, +}; +use tokio::sync::oneshot; +use tonic::{ + body::BoxBody, + transport::{Endpoint, NamedService, Server}, + Request, Response, Status, +}; +use tower_service::Service; + +struct ExtensionValue(i32); + +#[tokio::test] +async fn setting_extension_from_interceptor() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, req: Request) -> Result, Status> { + let value = req.extensions().get::().unwrap(); + assert_eq!(value.0, 42); + + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::with_interceptor(Svc, |mut req: Request<()>| { + req.extensions_mut().insert(ExtensionValue(42)); + Ok(req) + }); + + let (tx, rx) = oneshot::channel::<()>(); + + let jh = tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1323".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let channel = Endpoint::from_static("http://127.0.0.1:1323") + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + match client.unary_call(Input {}).await { + Ok(_) => {} + Err(status) => panic!("{}", status.message()), + } + + tx.send(()).unwrap(); + + jh.await.unwrap(); +} + +#[tokio::test] +async fn setting_extension_from_tower() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, req: Request) -> Result, Status> { + let value = req.extensions().get::().unwrap(); + assert_eq!(value.0, 42); + + Ok(Response::new(Output {})) + } + } + + let svc = InterceptedService { + inner: test_server::TestServer::new(Svc), + }; + + let (tx, rx) = oneshot::channel::<()>(); + + let jh = tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1324".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let channel = Endpoint::from_static("http://127.0.0.1:1324") + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + match client.unary_call(Input {}).await { + Ok(_) => {} + Err(status) => panic!("{}", status.message()), + } + + tx.send(()).unwrap(); + + jh.await.unwrap(); +} + +#[derive(Debug, Clone)] +struct InterceptedService { + inner: S, +} + +impl Service> for InterceptedService +where + S: Service, Response = HyperResponse> + + NamedService + + Clone + + Send + + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = futures::future::BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: HyperRequest) -> Self::Future { + let clone = self.inner.clone(); + let mut inner = std::mem::replace(&mut self.inner, clone); + + req.extensions_mut().insert(ExtensionValue(42)); + + Box::pin(async move { + let response = inner.call(req).await?; + Ok(response) + }) + } +} + +impl NamedService for InterceptedService { + const NAME: &'static str = S::NAME; +} diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 2e3c2de85..4142be8d0 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -97,7 +97,8 @@ impl Grpc { M1: Send + Sync + 'static, M2: Send + Sync + 'static, { - let (mut parts, body) = self.streaming(request, path, codec).await?.into_parts(); + let (mut parts, body, extensions) = + self.streaming(request, path, codec).await?.into_parts(); futures_util::pin_mut!(body); @@ -114,7 +115,7 @@ impl Grpc { parts.merge(trailers); } - Ok(Response::from_parts(parts, message)) + Ok(Response::from_parts(parts, message, extensions)) } /// Send a server side streaming gRPC request. diff --git a/tonic/src/extensions.rs b/tonic/src/extensions.rs new file mode 100644 index 000000000..7b897e406 --- /dev/null +++ b/tonic/src/extensions.rs @@ -0,0 +1,67 @@ +use std::fmt; + +/// A type map of protocol extensions. +/// +/// `Extensions` can be used by [`Interceptor`] and [`Request`] to store extra data derived from +/// the underlying protocol. +/// +/// [`Interceptor`]: crate::Interceptor +/// [`Request`]: crate::Request +pub struct Extensions(http::Extensions); + +impl Extensions { + pub(crate) fn new() -> Self { + Self(http::Extensions::new()) + } + + /// Insert a type into this `Extensions`. + /// + /// If a extension of this type already existed, it will + /// be returned. + #[inline] + pub fn insert(&mut self, val: T) -> Option { + self.0.insert(val) + } + + /// Get a reference to a type previously inserted on this `Extensions`. + #[inline] + pub fn get(&self) -> Option<&T> { + self.0.get() + } + + /// Get a mutable reference to a type previously inserted on this `Extensions`. + #[inline] + pub fn get_mut(&mut self) -> Option<&mut T> { + self.0.get_mut() + } + + /// Remove a type from this `Extensions`. + /// + /// If a extension of this type existed, it will be returned. + #[inline] + pub fn remove(&mut self) -> Option { + self.0.remove() + } + + /// Clear the `Extensions` of all inserted extensions. + #[inline] + pub fn clear(&mut self) { + self.0.clear() + } + + #[inline] + pub(crate) fn from_http(http: http::Extensions) -> Self { + Self(http) + } + + #[inline] + pub(crate) fn into_http(self) -> http::Extensions { + self.0 + } +} + +impl fmt::Debug for Extensions { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Extensions").finish() + } +} diff --git a/tonic/src/lib.rs b/tonic/src/lib.rs index 266a45c50..d5eb61ec3 100644 --- a/tonic/src/lib.rs +++ b/tonic/src/lib.rs @@ -62,6 +62,7 @@ //! [`transport`]: transport/index.html #![recursion_limit = "256"] +#![allow(clippy::inconsistent_struct_constructor)] #![warn( missing_debug_implementations, missing_docs, @@ -87,6 +88,7 @@ pub mod server; #[cfg_attr(docsrs, doc(cfg(feature = "transport")))] pub mod transport; +mod extensions; mod interceptor; mod macros; mod request; @@ -100,6 +102,7 @@ pub use async_trait::async_trait; #[doc(inline)] pub use codec::Streaming; +pub use extensions::Extensions; pub use interceptor::Interceptor; pub use request::{IntoRequest, IntoStreamingRequest, Request}; pub use response::Response; diff --git a/tonic/src/request.rs b/tonic/src/request.rs index 7d8f80260..e0a033ef6 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -1,8 +1,8 @@ use crate::metadata::{MetadataMap, MetadataValue}; #[cfg(feature = "transport")] use crate::transport::Certificate; +use crate::Extensions; use futures_core::Stream; -use http::Extensions; #[cfg(feature = "transport")] use std::sync::Arc; use std::{net::SocketAddr, time::Duration}; @@ -116,7 +116,7 @@ impl Request { Request { metadata: MetadataMap::new(), message, - extensions: Extensions::default(), + extensions: Extensions::new(), } } @@ -161,7 +161,7 @@ impl Request { Request { metadata: MetadataMap::from_headers(parts.headers), message, - extensions: parts.extensions, + extensions: Extensions::from_http(parts.extensions), } } @@ -178,7 +178,7 @@ impl Request { *request.method_mut() = http::Method::POST; *request.uri_mut() = uri; *request.headers_mut() = self.metadata.into_sanitized_headers(); - *request.extensions_mut() = self.extensions; + *request.extensions_mut() = self.extensions.into_http(); request } @@ -193,7 +193,7 @@ impl Request { Request { metadata: self.metadata, message, - extensions: Extensions::default(), + extensions: Extensions::new(), } } @@ -254,6 +254,60 @@ impl Request { self.metadata_mut() .insert(crate::metadata::GRPC_TIMEOUT_HEADER, value); } + + /// Returns a reference to the associated extensions. + pub fn extensions(&self) -> &Extensions { + &self.extensions + } + + /// Returns a mutable reference to the associated extensions. + /// + /// # Example + /// + /// Extensions can be set in interceptors: + /// + /// ```no_run + /// use tonic::{Request, Interceptor}; + /// + /// struct MyExtension { + /// some_piece_of_data: String, + /// } + /// + /// Interceptor::new(|mut request: Request<()>| { + /// request.extensions_mut().insert(MyExtension { + /// some_piece_of_data: "foo".to_string(), + /// }); + /// + /// Ok(request) + /// }); + /// ``` + /// + /// And picked up by RPCs: + /// + /// ```no_run + /// use tonic::{async_trait, Status, Request, Response}; + /// # + /// # struct Output {} + /// # struct Input; + /// # struct MyService; + /// # struct MyExtension; + /// # #[async_trait] + /// # trait TestService { + /// # async fn handler(&self, req: Request) -> Result, Status>; + /// # } + /// + /// #[async_trait] + /// impl TestService for MyService { + /// async fn handler(&self, req: Request) -> Result, Status> { + /// let value: &MyExtension = req.extensions().get::().unwrap(); + /// + /// Ok(Response::new(Output {})) + /// } + /// } + /// ``` + pub fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions + } } impl IntoRequest for T { diff --git a/tonic/src/response.rs b/tonic/src/response.rs index ab6eaa66a..87f59b4e4 100644 --- a/tonic/src/response.rs +++ b/tonic/src/response.rs @@ -1,10 +1,11 @@ -use crate::metadata::MetadataMap; +use crate::{metadata::MetadataMap, Extensions}; /// A gRPC response and metadata from an RPC call. #[derive(Debug)] pub struct Response { metadata: MetadataMap, message: T, + extensions: Extensions, } impl Response { @@ -24,6 +25,7 @@ impl Response { Response { metadata: MetadataMap::new(), message, + extensions: Extensions::new(), } } @@ -52,12 +54,16 @@ impl Response { self.message } - pub(crate) fn into_parts(self) -> (MetadataMap, T) { - (self.metadata, self.message) + pub(crate) fn into_parts(self) -> (MetadataMap, T, Extensions) { + (self.metadata, self.message, self.extensions) } - pub(crate) fn from_parts(metadata: MetadataMap, message: T) -> Self { - Self { metadata, message } + pub(crate) fn from_parts(metadata: MetadataMap, message: T, extensions: Extensions) -> Self { + Self { + metadata, + message, + extensions, + } } pub(crate) fn from_http(res: http::Response) -> Self { @@ -65,6 +71,7 @@ impl Response { Response { metadata: MetadataMap::from_headers(head.headers), message, + extensions: Extensions::from_http(head.extensions), } } @@ -73,6 +80,7 @@ impl Response { *res.version_mut() = http::Version::HTTP_2; *res.headers_mut() = self.metadata.into_sanitized_headers(); + *res.extensions_mut() = self.extensions.into_http(); res } @@ -86,8 +94,19 @@ impl Response { Response { metadata: self.metadata, message, + extensions: self.extensions, } } + + /// Returns a reference to the associated extensions. + pub fn extensions(&self) -> &Extensions { + &self.extensions + } + + /// Returns a mutable reference to the associated extensions. + pub fn extensions_mut(&mut self) -> &mut Extensions { + &mut self.extensions + } } #[cfg(test)]