From ceed9c1c328b39d453c91be43627b00e343b178f Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Tue, 25 Oct 2022 12:50:31 +0200 Subject: [PATCH] Bonus: removed unnecessary boxing Signed-off-by: slinkydeveloper --- tonic-web/src/lib.rs | 3 -- tonic-web/src/service.rs | 90 ++++++++++++++++++++++++++++++++++------ 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/tonic-web/src/lib.rs b/tonic-web/src/lib.rs index 2ac3e2fde..394c86dd4 100644 --- a/tonic-web/src/lib.rs +++ b/tonic-web/src/lib.rs @@ -95,8 +95,6 @@ mod layer; mod service; use http::header::HeaderName; -use std::future::Future; -use std::pin::Pin; use std::time::Duration; use tonic::body::BoxBody; use tower_http::cors::{AllowOrigin, Cors, CorsLayer}; @@ -110,7 +108,6 @@ const DEFAULT_ALLOW_HEADERS: [&str; 4] = ["x-grpc-web", "content-type", "x-user-agent", "grpc-timeout"]; type BoxError = Box; -type BoxFuture = Pin> + Send>>; /// Enable a tonic service to handle grpc-web requests with the default configuration. /// diff --git a/tonic-web/src/service.rs b/tonic-web/src/service.rs index ac4ef9a91..78defd80a 100644 --- a/tonic-web/src/service.rs +++ b/tonic-web/src/service.rs @@ -1,7 +1,11 @@ +use futures_core::ready; +use std::future::Future; +use std::pin::Pin; use std::task::{Context, Poll}; use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version}; use hyper::Body; +use pin_project::pin_project; use tonic::body::{empty_body, BoxBody}; use tonic::transport::NamedService; use tower_service::Service; @@ -9,7 +13,7 @@ use tracing::{debug, trace}; use crate::call::content_types::is_grpc_web; use crate::call::{Encoding, GrpcWebCall}; -use crate::{BoxError, BoxFuture}; +use crate::BoxError; const GRPC: &str = "application/grpc"; @@ -47,13 +51,17 @@ impl GrpcWebService where S: Service, Response = Response> + Send + 'static, { - fn response(&self, status: StatusCode) -> BoxFuture { - Box::pin(async move { - Ok(Response::builder() - .status(status) - .body(empty_body()) - .unwrap()) - }) + fn response(&self, status: StatusCode) -> ResponseFuture { + ResponseFuture { + case: Case::ImmediateResponse { + res: Some( + Response::builder() + .status(status) + .body(empty_body()) + .unwrap(), + ), + }, + } } } @@ -65,7 +73,7 @@ where { type Response = S::Response; type Error = S::Error; - type Future = BoxFuture; + type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) @@ -89,8 +97,12 @@ where } => { trace!(kind = "simple", path = ?req.uri().path(), ?encoding, ?accept); - let fut = self.inner.call(coerce_request(req, encoding)); - Box::pin(async move { Ok(coerce_response(fut.await?, accept)) }) + ResponseFuture { + case: Case::GrpcWeb { + future: self.inner.call(coerce_request(req, encoding)), + accept, + }, + } } // The request's content-type matches one of the 4 supported grpc-web @@ -105,7 +117,11 @@ where // whatever they are. RequestKind::Other(Version::HTTP_2) => { debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE)); - Box::pin(self.inner.call(req)) + ResponseFuture { + case: Case::Other { + future: self.inner.call(req), + }, + } } // Return HTTP 400 for all other requests. @@ -117,6 +133,54 @@ where } } +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct ResponseFuture { + #[pin] + case: Case, +} + +#[pin_project(project = CaseProj)] +enum Case { + GrpcWeb { + #[pin] + future: F, + accept: Encoding, + }, + Other { + #[pin] + future: F, + }, + ImmediateResponse { + res: Option>, + }, +} + +impl Future for ResponseFuture +where + F: Future, E>> + Send + 'static, + E: Into + Send, +{ + type Output = Result, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + match this.case.as_mut().project() { + CaseProj::GrpcWeb { future, accept } => { + let res = match ready!(future.poll(cx)) { + Ok(b) => b, + Err(e) => return Poll::Ready(Err(e)), + }; + + Poll::Ready(Ok(coerce_response(res, *accept))) + } + CaseProj::Other { future } => future.poll(cx), + CaseProj::ImmediateResponse { res } => Poll::Ready(Ok(res.take().unwrap())), + } + } +} + impl NamedService for GrpcWebService { const NAME: &'static str = S::NAME; } @@ -177,6 +241,8 @@ mod tests { ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, CONTENT_TYPE, ORIGIN, }; + type BoxFuture = Pin> + Send>>; + #[derive(Debug, Clone)] struct Svc;