From b4fb623cc26ea43ab6a617fb3990003404c92726 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 25 Jun 2021 15:24:22 +0200 Subject: [PATCH 01/29] Initial compression support --- tests/integration_tests/tests/compression.rs | 80 +++++++++++++ tonic/Cargo.toml | 3 + tonic/src/client/grpc.rs | 6 +- tonic/src/codec/compression.rs | 118 +++++++++++++++++++ tonic/src/codec/decode.rs | 76 ++++++++---- tonic/src/codec/encode.rs | 40 +++++-- tonic/src/codec/mod.rs | 1 + tonic/src/codec/prost.rs | 3 +- tonic/src/metadata/map.rs | 3 +- tonic/src/server/grpc.rs | 59 ++++++---- tonic/src/transport/channel/endpoint.rs | 17 +++ tonic/src/transport/channel/mod.rs | 37 ++++-- 12 files changed, 378 insertions(+), 65 deletions(-) create mode 100644 tests/integration_tests/tests/compression.rs create mode 100644 tonic/src/codec/compression.rs diff --git a/tests/integration_tests/tests/compression.rs b/tests/integration_tests/tests/compression.rs new file mode 100644 index 000000000..f3a4d26cc --- /dev/null +++ b/tests/integration_tests/tests/compression.rs @@ -0,0 +1,80 @@ +use integration_tests::pb::{test_client, test_server, Input, Output}; +use std::net::SocketAddr; +use tokio::net::TcpListener; +use tonic::{ + transport::{Channel, Server}, + Code, Request, Response, Status, +}; +use tower::Service; + +// TODO(david): client copmressing messages +// TODO(david): client streaming +// TODO(david): server streaming +// TODO(david): bidirectional streaming + +#[tokio::test] +async fn server_compressing_messages() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, _req: Request) -> Result, Status> { + Ok(Response::new(Output {})) + } + } + + #[derive(Clone)] + struct Middleware(S); + + impl Service> for Middleware + where + S: Service>, + { + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.0.poll_ready(cx) + } + + fn call(&mut self, req: http::Request) -> Self::Future { + assert_eq!( + req.headers().get("grpc-accept-encoding").unwrap(), + "gzip,identity" + ); + + self.0.call(req) + } + } + + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + Server::builder() + .layer(tower::layer::layer_fn(Middleware)) + // .gzip() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .gzip() + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + let res = client.unary_call(Request::new(Input {})).await.unwrap(); + + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); +} diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 989b8f8c0..1bf098d13 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -30,6 +30,7 @@ transport = [ "hyper", "tokio", "tower", + "tower-http", "tracing-futures", "tokio/macros", "tokio/time", @@ -59,6 +60,7 @@ tokio-util = { version = "0.6", features = ["codec"] } async-stream = "0.3" http-body = "0.4.2" pin-project = "1.0" +flate2 = "1.0" # prost prost1 = { package = "prost", version = "0.7", optional = true } @@ -73,6 +75,7 @@ hyper = { version = "0.14.2", features = ["full"], optional = true } tokio = { version = "1.0.1", features = ["net"], optional = true } tokio-stream = "0.1" tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true } +tower-http = { version = "0.1", features = ["set-header"], optional = true } tracing-futures = { version = "0.2", optional = true } # rustls diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 72a803e3f..56d628ea4 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -1,7 +1,7 @@ use crate::{ body::BoxBody, client::GrpcService, - codec::{encode_client, Codec, Streaming}, + codec::{compression::Encoding, encode_client, Codec, Streaming}, Code, Request, Response, Status, }; use futures_core::Stream; @@ -166,6 +166,8 @@ impl Grpc { .await .map_err(|err| Status::from_error(err.into()))?; + let encoding = Encoding::from_encoding_header(response.headers()); + let status_code = response.status(); let trailers_only_status = Status::from_header_map(response.headers()); @@ -183,7 +185,7 @@ impl Grpc { let response = response.map(|body| { if expect_additional_trailers { - Streaming::new_response(codec.decoder(), body, status_code) + Streaming::new_response(codec.decoder(), body, status_code, encoding) } else { Streaming::new_empty(codec.decoder(), body) } diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs new file mode 100644 index 000000000..2c1fe0b69 --- /dev/null +++ b/tonic/src/codec/compression.rs @@ -0,0 +1,118 @@ +use super::encode::BUFFER_SIZE; +use bytes::{Buf, BufMut, BytesMut}; +use flate2::read::{GzDecoder, GzEncoder}; + +pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; +pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; + +#[derive(Debug, Default, Clone, Copy)] +pub(crate) struct AcceptEncoding { + gzip: bool, +} + +impl AcceptEncoding { + pub(crate) fn gzip(self) -> Self { + AcceptEncoding { gzip: true } + } + + pub(crate) fn into_header_value(self) -> http::HeaderValue { + if self.gzip { + http::HeaderValue::from_static("gzip,identity") + } else { + http::HeaderValue::from_static("identity") + } + } +} + +#[derive(Clone, Copy, Debug)] +pub(crate) enum Encoding { + Gzip, +} + +impl Encoding { + pub(crate) fn from_accept_encoding_header(map: &http::HeaderMap) -> Option { + let header_value = map.get(ACCEPT_ENCODING_HEADER)?; + let header_value_str = header_value.to_str().ok()?; + + header_value_str + .trim() + .split(',') + .map(|value| value.trim()) + .find_map(|value| match value { + "gzip" => Some(Encoding::Gzip), + _ => None, + }) + } + + pub(crate) fn from_encoding_header(map: &http::HeaderMap) -> Option { + let header_value = map.get(ENCODING_HEADER)?; + let header_value_str = header_value.to_str().ok()?; + + match header_value_str { + "gzip" => Some(Encoding::Gzip), + _ => None, + } + } + + pub(crate) fn into_header_value(self) -> http::HeaderValue { + match self { + Encoding::Gzip => http::HeaderValue::from_static("gzip"), + } + } +} + +/// Compress `len` bytes from `in_buffer` into `out_buffer`. +pub(crate) fn compress( + encoding: Encoding, + in_buffer: &mut BytesMut, + out_buffer: &mut BytesMut, + len: usize, +) -> Result<(), std::io::Error> { + let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE; + out_buffer.reserve(capacity); + + // compressor.compress(in_buffer, out_buffer, len)?; + + match encoding { + Encoding::Gzip => { + let mut gzip_decoder = GzEncoder::new( + &in_buffer[0..len], + // TODO(david): what should compression level be? + flate2::Compression::new(6), + ); + let mut out_writer = out_buffer.writer(); + + // TODO(david): use spawn blocking here + std::io::copy(&mut gzip_decoder, &mut out_writer)?; + } + } + + in_buffer.advance(len); + + Ok(()) +} + +pub(crate) fn decompress( + encoding: Encoding, + in_buffer: &mut BytesMut, + out_buffer: &mut BytesMut, + len: usize, +) -> Result<(), std::io::Error> { + let estimate_decompressed_len = len * 2; + let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE; + out_buffer.reserve(capacity); + + match encoding { + Encoding::Gzip => { + let mut gzip_decoder = GzDecoder::new(&in_buffer[0..len]); + let mut out_writer = out_buffer.writer(); + + // TODO(david): use spawn blocking here + std::io::copy(&mut gzip_decoder, &mut out_writer)?; + } + } + + in_buffer.advance(len); + + Ok(()) +} diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index f8f0c23f1..d19110fa6 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -1,4 +1,4 @@ -use super::{DecodeBuf, Decoder}; +use super::{compression::Encoding, DecodeBuf, Decoder}; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use futures_core::Stream; @@ -24,7 +24,9 @@ pub struct Streaming { state: State, direction: Direction, buf: BytesMut, + decompress_buf: BytesMut, trailers: Option, + encoding: Option, } impl Unpin for Streaming {} @@ -43,13 +45,18 @@ enum Direction { } impl Streaming { - pub(crate) fn new_response(decoder: D, body: B, status_code: StatusCode) -> Self + pub(crate) fn new_response( + decoder: D, + body: B, + status_code: StatusCode, + encoding: Option, + ) -> Self where B: Body + Send + Sync + 'static, B::Error: Into, D: Decoder + Send + Sync + 'static, { - Self::new(decoder, body, Direction::Response(status_code)) + Self::new(decoder, body, Direction::Response(status_code), encoding) } pub(crate) fn new_empty(decoder: D, body: B) -> Self @@ -58,7 +65,7 @@ impl Streaming { B::Error: Into, D: Decoder + Send + Sync + 'static, { - Self::new(decoder, body, Direction::EmptyResponse) + Self::new(decoder, body, Direction::EmptyResponse, None) } #[doc(hidden)] @@ -68,10 +75,10 @@ impl Streaming { B::Error: Into, D: Decoder + Send + Sync + 'static, { - Self::new(decoder, body, Direction::Request) + Self::new(decoder, body, Direction::Request, None) } - fn new(decoder: D, body: B, direction: Direction) -> Self + fn new(decoder: D, body: B, direction: Direction, encoding: Option) -> Self where B: Body + Send + Sync + 'static, B::Error: Into, @@ -86,7 +93,9 @@ impl Streaming { state: State::ReadHeader, direction, buf: BytesMut::with_capacity(BUFFER_SIZE), + decompress_buf: BytesMut::new(), trailers: None, + encoding, } } } @@ -162,13 +171,7 @@ impl Streaming { let is_compressed = match self.buf.get_u8() { 0 => false, - 1 => { - trace!("message compressed, compression not supported yet"); - return Err(Status::new( - Code::Unimplemented, - "Message compressed, compression not supported yet.".to_string(), - )); - } + 1 => true, f => { trace!("unexpected compression flag"); let message = if let Direction::Response(status) = self.direction { @@ -191,24 +194,51 @@ impl Streaming { } } - if let State::ReadBody { len, .. } = &self.state { + if let State::ReadBody { len, compression } = &self.state { // if we haven't read enough of the message then return and keep // reading if self.buf.remaining() < *len || self.buf.len() < *len { return Ok(None); } - return match self - .decoder - .decode(&mut DecodeBuf::new(&mut self.buf, *len)) - { - Ok(Some(msg)) => { - self.state = State::ReadHeader; - Ok(Some(msg)) + let result = if *compression { + if let Err(err) = super::compression::decompress( + // TODO(david): handle missing self.encoding + self.encoding.unwrap(), + &mut self.buf, + &mut self.decompress_buf, + *len, + ) { + let message = if let Direction::Response(status) = self.direction { + format!( + "Error decompressing: {}, while receiving response with status: {}", + err, status + ) + } else { + format!("Error decompressing: {}, while sending request", err) + }; + return Err(Status::new(Code::Internal, message)); + } + let uncompressed_len = self.decompress_buf.len(); + self.decoder.decode(&mut DecodeBuf::new( + &mut self.decompress_buf, + uncompressed_len, + )) + } else { + match self + .decoder + .decode(&mut DecodeBuf::new(&mut self.buf, *len)) + { + Ok(Some(msg)) => { + self.state = State::ReadHeader; + Ok(Some(msg)) + } + Ok(None) => Ok(None), + Err(e) => Err(e), } - Ok(None) => Ok(None), - Err(e) => Err(e), }; + + return result; } Ok(None) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 58f4f99da..946274d92 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -1,4 +1,7 @@ -use super::{EncodeBuf, Encoder}; +use super::{ + compression::{compress, Encoding}, + EncodeBuf, Encoder, +}; use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; use futures_core::{Stream, TryStream}; @@ -11,18 +14,19 @@ use std::{ task::{Context, Poll}, }; -const BUFFER_SIZE: usize = 8 * 1024; +pub(super) const BUFFER_SIZE: usize = 8 * 1024; pub(crate) fn encode_server( encoder: T, source: U, + encoding: Option, ) -> EncodeBody>> where T: Encoder + Send + Sync + 'static, T::Item: Send + Sync, U: Stream> + Send + Sync + 'static, { - let stream = encode(encoder, source).into_stream(); + let stream = encode(encoder, source, encoding).into_stream(); EncodeBody::new_server(stream) } @@ -35,17 +39,28 @@ where T::Item: Send + Sync, U: Stream + Send + Sync + 'static, { - let stream = encode(encoder, source.map(Ok)).into_stream(); + // TODO(david): get encoding as argument? + let stream = encode(encoder, source.map(Ok), None).into_stream(); EncodeBody::new_client(stream) } -fn encode(mut encoder: T, source: U) -> impl TryStream +fn encode( + mut encoder: T, + source: U, + encoding: Option, +) -> impl TryStream where T: Encoder, U: Stream>, { async_stream::stream! { let mut buf = BytesMut::with_capacity(BUFFER_SIZE); + + let (compression_enabled, mut compression_buf) = match encoding { + Some(Encoding::Gzip) => (true, BytesMut::with_capacity(BUFFER_SIZE)), + None => (false, BytesMut::new()), + }; + futures_util::pin_mut!(source); loop { @@ -55,14 +70,25 @@ where unsafe { buf.advance_mut(5); } - encoder.encode(item, &mut EncodeBuf::new(&mut buf)).map_err(drop).unwrap(); + + if compression_enabled { + compression_buf.clear(); + encoder.encode(item, &mut EncodeBuf::new(&mut compression_buf)).map_err(drop).unwrap(); + let compressed_len = compression_buf.len(); + // TODO(david): handle error + compress(encoding.unwrap(), &mut compression_buf, &mut buf, compressed_len).expect("compression failed"); + } else { + encoder.encode(item, &mut EncodeBuf::new(&mut buf)).map_err(drop).unwrap(); + } // now that we know length, we can write the header let len = buf.len() - 5; assert!(len <= std::u32::MAX as usize); { let mut buf = &mut buf[..5]; - buf.put_u8(0); // byte must be 0, reserve doesn't auto-zero + + buf.put_u8(compression_enabled as u8); + buf.put_u32(len as u32); } diff --git a/tonic/src/codec/mod.rs b/tonic/src/codec/mod.rs index e100556c3..3d520dc9d 100644 --- a/tonic/src/codec/mod.rs +++ b/tonic/src/codec/mod.rs @@ -4,6 +4,7 @@ //! and a protobuf codec based on prost. mod buffer; +pub(crate) mod compression; mod decode; mod encode; #[cfg(feature = "prost")] diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index ddfa93e0e..28e85f4a5 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -119,7 +119,7 @@ mod tests { let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000); let source = futures_util::stream::iter(messages); - let body = encode_server(encoder, source); + let body = encode_server(encoder, source, None); futures_util::pin_mut!(body); @@ -216,6 +216,7 @@ mod tests { } } + #[allow(clippy::drop_ref)] fn poll_trailers( self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/tonic/src/metadata/map.rs b/tonic/src/metadata/map.rs index 461a9fb54..197976d22 100644 --- a/tonic/src/metadata/map.rs +++ b/tonic/src/metadata/map.rs @@ -200,12 +200,11 @@ pub(crate) const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout"; impl MetadataMap { // Headers reserved by the gRPC protocol. - pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 7] = [ + pub(crate) const GRPC_RESERVED_HEADERS: [&'static str; 6] = [ "te", "user-agent", "content-type", "grpc-message", - "grpc-encoding", "grpc-message-type", "grpc-status", ]; diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index e640ac8ed..e26a9db8b 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -1,6 +1,6 @@ use crate::{ body::BoxBody, - codec::{encode_server, Codec, Streaming}, + codec::{compression::Encoding, encode_server, Codec, Streaming}, server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService}, Code, Request, Status, }; @@ -43,13 +43,16 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { + let requested_encoding = Encoding::from_accept_encoding_header(req.headers()); + let request = match self.map_request_unary(req).await { Ok(r) => r, Err(status) => { return self - .map_response::>>>(Err( - status, - )); + .map_response::>>>( + Err(status), + requested_encoding, + ); } }; @@ -58,7 +61,7 @@ where .await .map(|r| r.map(|m| stream::once(future::ok(m)))); - self.map_response(response) + self.map_response(response, requested_encoding) } /// Handle a server side streaming request. @@ -73,16 +76,18 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { + // TODO(david): requested_encoding + let request = match self.map_request_unary(req).await { Ok(r) => r, Err(status) => { - return self.map_response::(Err(status)); + return self.map_response::(Err(status), None); } }; let response = service.call(request).await; - self.map_response(response) + self.map_response(response, None) } /// Handle a client side streaming gRPC request. @@ -96,12 +101,14 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send + 'static, { + // TODO(david): requested_encoding + let request = self.map_request_streaming(req); let response = service .call(request) .await .map(|r| r.map(|m| stream::once(future::ok(m)))); - self.map_response(response) + self.map_response(response, None) } /// Handle a bi-directional streaming gRPC request. @@ -116,9 +123,10 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { + // TODO(david): requested_encoding let request = self.map_request_streaming(req); let response = service.call(request).await; - self.map_response(response) + self.map_response(response, None) } async fn map_request_unary( @@ -162,26 +170,35 @@ where fn map_response( &mut self, response: Result, Status>, + requested_encoding: Option, ) -> http::Response where B: TryStream + Send + Sync + 'static, { - match response { - Ok(r) => { - let (mut parts, body) = r.into_http().into_parts(); + let response = match response { + Ok(r) => r, + Err(status) => return status.to_http(), + }; - // Set the content type - parts.headers.insert( - http::header::CONTENT_TYPE, - http::header::HeaderValue::from_static("application/grpc"), - ); + let (mut parts, body) = response.into_http().into_parts(); - let body = encode_server(self.codec.encoder(), body.into_stream()); + // Set the content type + parts.headers.insert( + http::header::CONTENT_TYPE, + http::header::HeaderValue::from_static("application/grpc"), + ); - http::Response::from_parts(parts, BoxBody::new(body)) - } - Err(status) => status.to_http(), + if let Some(encoding) = requested_encoding { + // Set the content encoding + parts.headers.insert( + crate::codec::compression::ENCODING_HEADER, + encoding.into_header_value(), + ); } + + let body = encode_server(self.codec.encoder(), body.into_stream(), requested_encoding); + + http::Response::from_parts(parts, BoxBody::new(body)) } } diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 204ab7bc9..6d5752b66 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -2,6 +2,7 @@ use super::super::service; use super::Channel; #[cfg(feature = "tls")] use super::ClientTlsConfig; +use crate::codec::compression::AcceptEncoding; #[cfg(feature = "tls")] use crate::transport::service::TlsConnector; use crate::transport::Error; @@ -39,6 +40,7 @@ pub struct Endpoint { pub(crate) http2_keep_alive_timeout: Option, pub(crate) http2_keep_alive_while_idle: Option, pub(crate) http2_adaptive_window: Option, + pub(crate) accept_encoding: AcceptEncoding, } impl Endpoint { @@ -240,6 +242,20 @@ impl Endpoint { } } + /// Enable `gzip` compression. + /// + /// This will tell the server that `gzip` compression is accepted. Messages compressed will be + /// automatically decompressed. + /// + /// Compression is not enabled by default. + // TODO(david): disabling compression on individual messages + pub fn gzip(self) -> Self { + Endpoint { + accept_encoding: self.accept_encoding.gzip(), + ..self + } + } + /// Create a channel from this config. pub async fn connect(&self) -> Result { let mut http = hyper::client::connect::HttpConnector::new(); @@ -329,6 +345,7 @@ impl From for Endpoint { http2_keep_alive_timeout: None, http2_keep_alive_while_idle: None, http2_adaptive_window: None, + accept_encoding: AcceptEncoding::default(), } } } diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index c2ecd8e39..d1f8d945e 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -10,7 +10,7 @@ pub use endpoint::Endpoint; pub use tls::ClientTlsConfig; use super::service::{Connection, DynamicServiceStream}; -use crate::body::BoxBody; +use crate::{body::BoxBody, codec::compression::AcceptEncoding}; use bytes::Bytes; use http::{ uri::{InvalidUri, Uri}, @@ -29,15 +29,16 @@ use tokio::{ sync::mpsc::{channel, Sender}, }; -use tower::balance::p2c::Balance; use tower::{ + balance::p2c::Balance, buffer::{self, Buffer}, discover::{Change, Discover}, - util::{BoxService, Either}, + util::BoxService, Service, }; +use tower_http::set_header::SetRequestHeader; -type Svc = Either, Response, crate::Error>>; +type Svc = BoxService, Response, crate::Error>; const DEFAULT_BUFFER_SIZE: usize = 1024; @@ -137,10 +138,13 @@ impl Channel { C::Future: Unpin + Send, C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, { - let buffer_size = endpoint.buffer_size.clone().unwrap_or(DEFAULT_BUFFER_SIZE); + let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); + let accept_encoding = endpoint.accept_encoding; let svc = Connection::lazy(connector, endpoint); - let svc = Buffer::new(Either::A(svc), buffer_size); + let svc = with_accept_encoding(svc, accept_encoding); + let svc = BoxService::new(svc); + let svc = Buffer::new(svc, buffer_size); Channel { svc } } @@ -152,12 +156,15 @@ impl Channel { C::Future: Unpin + Send, C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, { - let buffer_size = endpoint.buffer_size.clone().unwrap_or(DEFAULT_BUFFER_SIZE); + let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); + let accept_encoding = endpoint.accept_encoding; let svc = Connection::connect(connector, endpoint) .await .map_err(super::Error::from_source)?; - let svc = Buffer::new(Either::A(svc), buffer_size); + let svc = with_accept_encoding(svc, accept_encoding); + let svc = BoxService::new(svc); + let svc = Buffer::new(svc, buffer_size); Ok(Channel { svc }) } @@ -171,12 +178,24 @@ impl Channel { let svc = Balance::new(discover); let svc = BoxService::new(svc); - let svc = Buffer::new(Either::B(svc), buffer_size); + let svc = Buffer::new(svc, buffer_size); Channel { svc } } } +fn with_accept_encoding( + svc: S, + accept_encoding: AcceptEncoding, +) -> SetRequestHeader { + let header_value = accept_encoding.into_header_value(); + SetRequestHeader::overriding( + svc, + http::header::HeaderName::from_static(crate::codec::compression::ACCEPT_ENCODING_HEADER), + header_value, + ) +} + impl Service> for Channel { type Response = http::Response; type Error = super::Error; From 9ffb8d5c9454dc60a9b1c367b23b3e2662073d0c Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 25 Jun 2021 17:54:20 +0200 Subject: [PATCH 02/29] Support configuring compression on `Server` --- examples/src/blocking/client.rs | 2 +- tests/integration_tests/tests/compression.rs | 126 +++++++++++++-- tonic-build/src/client.rs | 8 +- tonic/Cargo.toml | 2 +- tonic/src/codec/compression.rs | 161 ++++++++++++++++--- tonic/src/codec/decode.rs | 7 +- tonic/src/codec/encode.rs | 13 +- tonic/src/transport/channel/endpoint.rs | 15 +- tonic/src/transport/channel/mod.rs | 6 +- tonic/src/transport/server/mod.rs | 23 ++- 10 files changed, 311 insertions(+), 52 deletions(-) diff --git a/examples/src/blocking/client.rs b/examples/src/blocking/client.rs index 12788e085..fe83348a9 100644 --- a/examples/src/blocking/client.rs +++ b/examples/src/blocking/client.rs @@ -27,7 +27,7 @@ impl BlockingClient { let rt = Builder::new_multi_thread().enable_all().build().unwrap(); let client = rt.block_on(GreeterClient::connect(dst))?; - Ok(Self { rt, client }) + Ok(Self { client, rt }) } pub fn say_hello( diff --git a/tests/integration_tests/tests/compression.rs b/tests/integration_tests/tests/compression.rs index f3a4d26cc..382c2c08a 100644 --- a/tests/integration_tests/tests/compression.rs +++ b/tests/integration_tests/tests/compression.rs @@ -1,9 +1,8 @@ use integration_tests::pb::{test_client, test_server, Input, Output}; -use std::net::SocketAddr; use tokio::net::TcpListener; use tonic::{ transport::{Channel, Server}, - Code, Request, Response, Status, + Request, Response, Status, }; use tower::Service; @@ -12,7 +11,9 @@ use tower::Service; // TODO(david): server streaming // TODO(david): bidirectional streaming -#[tokio::test] +// TODO(david): document that using a multi threaded tokio runtime is +// required (because of `block_in_place`) +#[tokio::test(flavor = "multi_thread")] async fn server_compressing_messages() { struct Svc; @@ -23,10 +24,10 @@ async fn server_compressing_messages() { } } - #[derive(Clone)] - struct Middleware(S); + #[derive(Clone, Copy)] + struct AssertCorrectAcceptEncoding(S); - impl Service> for Middleware + impl Service> for AssertCorrectAcceptEncoding where S: Service>, { @@ -42,10 +43,8 @@ async fn server_compressing_messages() { } fn call(&mut self, req: http::Request) -> Self::Future { - assert_eq!( - req.headers().get("grpc-accept-encoding").unwrap(), - "gzip,identity" - ); + println!("test middleware called"); + assert_eq!(req.headers().get("grpc-accept-encoding").unwrap(), "gzip"); self.0.call(req) } @@ -58,8 +57,8 @@ async fn server_compressing_messages() { tokio::spawn(async move { Server::builder() - .layer(tower::layer::layer_fn(Middleware)) - // .gzip() + .layer(tower::layer::layer_fn(AssertCorrectAcceptEncoding)) + .send_gzip() .add_service(svc) .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) .await @@ -67,7 +66,7 @@ async fn server_compressing_messages() { }); let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .gzip() + .accept_gzip() .connect() .await .unwrap(); @@ -78,3 +77,104 @@ async fn server_compressing_messages() { assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); } + +#[tokio::test(flavor = "multi_thread")] +async fn client_enabled_server_disabled() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, _req: Request) -> Result, Status> { + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + // no compression enable on the server so responses should not be compressed + Server::builder() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .accept_gzip() + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + let res = client.unary_call(Request::new(Input {})).await.unwrap(); + + assert!(res.metadata().get("grpc-encoding").is_none()); +} + +#[tokio::test(flavor = "multi_thread")] +async fn client_disabled() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, _req: Request) -> Result, Status> { + Ok(Response::new(Output {})) + } + } + + #[derive(Clone, Copy)] + struct AssertCorrectAcceptEncoding(S); + + impl Service> for AssertCorrectAcceptEncoding + where + S: Service>, + { + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.0.poll_ready(cx) + } + + fn call(&mut self, req: http::Request) -> Self::Future { + assert!(req.headers().get("grpc-accept-encoding").is_none()); + + self.0.call(req) + } + } + + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + Server::builder() + .layer(tower::layer::layer_fn(AssertCorrectAcceptEncoding)) + .send_gzip() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + let res = client.unary_call(Request::new(Input {})).await.unwrap(); + + assert!(res.metadata().get("grpc-encoding").is_none()); +} diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index be7d4247b..d20fced9d 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -153,10 +153,10 @@ fn generate_unary( &mut self, request: impl tonic::IntoRequest<#request>, ) -> Result, tonic::Status> { - self.inner.ready().await.map_err(|e| { - tonic::Status::new(tonic::Code::Unknown, format!("Service was not ready: {}", e.into())) - })?; - let codec = #codec_name::default(); + self.inner.ready().await.map_err(|e| { + tonic::Status::new(tonic::Code::Unknown, format!("Service was not ready: {}", e.into())) + })?; + let codec = #codec_name::default(); let path = http::uri::PathAndQuery::from_static(#path); self.inner.unary(request.into_request(), path, codec).await } diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 1bf098d13..97e31c186 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -72,7 +72,7 @@ async-trait = { version = "0.1.13", optional = true } # transport h2 = { version = "0.3", optional = true } hyper = { version = "0.14.2", features = ["full"], optional = true } -tokio = { version = "1.0.1", features = ["net"], optional = true } +tokio = { version = "1.0.1", features = ["net", "rt-multi-thread"], optional = true } tokio-stream = "0.1" tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true } tower-http = { version = "0.1", features = ["set-header"], optional = true } diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 2c1fe0b69..34735ced6 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -1,27 +1,97 @@ use super::encode::BUFFER_SIZE; use bytes::{Buf, BufMut, BytesMut}; use flate2::read::{GzDecoder, GzEncoder}; +use std::fmt::Write; pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; +/// Struct used to configure which encodings are enabled on a server or channel. #[derive(Debug, Default, Clone, Copy)] -pub(crate) struct AcceptEncoding { +pub(crate) struct EnabledEncodings { gzip: bool, } -impl AcceptEncoding { +impl EnabledEncodings { pub(crate) fn gzip(self) -> Self { - AcceptEncoding { gzip: true } + Self { gzip: true } } - pub(crate) fn into_header_value(self) -> http::HeaderValue { + pub(crate) fn into_accept_encoding_header_value(self) -> http::HeaderValue { if self.gzip { http::HeaderValue::from_static("gzip,identity") } else { http::HeaderValue::from_static("identity") } } + + /// Find the `grpc-accept-encoding` header and remove the encoding values that aren't enabled. + /// + /// For example a header value like `gzip,brotli,identity` where only `gzip` is enabled will + /// become `gzip`. + /// + /// This is used to remove disabled encodings from incoming requests in the server before they + /// each the actual `server::Grpc` service implementation. It is not possible to configure + /// `server::Grpc` so the configuration must be done at the `Server` level. + pub(crate) fn remove_disabled_encodings_from_accept_encoding(self, map: &mut http::HeaderMap) { + let accept_encoding = if let Some(accept_encoding) = map.remove(ACCEPT_ENCODING_HEADER) { + accept_encoding + } else { + return; + }; + + let accept_encoding_str = if let Ok(accept_encoding) = accept_encoding.to_str() { + accept_encoding + } else { + map.insert( + http::header::HeaderName::from_static(ACCEPT_ENCODING_HEADER), + accept_encoding, + ); + return; + }; + + // first check if we need to make changes to avoid allocating + let contains_disabled_encodings = + split_by_comma(accept_encoding_str).any(|encoding| match encoding { + "gzip" => !self.gzip, + _ => true, + }); + + if !contains_disabled_encodings { + // no changes necessary, put the original value back + map.insert( + http::header::HeaderName::from_static(ACCEPT_ENCODING_HEADER), + accept_encoding, + ); + return; + } + + // can be simplified when `Iterator::intersperse` is stable + let enabled_encodings = + split_by_comma(accept_encoding_str).filter_map(|encoding| match encoding { + "gzip" if self.gzip => Some("gzip"), + _ => None, + }); + + let mut new_value = String::new(); + let mut is_first = true; + + for encoding in enabled_encodings { + if is_first { + let _ = write!(new_value, "{}", encoding); + } else { + let _ = write!(new_value, ",{}", encoding); + }; + is_first = false; + } + + if !new_value.is_empty() { + map.insert( + http::header::HeaderName::from_static(ACCEPT_ENCODING_HEADER), + new_value.parse().unwrap(), + ); + } + } } #[derive(Clone, Copy, Debug)] @@ -34,14 +104,10 @@ impl Encoding { let header_value = map.get(ACCEPT_ENCODING_HEADER)?; let header_value_str = header_value.to_str().ok()?; - header_value_str - .trim() - .split(',') - .map(|value| value.trim()) - .find_map(|value| match value { - "gzip" => Some(Encoding::Gzip), - _ => None, - }) + split_by_comma(header_value_str).find_map(|value| match value { + "gzip" => Some(Encoding::Gzip), + _ => None, + }) } pub(crate) fn from_encoding_header(map: &http::HeaderMap) -> Option { @@ -61,6 +127,10 @@ impl Encoding { } } +fn split_by_comma(s: &str) -> impl Iterator { + s.trim().split(',').map(|s| s.trim()) +} + /// Compress `len` bytes from `in_buffer` into `out_buffer`. pub(crate) fn compress( encoding: Encoding, @@ -71,19 +141,16 @@ pub(crate) fn compress( let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE; out_buffer.reserve(capacity); - // compressor.compress(in_buffer, out_buffer, len)?; - match encoding { Encoding::Gzip => { let mut gzip_decoder = GzEncoder::new( &in_buffer[0..len], - // TODO(david): what should compression level be? + // FIXME: support customizing the compression level flate2::Compression::new(6), ); let mut out_writer = out_buffer.writer(); - // TODO(david): use spawn blocking here - std::io::copy(&mut gzip_decoder, &mut out_writer)?; + tokio::task::block_in_place(|| std::io::copy(&mut gzip_decoder, &mut out_writer))?; } } @@ -107,8 +174,7 @@ pub(crate) fn decompress( let mut gzip_decoder = GzDecoder::new(&in_buffer[0..len]); let mut out_writer = out_buffer.writer(); - // TODO(david): use spawn blocking here - std::io::copy(&mut gzip_decoder, &mut out_writer)?; + tokio::task::block_in_place(|| std::io::copy(&mut gzip_decoder, &mut out_writer))?; } } @@ -116,3 +182,60 @@ pub(crate) fn decompress( Ok(()) } + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use http::header::{HeaderMap, HeaderName}; + + #[test] + fn remove_disabled_encodings_empty_map() { + let mut map = HeaderMap::new(); + let encodings = EnabledEncodings { gzip: true }; + encodings.remove_disabled_encodings_from_accept_encoding(&mut map); + assert!(map.is_empty()); + } + + #[test] + fn remove_disabled_encodings_single_supported() { + let mut map = HeaderMap::new(); + map.insert( + HeaderName::from_static(ACCEPT_ENCODING_HEADER), + "gzip".parse().unwrap(), + ); + + let encodings = EnabledEncodings { gzip: true }; + encodings.remove_disabled_encodings_from_accept_encoding(&mut map); + + assert_eq!(&map[ACCEPT_ENCODING_HEADER], "gzip"); + } + + #[test] + fn remove_disabled_encodings_single_unsupported() { + let mut map = HeaderMap::new(); + map.insert( + HeaderName::from_static(ACCEPT_ENCODING_HEADER), + "gzip".parse().unwrap(), + ); + + let encodings = EnabledEncodings { gzip: false }; + encodings.remove_disabled_encodings_from_accept_encoding(&mut map); + + assert!(map.get(ACCEPT_ENCODING_HEADER).is_none()); + } + + #[test] + fn remove_disabled_encodings_multiple_supported() { + let mut map = HeaderMap::new(); + map.insert( + HeaderName::from_static(ACCEPT_ENCODING_HEADER), + "foo,gzip,identity".parse().unwrap(), + ); + + let encodings = EnabledEncodings { gzip: true }; + encodings.remove_disabled_encodings_from_accept_encoding(&mut map); + + assert_eq!(&map[ACCEPT_ENCODING_HEADER], "gzip"); + } +} diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index d19110fa6..831c19a37 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -1,4 +1,7 @@ -use super::{compression::Encoding, DecodeBuf, Decoder}; +use super::{ + compression::{decompress, Encoding}, + DecodeBuf, Decoder, +}; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use futures_core::Stream; @@ -202,7 +205,7 @@ impl Streaming { } let result = if *compression { - if let Err(err) = super::compression::decompress( + if let Err(err) = decompress( // TODO(david): handle missing self.encoding self.encoding.unwrap(), &mut self.buf, diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 946274d92..61bb82c7f 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -75,8 +75,17 @@ where compression_buf.clear(); encoder.encode(item, &mut EncodeBuf::new(&mut compression_buf)).map_err(drop).unwrap(); let compressed_len = compression_buf.len(); - // TODO(david): handle error - compress(encoding.unwrap(), &mut compression_buf, &mut buf, compressed_len).expect("compression failed"); + + let compress_result = compress( + encoding.unwrap(), + &mut compression_buf, + &mut buf, + compressed_len, + ); + + if let Err(err) = compress_result { + yield Err(Status::internal(format!("Error compressing: {}", err))) + } } else { encoder.encode(item, &mut EncodeBuf::new(&mut buf)).map_err(drop).unwrap(); } diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 6d5752b66..09bb7693e 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -2,7 +2,7 @@ use super::super::service; use super::Channel; #[cfg(feature = "tls")] use super::ClientTlsConfig; -use crate::codec::compression::AcceptEncoding; +use crate::codec::compression::EnabledEncodings; #[cfg(feature = "tls")] use crate::transport::service::TlsConnector; use crate::transport::Error; @@ -40,7 +40,7 @@ pub struct Endpoint { pub(crate) http2_keep_alive_timeout: Option, pub(crate) http2_keep_alive_while_idle: Option, pub(crate) http2_adaptive_window: Option, - pub(crate) accept_encoding: AcceptEncoding, + pub(crate) accept_encoding: EnabledEncodings, } impl Endpoint { @@ -242,14 +242,17 @@ impl Endpoint { } } - /// Enable `gzip` compression. + /// Enable `gzip` compressed responses. /// - /// This will tell the server that `gzip` compression is accepted. Messages compressed will be + /// This will tell the server that `gzip` compression is accepted. Messages will be /// automatically decompressed. /// + /// This does not compress messages sent by the client. + /// /// Compression is not enabled by default. // TODO(david): disabling compression on individual messages - pub fn gzip(self) -> Self { + // TODO(david): sending compressed messages + pub fn accept_gzip(self) -> Self { Endpoint { accept_encoding: self.accept_encoding.gzip(), ..self @@ -345,7 +348,7 @@ impl From for Endpoint { http2_keep_alive_timeout: None, http2_keep_alive_while_idle: None, http2_adaptive_window: None, - accept_encoding: AcceptEncoding::default(), + accept_encoding: EnabledEncodings::default(), } } } diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index d1f8d945e..e660f8bd5 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -10,7 +10,7 @@ pub use endpoint::Endpoint; pub use tls::ClientTlsConfig; use super::service::{Connection, DynamicServiceStream}; -use crate::{body::BoxBody, codec::compression::AcceptEncoding}; +use crate::{body::BoxBody, codec::compression::EnabledEncodings}; use bytes::Bytes; use http::{ uri::{InvalidUri, Uri}, @@ -186,9 +186,9 @@ impl Channel { fn with_accept_encoding( svc: S, - accept_encoding: AcceptEncoding, + accept_encoding: EnabledEncodings, ) -> SetRequestHeader { - let header_value = accept_encoding.into_header_value(); + let header_value = accept_encoding.into_accept_encoding_header_value(); SetRequestHeader::overriding( svc, http::header::HeaderName::from_static(crate::codec::compression::ACCEPT_ENCODING_HEADER), diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 7ccbc589d..24c6affbc 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -27,7 +27,7 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, Or, Routes, ServerIo}; -use crate::body::BoxBody; +use crate::{body::BoxBody, codec::compression::EnabledEncodings}; use bytes::Bytes; use futures_core::Stream; use futures_util::{ @@ -85,6 +85,7 @@ pub struct Server { max_frame_size: Option, accept_http1: bool, layer: L, + encodings: EnabledEncodings, } /// A stack based `Service` router. @@ -321,6 +322,14 @@ impl Server { } } + /// Compress outgoing messages with `gzip` if supported by the client. + pub fn send_gzip(self) -> Self { + Server { + encodings: self.encodings.gzip(), + ..self + } + } + /// Create a router with the `S` typed service as the first service. /// /// This will clone the `Server` builder and create a router that will @@ -446,6 +455,7 @@ impl Server { http2_keepalive_timeout: self.http2_keepalive_timeout, max_frame_size: self.max_frame_size, accept_http1: self.accept_http1, + encodings: EnabledEncodings::default(), } } @@ -476,6 +486,7 @@ impl Server { let timeout = self.timeout; let max_frame_size = self.max_frame_size; let http2_only = !self.accept_http1; + let encodings = self.encodings; let http2_keepalive_interval = self.http2_keepalive_interval; let http2_keepalive_timeout = self @@ -492,6 +503,7 @@ impl Server { concurrency_limit, timeout, trace_interceptor, + encodings, _io: PhantomData, }; @@ -757,6 +769,7 @@ impl fmt::Debug for Server { struct Svc { inner: S, trace_interceptor: Option, + encodings: EnabledEncodings, } impl Service> for Svc @@ -789,6 +802,11 @@ where tracing::Span::none() }; + // remove disabled disablings from `grpc-accept-encoding` so the inner service doesn't even + // seen them. + self.encodings + .remove_disabled_encodings_from_accept_encoding(req.headers_mut()); + SvcFuture { inner: self.inner.call(req), span, @@ -833,6 +851,7 @@ struct MakeSvc { timeout: Option, inner: S, trace_interceptor: Option, + encodings: EnabledEncodings, _io: PhantomData IO>, } @@ -860,6 +879,7 @@ where let concurrency_limit = self.concurrency_limit; let timeout = self.timeout; let trace_interceptor = self.trace_interceptor.clone(); + let encodings = self.encodings; let svc = ServiceBuilder::new() .layer_fn(RecoverError::new) @@ -895,6 +915,7 @@ where .service(Svc { inner: svc, trace_interceptor, + encodings, }); future::ready(Ok(svc)) From a807ddc9163e91de3cf1ad596a5c6c71a7fd0f06 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Fri, 25 Jun 2021 17:56:25 +0200 Subject: [PATCH 03/29] Minor clean up --- tests/integration_tests/tests/compression.rs | 41 ++++++-------------- 1 file changed, 11 insertions(+), 30 deletions(-) diff --git a/tests/integration_tests/tests/compression.rs b/tests/integration_tests/tests/compression.rs index 382c2c08a..f97e2987a 100644 --- a/tests/integration_tests/tests/compression.rs +++ b/tests/integration_tests/tests/compression.rs @@ -11,19 +11,21 @@ use tower::Service; // TODO(david): server streaming // TODO(david): bidirectional streaming +// TODO(david): somehow verify that compression is actually happening + +struct Svc; + +#[tonic::async_trait] +impl test_server::Test for Svc { + async fn unary_call(&self, _req: Request) -> Result, Status> { + Ok(Response::new(Output {})) + } +} + // TODO(david): document that using a multi threaded tokio runtime is // required (because of `block_in_place`) #[tokio::test(flavor = "multi_thread")] async fn server_compressing_messages() { - struct Svc; - - #[tonic::async_trait] - impl test_server::Test for Svc { - async fn unary_call(&self, _req: Request) -> Result, Status> { - Ok(Response::new(Output {})) - } - } - #[derive(Clone, Copy)] struct AssertCorrectAcceptEncoding(S); @@ -43,9 +45,7 @@ async fn server_compressing_messages() { } fn call(&mut self, req: http::Request) -> Self::Future { - println!("test middleware called"); assert_eq!(req.headers().get("grpc-accept-encoding").unwrap(), "gzip"); - self.0.call(req) } } @@ -80,15 +80,6 @@ async fn server_compressing_messages() { #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_disabled() { - struct Svc; - - #[tonic::async_trait] - impl test_server::Test for Svc { - async fn unary_call(&self, _req: Request) -> Result, Status> { - Ok(Response::new(Output {})) - } - } - let svc = test_server::TestServer::new(Svc); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -118,15 +109,6 @@ async fn client_enabled_server_disabled() { #[tokio::test(flavor = "multi_thread")] async fn client_disabled() { - struct Svc; - - #[tonic::async_trait] - impl test_server::Test for Svc { - async fn unary_call(&self, _req: Request) -> Result, Status> { - Ok(Response::new(Output {})) - } - } - #[derive(Clone, Copy)] struct AssertCorrectAcceptEncoding(S); @@ -147,7 +129,6 @@ async fn client_disabled() { fn call(&mut self, req: http::Request) -> Self::Future { assert!(req.headers().get("grpc-accept-encoding").is_none()); - self.0.call(req) } } From 46a12f612c1d39327d9f0e154b215d00f17a2d1c Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sat, 26 Jun 2021 11:52:37 +0200 Subject: [PATCH 04/29] Test that compression is actually happening --- Cargo.toml | 1 + tests/compression/Cargo.toml | 25 ++++ tests/compression/build.rs | 3 + tests/compression/proto/test.proto | 14 ++ .../compression.rs => compression/src/lib.rs} | 125 ++++++++++++++---- tests/compression/src/util.rs | 58 ++++++++ 6 files changed, 197 insertions(+), 29 deletions(-) create mode 100644 tests/compression/Cargo.toml create mode 100644 tests/compression/build.rs create mode 100644 tests/compression/proto/test.proto rename tests/{integration_tests/tests/compression.rs => compression/src/lib.rs} (51%) create mode 100644 tests/compression/src/util.rs diff --git a/Cargo.toml b/Cargo.toml index 05b7f6f4a..9fd8d2f99 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ members = [ "tests/integration_tests", "tests/stream_conflict", "tests/root-crate-path", + "tests/compression", "tonic-web/tests/integration" ] diff --git a/tests/compression/Cargo.toml b/tests/compression/Cargo.toml new file mode 100644 index 000000000..22f0985eb --- /dev/null +++ b/tests/compression/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "compression" +version = "0.1.0" +authors = ["Lucio Franco "] +edition = "2018" +publish = false +license = "MIT" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tonic = { path = "../../tonic" } +prost = "0.7" +tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] } +tower = { version = "0.4", features = [] } +http-body = "0.4" +http = "0.2" +tokio-stream = { version = "0.1.5", features = ["net"] } +tower-http = { version = "0.1", features = ["map-response-body"] } +bytes = "1" +futures = "0.3" +pin-project = "1.0" + +[build-dependencies] +tonic-build = { path = "../../tonic-build" } diff --git a/tests/compression/build.rs b/tests/compression/build.rs new file mode 100644 index 000000000..a091e9483 --- /dev/null +++ b/tests/compression/build.rs @@ -0,0 +1,3 @@ +fn main() { + tonic_build::compile_protos("proto/test.proto").unwrap(); +} diff --git a/tests/compression/proto/test.proto b/tests/compression/proto/test.proto new file mode 100644 index 000000000..e9df5aee5 --- /dev/null +++ b/tests/compression/proto/test.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package test; + +service Test { + rpc UnaryCall(Input) returns (Output); +} + +message Input {} + +message Output { + // include a bunch of data so there actually is something to compress + bytes data = 1; +} diff --git a/tests/integration_tests/tests/compression.rs b/tests/compression/src/lib.rs similarity index 51% rename from tests/integration_tests/tests/compression.rs rename to tests/compression/src/lib.rs index f97e2987a..e2d5656e7 100644 --- a/tests/integration_tests/tests/compression.rs +++ b/tests/compression/src/lib.rs @@ -1,31 +1,44 @@ -use integration_tests::pb::{test_client, test_server, Input, Output}; +#![allow(unused_imports)] + +tonic::include_proto!("test"); + +use std::sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, +}; use tokio::net::TcpListener; use tonic::{ transport::{Channel, Server}, Request, Response, Status, }; -use tower::Service; +use tower::{layer::layer_fn, Service, ServiceBuilder}; +use tower_http::map_response_body::MapResponseBodyLayer; + +mod util; // TODO(david): client copmressing messages // TODO(david): client streaming // TODO(david): server streaming // TODO(david): bidirectional streaming -// TODO(david): somehow verify that compression is actually happening - struct Svc; +const UNCOMPRESSED_MIN_BODY_SIZE: usize = 1024; + #[tonic::async_trait] impl test_server::Test for Svc { async fn unary_call(&self, _req: Request) -> Result, Status> { - Ok(Response::new(Output {})) + let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE]; + Ok(Response::new(Output { + data: data.to_vec(), + })) } } // TODO(david): document that using a multi threaded tokio runtime is // required (because of `block_in_place`) #[tokio::test(flavor = "multi_thread")] -async fn server_compressing_messages() { +async fn client_enabled_server_enabled() { #[derive(Clone, Copy)] struct AssertCorrectAcceptEncoding(S); @@ -55,14 +68,29 @@ async fn server_compressing_messages() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - tokio::spawn(async move { - Server::builder() - .layer(tower::layer::layer_fn(AssertCorrectAcceptEncoding)) - .send_gzip() - .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) - .await - .unwrap(); + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer(layer_fn(AssertCorrectAcceptEncoding)) + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .send_gzip() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } }); let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) @@ -76,6 +104,9 @@ async fn server_compressing_messages() { let res = client.unary_call(Request::new(Input {})).await.unwrap(); assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] @@ -85,13 +116,28 @@ async fn client_enabled_server_disabled() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - tokio::spawn(async move { - // no compression enable on the server so responses should not be compressed - Server::builder() - .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) - .await - .unwrap(); + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + // no compression enable on the server so responses should not be compressed + .layer( + ServiceBuilder::new() + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } }); let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) @@ -105,6 +151,9 @@ async fn client_enabled_server_disabled() { let res = client.unary_call(Request::new(Input {})).await.unwrap(); assert!(res.metadata().get("grpc-encoding").is_none()); + + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] @@ -138,14 +187,29 @@ async fn client_disabled() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - tokio::spawn(async move { - Server::builder() - .layer(tower::layer::layer_fn(AssertCorrectAcceptEncoding)) - .send_gzip() - .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) - .await - .unwrap(); + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer(layer_fn(AssertCorrectAcceptEncoding)) + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .send_gzip() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } }); let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) @@ -158,4 +222,7 @@ async fn client_disabled() { let res = client.unary_call(Request::new(Input {})).await.unwrap(); assert!(res.metadata().get("grpc-encoding").is_none()); + + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs new file mode 100644 index 000000000..17767fc1f --- /dev/null +++ b/tests/compression/src/util.rs @@ -0,0 +1,58 @@ +use bytes::Bytes; +use futures::ready; +use http_body::Body; +use pin_project::pin_project; +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, + }, + task::{Context, Poll}, +}; + +/// A body that tracks how many bytes passes through it +#[pin_project] +pub struct CountBytesBody { + #[pin] + pub inner: B, + pub counter: Arc, +} + +impl Body for CountBytesBody +where + B: Body, +{ + type Data = B::Data; + type Error = B::Error; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let this = self.project(); + let counter: Arc = this.counter.clone(); + match ready!(this.inner.poll_data(cx)) { + Some(Ok(chunk)) => { + counter.fetch_add(chunk.len(), Relaxed); + Poll::Ready(Some(Ok(chunk))) + } + x => Poll::Ready(x), + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + self.project().inner.poll_trailers(cx) + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + fn size_hint(&self) -> http_body::SizeHint { + self.inner.size_hint() + } +} From 3ba28588277c63b38c0b1c372039ea2b81afadd9 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sat, 26 Jun 2021 11:57:34 +0200 Subject: [PATCH 05/29] Clean up some todos --- tonic/src/codec/compression.rs | 1 + tonic/src/server/grpc.rs | 18 +++++++++--------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 34735ced6..7514e8f07 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -100,6 +100,7 @@ pub(crate) enum Encoding { } impl Encoding { + /// Based on the `grpc-accept-encoding` header, pick an encoding to use. pub(crate) fn from_accept_encoding_header(map: &http::HeaderMap) -> Option { let header_value = map.get(ACCEPT_ENCODING_HEADER)?; let header_value_str = header_value.to_str().ok()?; diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index e26a9db8b..09c40bc15 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -43,7 +43,7 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { - let requested_encoding = Encoding::from_accept_encoding_header(req.headers()); + let encoding = Encoding::from_accept_encoding_header(req.headers()); let request = match self.map_request_unary(req).await { Ok(r) => r, @@ -51,7 +51,7 @@ where return self .map_response::>>>( Err(status), - requested_encoding, + encoding, ); } }; @@ -61,7 +61,7 @@ where .await .map(|r| r.map(|m| stream::once(future::ok(m)))); - self.map_response(response, requested_encoding) + self.map_response(response, encoding) } /// Handle a server side streaming request. @@ -76,7 +76,7 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { - // TODO(david): requested_encoding + // TODO(david): encoding let request = match self.map_request_unary(req).await { Ok(r) => r, @@ -101,7 +101,7 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send + 'static, { - // TODO(david): requested_encoding + // TODO(david): encoding let request = self.map_request_streaming(req); let response = service @@ -123,7 +123,7 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { - // TODO(david): requested_encoding + // TODO(david): encoding let request = self.map_request_streaming(req); let response = service.call(request).await; self.map_response(response, None) @@ -170,7 +170,7 @@ where fn map_response( &mut self, response: Result, Status>, - requested_encoding: Option, + encoding: Option, ) -> http::Response where B: TryStream + Send + Sync + 'static, @@ -188,7 +188,7 @@ where http::header::HeaderValue::from_static("application/grpc"), ); - if let Some(encoding) = requested_encoding { + if let Some(encoding) = encoding { // Set the content encoding parts.headers.insert( crate::codec::compression::ENCODING_HEADER, @@ -196,7 +196,7 @@ where ); } - let body = encode_server(self.codec.encoder(), body.into_stream(), requested_encoding); + let body = encode_server(self.codec.encoder(), body.into_stream(), encoding); http::Response::from_parts(parts, BoxBody::new(body)) } From aae6015cbb65b9b2166954900608e8e451db9e7e Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sun, 27 Jun 2021 23:31:59 +0200 Subject: [PATCH 06/29] channels compressing requests --- tests/compression/Cargo.toml | 3 +- tests/compression/proto/test.proto | 9 +- tests/compression/src/compressing_request.rs | 74 +++++++ tests/compression/src/compressing_response.rs | 193 +++++++++++++++++ tests/compression/src/lib.rs | 205 +----------------- tonic/src/codec/compression.rs | 13 +- tonic/src/status.rs | 11 + tonic/src/transport/channel/endpoint.rs | 16 +- tonic/src/transport/channel/mod.rs | 103 ++++++++- tonic/src/transport/server/mod.rs | 26 ++- 10 files changed, 436 insertions(+), 217 deletions(-) create mode 100644 tests/compression/src/compressing_request.rs create mode 100644 tests/compression/src/compressing_response.rs diff --git a/tests/compression/Cargo.toml b/tests/compression/Cargo.toml index 22f0985eb..eb3516cc9 100644 --- a/tests/compression/Cargo.toml +++ b/tests/compression/Cargo.toml @@ -16,10 +16,11 @@ tower = { version = "0.4", features = [] } http-body = "0.4" http = "0.2" tokio-stream = { version = "0.1.5", features = ["net"] } -tower-http = { version = "0.1", features = ["map-response-body"] } +tower-http = { version = "0.1", features = ["map-response-body", "map-request-body"] } bytes = "1" futures = "0.3" pin-project = "1.0" +hyper = "0.14" [build-dependencies] tonic-build = { path = "../../tonic-build" } diff --git a/tests/compression/proto/test.proto b/tests/compression/proto/test.proto index e9df5aee5..824026c43 100644 --- a/tests/compression/proto/test.proto +++ b/tests/compression/proto/test.proto @@ -2,13 +2,14 @@ syntax = "proto3"; package test; +import "google/protobuf/empty.proto"; + service Test { - rpc UnaryCall(Input) returns (Output); + rpc CompressOutput(google.protobuf.Empty) returns (SomeData); + rpc CompressInput(SomeData) returns (google.protobuf.Empty); } -message Input {} - -message Output { +message SomeData { // include a bunch of data so there actually is something to compress bytes data = 1; } diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs new file mode 100644 index 000000000..98379bd9e --- /dev/null +++ b/tests/compression/src/compressing_request.rs @@ -0,0 +1,74 @@ +use super::*; +use http_body::Body as _; + +#[tokio::test(flavor = "multi_thread")] +async fn client_enabled_server_enabled() { + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + let measure_request_body_size_layer = { + let bytes_sent_counter = bytes_sent_counter.clone(); + MapRequestBodyLayer::new(move |mut body: hyper::Body| { + let (mut tx, new_body) = hyper::Body::channel(); + + let bytes_sent_counter = bytes_sent_counter.clone(); + tokio::spawn(async move { + while let Some(chunk) = body.data().await { + let chunk = chunk.unwrap(); + bytes_sent_counter.fetch_add(chunk.len(), Relaxed); + tx.send_data(chunk).await.unwrap(); + } + + if let Some(trailers) = body.trailers().await.unwrap() { + tx.send_trailers(trailers).await.unwrap(); + } + }); + + new_body + }) + }; + + tokio::spawn(async move { + Server::builder() + .layer( + ServiceBuilder::new() + // TODO(david): require request to have `grpc-encoding: gzip` + .layer( + ServiceBuilder::new() + .layer(measure_request_body_size_layer) + .into_inner(), + ) + .into_inner(), + ) + .accept_gzip() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .send_gzip() + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + client + .compress_input(SomeData { + data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), + }) + .await + .unwrap(); + + let bytes_sent = bytes_sent_counter.load(Relaxed); + dbg!(&bytes_sent); + assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); +} + +// TODO(david): send_gzip on channel, but disabling compression of a message diff --git a/tests/compression/src/compressing_response.rs b/tests/compression/src/compressing_response.rs new file mode 100644 index 000000000..7ee564562 --- /dev/null +++ b/tests/compression/src/compressing_response.rs @@ -0,0 +1,193 @@ +use super::*; + +// TODO(david): document that using a multi threaded tokio runtime is +// required (because of `block_in_place`) +#[tokio::test(flavor = "multi_thread")] +async fn client_enabled_server_enabled() { + #[derive(Clone, Copy)] + struct AssertCorrectAcceptEncoding(S); + + impl Service> for AssertCorrectAcceptEncoding + where + S: Service>, + { + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.0.poll_ready(cx) + } + + fn call(&mut self, req: http::Request) -> Self::Future { + assert_eq!(req.headers().get("grpc-accept-encoding").unwrap(), "gzip"); + self.0.call(req) + } + } + + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer(layer_fn(AssertCorrectAcceptEncoding)) + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .send_gzip() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .accept_gzip() + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + let res = client.compress_output(()).await.unwrap(); + + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); +} + +#[tokio::test(flavor = "multi_thread")] +async fn client_enabled_server_disabled() { + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + // no compression enable on the server so responses should not be compressed + .layer( + ServiceBuilder::new() + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .accept_gzip() + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + let res = client.compress_output(()).await.unwrap(); + + assert!(res.metadata().get("grpc-encoding").is_none()); + + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); +} + +#[tokio::test(flavor = "multi_thread")] +async fn client_disabled() { + #[derive(Clone, Copy)] + struct AssertCorrectAcceptEncoding(S); + + impl Service> for AssertCorrectAcceptEncoding + where + S: Service>, + { + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.0.poll_ready(cx) + } + + fn call(&mut self, req: http::Request) -> Self::Future { + assert!(req.headers().get("grpc-accept-encoding").is_none()); + self.0.call(req) + } + } + + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer(layer_fn(AssertCorrectAcceptEncoding)) + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .send_gzip() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + let res = client.compress_output(()).await.unwrap(); + + assert!(res.metadata().get("grpc-encoding").is_none()); + + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); +} diff --git a/tests/compression/src/lib.rs b/tests/compression/src/lib.rs index e2d5656e7..16c10eb11 100644 --- a/tests/compression/src/lib.rs +++ b/tests/compression/src/lib.rs @@ -1,7 +1,5 @@ #![allow(unused_imports)] -tonic::include_proto!("test"); - use std::sync::{ atomic::{AtomicUsize, Ordering::Relaxed}, Arc, @@ -12,10 +10,14 @@ use tonic::{ Request, Response, Status, }; use tower::{layer::layer_fn, Service, ServiceBuilder}; -use tower_http::map_response_body::MapResponseBodyLayer; +use tower_http::{map_request_body::MapRequestBodyLayer, map_response_body::MapResponseBodyLayer}; +mod compressing_request; +mod compressing_response; mod util; +tonic::include_proto!("test"); + // TODO(david): client copmressing messages // TODO(david): client streaming // TODO(david): server streaming @@ -27,202 +29,15 @@ const UNCOMPRESSED_MIN_BODY_SIZE: usize = 1024; #[tonic::async_trait] impl test_server::Test for Svc { - async fn unary_call(&self, _req: Request) -> Result, Status> { + async fn compress_output(&self, _req: Request<()>) -> Result, Status> { let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE]; - Ok(Response::new(Output { + Ok(Response::new(SomeData { data: data.to_vec(), })) } -} - -// TODO(david): document that using a multi threaded tokio runtime is -// required (because of `block_in_place`) -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_enabled() { - #[derive(Clone, Copy)] - struct AssertCorrectAcceptEncoding(S); - - impl Service> for AssertCorrectAcceptEncoding - where - S: Service>, - { - type Response = S::Response; - type Error = S::Error; - type Future = S::Future; - - fn poll_ready( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.0.poll_ready(cx) - } - - fn call(&mut self, req: http::Request) -> Self::Future { - assert_eq!(req.headers().get("grpc-accept-encoding").unwrap(), "gzip"); - self.0.call(req) - } - } - - let svc = test_server::TestServer::new(Svc); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); - - tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); - async move { - Server::builder() - .layer( - ServiceBuilder::new() - .layer(layer_fn(AssertCorrectAcceptEncoding)) - .layer(MapResponseBodyLayer::new(move |body| { - util::CountBytesBody { - inner: body, - counter: bytes_sent_counter.clone(), - } - })) - .into_inner(), - ) - .send_gzip() - .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) - .await - .unwrap(); - } - }); - - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .accept_gzip() - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel); - - let res = client.unary_call(Request::new(Input {})).await.unwrap(); - - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); - - let bytes_sent = bytes_sent_counter.load(Relaxed); - assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); -} - -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_disabled() { - let svc = test_server::TestServer::new(Svc); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); - - tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); - async move { - Server::builder() - // no compression enable on the server so responses should not be compressed - .layer( - ServiceBuilder::new() - .layer(MapResponseBodyLayer::new(move |body| { - util::CountBytesBody { - inner: body, - counter: bytes_sent_counter.clone(), - } - })) - .into_inner(), - ) - .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) - .await - .unwrap(); - } - }); - - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .accept_gzip() - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel); - - let res = client.unary_call(Request::new(Input {})).await.unwrap(); - - assert!(res.metadata().get("grpc-encoding").is_none()); - - let bytes_sent = bytes_sent_counter.load(Relaxed); - assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); -} - -#[tokio::test(flavor = "multi_thread")] -async fn client_disabled() { - #[derive(Clone, Copy)] - struct AssertCorrectAcceptEncoding(S); - - impl Service> for AssertCorrectAcceptEncoding - where - S: Service>, - { - type Response = S::Response; - type Error = S::Error; - type Future = S::Future; - - fn poll_ready( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.0.poll_ready(cx) - } - - fn call(&mut self, req: http::Request) -> Self::Future { - assert!(req.headers().get("grpc-accept-encoding").is_none()); - self.0.call(req) - } + async fn compress_input(&self, req: Request) -> Result, Status> { + assert_eq!(req.into_inner().data.len(), UNCOMPRESSED_MIN_BODY_SIZE); + Ok(Response::new(())) } - - let svc = test_server::TestServer::new(Svc); - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); - - tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); - async move { - Server::builder() - .layer( - ServiceBuilder::new() - .layer(layer_fn(AssertCorrectAcceptEncoding)) - .layer(MapResponseBodyLayer::new(move |body| { - util::CountBytesBody { - inner: body, - counter: bytes_sent_counter.clone(), - } - })) - .into_inner(), - ) - .send_gzip() - .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) - .await - .unwrap(); - } - }); - - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel); - - let res = client.unary_call(Request::new(Input {})).await.unwrap(); - - assert!(res.metadata().get("grpc-encoding").is_none()); - - let bytes_sent = bytes_sent_counter.load(Relaxed); - assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 7514e8f07..e40782404 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -133,19 +133,22 @@ fn split_by_comma(s: &str) -> impl Iterator { } /// Compress `len` bytes from `in_buffer` into `out_buffer`. -pub(crate) fn compress( +pub(crate) fn compress( encoding: Encoding, - in_buffer: &mut BytesMut, + in_buffer: &mut B, out_buffer: &mut BytesMut, len: usize, -) -> Result<(), std::io::Error> { +) -> Result<(), std::io::Error> +where + B: AsRef<[u8]> + bytes::Buf, +{ let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE; out_buffer.reserve(capacity); match encoding { Encoding::Gzip => { let mut gzip_decoder = GzEncoder::new( - &in_buffer[0..len], + &in_buffer.as_ref()[0..len], // FIXME: support customizing the compression level flate2::Compression::new(6), ); @@ -155,6 +158,8 @@ pub(crate) fn compress( } } + // TODO(david): is this necessary? test sending multiple requests and + // responses on the same channel in_buffer.advance(len); Ok(()) diff --git a/tonic/src/status.rs b/tonic/src/status.rs index e5bb0dc56..e9f79bfef 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -334,6 +334,17 @@ impl Status { Err(err) } + /// Set the source error of the status + pub(crate) fn with_source(self, source: T) -> Self + where + T: Into>, + { + Self { + source: Some(source.into()), + ..self + } + } + // FIXME: bubble this into `transport` and expose generic http2 reasons. #[cfg(feature = "transport")] fn from_h2_error(err: &h2::Error) -> Status { diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 09bb7693e..8b076b7eb 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -2,7 +2,7 @@ use super::super::service; use super::Channel; #[cfg(feature = "tls")] use super::ClientTlsConfig; -use crate::codec::compression::EnabledEncodings; +use crate::codec::compression::{EnabledEncodings, Encoding}; #[cfg(feature = "tls")] use crate::transport::service::TlsConnector; use crate::transport::Error; @@ -41,6 +41,7 @@ pub struct Endpoint { pub(crate) http2_keep_alive_while_idle: Option, pub(crate) http2_adaptive_window: Option, pub(crate) accept_encoding: EnabledEncodings, + pub(crate) send_encoding: Option, } impl Endpoint { @@ -251,7 +252,6 @@ impl Endpoint { /// /// Compression is not enabled by default. // TODO(david): disabling compression on individual messages - // TODO(david): sending compressed messages pub fn accept_gzip(self) -> Self { Endpoint { accept_encoding: self.accept_encoding.gzip(), @@ -259,6 +259,17 @@ impl Endpoint { } } + /// Compress requests with `gzip`. + /// + /// This requires the server to accept `gzip` compressed requests otherwise it might + /// respond with an error. + pub fn send_gzip(self) -> Self { + Endpoint { + send_encoding: Some(Encoding::Gzip), + ..self + } + } + /// Create a channel from this config. pub async fn connect(&self) -> Result { let mut http = hyper::client::connect::HttpConnector::new(); @@ -349,6 +360,7 @@ impl From for Endpoint { http2_keep_alive_while_idle: None, http2_adaptive_window: None, accept_encoding: EnabledEncodings::default(), + send_encoding: None, } } } diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index e660f8bd5..83b6d26d0 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -10,13 +10,18 @@ pub use endpoint::Endpoint; pub use tls::ClientTlsConfig; use super::service::{Connection, DynamicServiceStream}; -use crate::{body::BoxBody, codec::compression::EnabledEncodings}; -use bytes::Bytes; +use crate::{ + body::BoxBody, + codec::compression::{EnabledEncodings, Encoding}, +}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use http::{ uri::{InvalidUri, Uri}, Request, Response, }; +use http_body::Body as _; use hyper::client::connect::Connection as HyperConnection; +use pin_project::pin_project; use std::{ fmt, future::Future, @@ -68,6 +73,8 @@ const DEFAULT_BUFFER_SIZE: usize = 1024; #[derive(Clone)] pub struct Channel { svc: Buffer>, + /// The encoding that request bodies will be compressed with. + send_encoding: Option, } /// A future that resolves to an HTTP response. @@ -140,13 +147,14 @@ impl Channel { { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); let accept_encoding = endpoint.accept_encoding; + let send_encoding = endpoint.send_encoding; let svc = Connection::lazy(connector, endpoint); let svc = with_accept_encoding(svc, accept_encoding); let svc = BoxService::new(svc); let svc = Buffer::new(svc, buffer_size); - Channel { svc } + Channel { svc, send_encoding } } pub(crate) async fn connect(connector: C, endpoint: Endpoint) -> Result @@ -158,6 +166,7 @@ impl Channel { { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); let accept_encoding = endpoint.accept_encoding; + let send_encoding = endpoint.send_encoding; let svc = Connection::connect(connector, endpoint) .await @@ -166,7 +175,7 @@ impl Channel { let svc = BoxService::new(svc); let svc = Buffer::new(svc, buffer_size); - Ok(Channel { svc }) + Ok(Channel { svc, send_encoding }) } pub(crate) fn balance(discover: D, buffer_size: usize) -> Self @@ -180,7 +189,10 @@ impl Channel { let svc = BoxService::new(svc); let svc = Buffer::new(svc, buffer_size); - Channel { svc } + Channel { + svc, + send_encoding: None, + } } } @@ -206,8 +218,26 @@ impl Service> for Channel { } fn call(&mut self, request: http::Request) -> Self::Future { + let (mut parts, body) = request.into_parts(); + + let new_body = if let Some(encoding) = self.send_encoding { + parts.headers.insert( + crate::codec::compression::ENCODING_HEADER, + encoding.into_header_value(), + ); + + CompressEachChunkBody { + inner: body, + encoding, + encoding_buf: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE), + } + .boxed() + } else { + body + }; + + let request = http::Request::from_parts(parts, new_body); let inner = Service::call(&mut self.svc, request); - ResponseFuture { inner } } } @@ -233,3 +263,64 @@ impl fmt::Debug for ResponseFuture { f.debug_struct("ResponseFuture").finish() } } + +/// A `http_body::Body` that compresses each chunk with a given encoding. +#[pin_project] +struct CompressEachChunkBody { + #[pin] + inner: B, + encoding: Encoding, + encoding_buf: BytesMut, +} + +impl http_body::Body for CompressEachChunkBody +where + B: http_body::Body, +{ + type Data = Bytes; + type Error = crate::Status; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let this = self.project(); + match futures_util::ready!(this.inner.poll_data(cx)) { + Some(Ok(mut chunk)) => { + let len = chunk.len(); + + this.encoding_buf.clear(); + + if let Err(err) = crate::codec::compression::compress( + *this.encoding, + &mut chunk, + this.encoding_buf, + len, + ) { + let status = + crate::Status::internal("Failed to compress body chunk").with_source(err); + return Poll::Ready(Some(Err(status))); + } + + let chunk = this.encoding_buf.clone().freeze(); + + Poll::Ready(Some(Ok(chunk))) + } + other => Poll::Ready(other), + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + self.project().inner.poll_trailers(cx) + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + // we don't define `size_hint` because we compress each + // chunk and dunno the size +} diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 24c6affbc..5d42d0852 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -85,7 +85,8 @@ pub struct Server { max_frame_size: Option, accept_http1: bool, layer: L, - encodings: EnabledEncodings, + send_encodings: EnabledEncodings, + accept_encodings: EnabledEncodings, } /// A stack based `Service` router. @@ -325,7 +326,15 @@ impl Server { /// Compress outgoing messages with `gzip` if supported by the client. pub fn send_gzip(self) -> Self { Server { - encodings: self.encodings.gzip(), + send_encodings: self.send_encodings.gzip(), + ..self + } + } + + /// Accept requests compressed with `gzip`. + pub fn accept_gzip(self) -> Self { + Server { + accept_encodings: self.accept_encodings.gzip(), ..self } } @@ -455,7 +464,8 @@ impl Server { http2_keepalive_timeout: self.http2_keepalive_timeout, max_frame_size: self.max_frame_size, accept_http1: self.accept_http1, - encodings: EnabledEncodings::default(), + send_encodings: EnabledEncodings::default(), + accept_encodings: EnabledEncodings::default(), } } @@ -486,7 +496,7 @@ impl Server { let timeout = self.timeout; let max_frame_size = self.max_frame_size; let http2_only = !self.accept_http1; - let encodings = self.encodings; + let encodings = self.send_encodings; let http2_keepalive_interval = self.http2_keepalive_interval; let http2_keepalive_timeout = self @@ -788,6 +798,12 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { + if let Some(value) = req.headers().get("grpc-encoding") { + if value == "gzip" { + todo!() + } + } + let span = if let Some(trace_interceptor) = &self.trace_interceptor { let (parts, body) = req.into_parts(); let bodyless_request = Request::from_parts(parts, ()); @@ -802,7 +818,7 @@ where tracing::Span::none() }; - // remove disabled disablings from `grpc-accept-encoding` so the inner service doesn't even + // remove disabled encodings from `grpc-accept-encoding` so the inner service doesn't even // seen them. self.encodings .remove_disabled_encodings_from_accept_encoding(req.headers_mut()); From bd3ab36067efb08fed5d37ec79c9980771c3b738 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 28 Jun 2021 11:53:31 +0200 Subject: [PATCH 07/29] Move compression to be on the codecs --- tests/compression/src/compressing_request.rs | 6 +- tests/compression/src/compressing_response.rs | 17 ++- tonic-build/src/client.rs | 12 ++ tonic-build/src/server.rs | 47 +++++-- tonic/src/client/grpc.rs | 58 ++++++++- tonic/src/codec/compression.rs | 66 ++++++---- tonic/src/codec/decode.rs | 20 +-- tonic/src/codec/encode.rs | 18 +-- tonic/src/codec/mod.rs | 8 +- tonic/src/codec/prost.rs | 2 +- tonic/src/codegen.rs | 1 + tonic/src/server/grpc.rs | 81 +++++++++++- tonic/src/transport/channel/endpoint.rs | 32 ----- tonic/src/transport/channel/mod.rs | 116 +----------------- tonic/src/transport/server/mod.rs | 39 +----- 15 files changed, 268 insertions(+), 255 deletions(-) diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index 98379bd9e..636abbe34 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -3,7 +3,7 @@ use http_body::Body as _; #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_enabled() { - let svc = test_server::TestServer::new(Svc); + let svc = test_server::TestServer::new(Svc).accept_gzip(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -44,7 +44,6 @@ async fn client_enabled_server_enabled() { ) .into_inner(), ) - .accept_gzip() .add_service(svc) .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) .await @@ -52,12 +51,11 @@ async fn client_enabled_server_enabled() { }); let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .send_gzip() .connect() .await .unwrap(); - let mut client = test_client::TestClient::new(channel); + let mut client = test_client::TestClient::new(channel).send_gzip(); client .compress_input(SomeData { diff --git a/tests/compression/src/compressing_response.rs b/tests/compression/src/compressing_response.rs index 7ee564562..a6d68e130 100644 --- a/tests/compression/src/compressing_response.rs +++ b/tests/compression/src/compressing_response.rs @@ -23,12 +23,15 @@ async fn client_enabled_server_enabled() { } fn call(&mut self, req: http::Request) -> Self::Future { - assert_eq!(req.headers().get("grpc-accept-encoding").unwrap(), "gzip"); + assert_eq!( + req.headers().get("grpc-accept-encoding").unwrap(), + "gzip,identity" + ); self.0.call(req) } } - let svc = test_server::TestServer::new(Svc); + let svc = test_server::TestServer::new(Svc).send_gzip(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -50,7 +53,6 @@ async fn client_enabled_server_enabled() { })) .into_inner(), ) - .send_gzip() .add_service(svc) .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) .await @@ -59,12 +61,11 @@ async fn client_enabled_server_enabled() { }); let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .accept_gzip() .connect() .await .unwrap(); - let mut client = test_client::TestClient::new(channel); + let mut client = test_client::TestClient::new(channel).accept_gzip(); let res = client.compress_output(()).await.unwrap(); @@ -106,12 +107,11 @@ async fn client_enabled_server_disabled() { }); let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .accept_gzip() .connect() .await .unwrap(); - let mut client = test_client::TestClient::new(channel); + let mut client = test_client::TestClient::new(channel).accept_gzip(); let res = client.compress_output(()).await.unwrap(); @@ -147,7 +147,7 @@ async fn client_disabled() { } } - let svc = test_server::TestServer::new(Svc); + let svc = test_server::TestServer::new(Svc).send_gzip(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -169,7 +169,6 @@ async fn client_disabled() { })) .into_inner(), ) - .send_gzip() .add_service(svc) .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) .await diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index d20fced9d..984528010 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -59,6 +59,18 @@ pub fn generate( #service_ident::new(InterceptedService::new(inner, interceptor)) } + // TODO(david): docs + pub fn send_gzip(mut self) -> Self { + self.inner = self.inner.send_gzip(); + self + } + + // TODO(david): docs + pub fn accept_gzip(mut self) -> Self { + self.inner = self.inner.accept_gzip(); + self + } + #methods } diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index 517917540..ae0a976d7 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -48,6 +48,8 @@ pub fn generate( #[derive(Debug)] pub struct #server_service { inner: _Inner, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, } struct _Inner(Arc); @@ -56,7 +58,11 @@ pub fn generate( pub fn new(inner: T) -> Self { let inner = Arc::new(inner); let inner = _Inner(inner); - Self { inner } + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + } } pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService @@ -65,6 +71,18 @@ pub fn generate( { InterceptedService::new(Self::new(inner), interceptor) } + + // TODO(david): docs + pub fn accept_gzip(mut self) -> Self { + self.accept_compression_encodings.enable_gzip(); + self + } + + // TODO(david): docs + pub fn send_gzip(mut self) -> Self { + self.send_compression_encodings.enable_gzip(); + self + } } impl Service> for #server_service @@ -102,7 +120,11 @@ pub fn generate( impl Clone for #server_service { fn clone(&self) -> Self { let inner = self.inner.clone(); - Self { inner } + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + } } } @@ -335,13 +357,16 @@ fn generate_unary( } } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; let method = #service_ident(inner); let codec = #codec_name::default(); - let mut grpc = tonic::server::Grpc::new(codec); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config(accept_compression_encodings, send_compression_encodings); let res = grpc.unary(method, req).await; Ok(res) @@ -379,19 +404,21 @@ fn generate_server_streaming( let inner = self.0.clone(); let fut = async move { (*inner).#method_ident(request).await - }; Box::pin(fut) } } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; let method = #service_ident(inner); let codec = #codec_name::default(); - let mut grpc = tonic::server::Grpc::new(codec); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config(accept_compression_encodings, send_compression_encodings); let res = grpc.server_streaming(method, req).await; Ok(res) @@ -432,13 +459,16 @@ fn generate_client_streaming( } } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; let method = #service_ident(inner); let codec = #codec_name::default(); - let mut grpc = tonic::server::Grpc::new(codec); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config(accept_compression_encodings, send_compression_encodings); let res = grpc.client_streaming(method, req).await; Ok(res) @@ -482,13 +512,16 @@ fn generate_streaming( } } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; let method = #service_ident(inner); let codec = #codec_name::default(); - let mut grpc = tonic::server::Grpc::new(codec); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config(accept_compression_encodings, send_compression_encodings); let res = grpc.streaming(method, req).await; Ok(res) diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 56d628ea4..814f37e3c 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -1,7 +1,10 @@ use crate::{ body::BoxBody, client::GrpcService, - codec::{compression::Encoding, encode_client, Codec, Streaming}, + codec::{ + compression::CompressionEncoding, encode_client, Codec, EnabledCompressionEncodings, + Streaming, + }, Code, Request, Response, Status, }; use futures_core::Stream; @@ -28,12 +31,30 @@ use std::fmt; /// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests pub struct Grpc { inner: T, + /// Which compression encodings does the client accept? + accept_compression_encodings: EnabledCompressionEncodings, + /// The compression encoding that will be applied to requests. + send_compression_encodings: Option, } impl Grpc { /// Creates a new gRPC client with the provided [`GrpcService`]. pub fn new(inner: T) -> Self { - Self { inner } + Self { + inner, + send_compression_encodings: None, + accept_compression_encodings: EnabledCompressionEncodings::default(), + } + } + + pub fn send_gzip(mut self) -> Self { + self.send_compression_encodings = Some(CompressionEncoding::Gzip); + self + } + + pub fn accept_gzip(mut self) -> Self { + self.accept_compression_encodings.enable_gzip(); + self } /// Check if the inner [`GrpcService`] is able to accept a new request. @@ -145,7 +166,7 @@ impl Grpc { let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri"); let request = request - .map(|s| encode_client(codec.encoder(), s)) + .map(|s| encode_client(codec.encoder(), s, self.send_compression_encodings)) .map(BoxBody::new); let mut request = request.into_http(uri); @@ -160,13 +181,31 @@ impl Grpc { .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("application/grpc")); + if let Some(encoding) = self.send_compression_encodings { + request.headers_mut().insert( + crate::codec::compression::ENCODING_HEADER, + encoding.into_header_value(), + ); + } + + if let Some(header_value) = self + .accept_compression_encodings + .into_accept_encoding_header_value() + { + request.headers_mut().insert( + crate::codec::compression::ACCEPT_ENCODING_HEADER, + header_value, + ); + } + let response = self .inner .call(request) .await .map_err(|err| Status::from_error(err.into()))?; - let encoding = Encoding::from_encoding_header(response.headers()); + // TODO(david): server compressing with algorithm the client doesn't know + let encoding = CompressionEncoding::from_encoding_header(response.headers()); let status_code = response.status(); let trailers_only_status = Status::from_header_map(response.headers()); @@ -199,12 +238,21 @@ impl Clone for Grpc { fn clone(&self) -> Self { Self { inner: self.inner.clone(), + send_compression_encodings: self.send_compression_encodings, + accept_compression_encodings: self.accept_compression_encodings, } } } impl fmt::Debug for Grpc { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Grpc").field("inner", &self.inner).finish() + f.debug_struct("Grpc") + .field("inner", &self.inner) + .field("compression_encoding", &self.send_compression_encodings) + .field( + "accept_compression_encodings", + &self.accept_compression_encodings, + ) + .finish() } } diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index e40782404..e5b5a1af7 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -1,27 +1,32 @@ use super::encode::BUFFER_SIZE; use bytes::{Buf, BufMut, BytesMut}; use flate2::read::{GzDecoder, GzEncoder}; -use std::fmt::Write; +use std::fmt::{self, Write}; pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; /// Struct used to configure which encodings are enabled on a server or channel. #[derive(Debug, Default, Clone, Copy)] -pub(crate) struct EnabledEncodings { - gzip: bool, +pub struct EnabledCompressionEncodings { + pub(crate) gzip: bool, } -impl EnabledEncodings { - pub(crate) fn gzip(self) -> Self { - Self { gzip: true } +impl EnabledCompressionEncodings { + pub(crate) fn gzip(self) -> bool { + self.gzip } - pub(crate) fn into_accept_encoding_header_value(self) -> http::HeaderValue { + /// Enable `gzip` compression. + pub fn enable_gzip(&mut self) { + self.gzip = true; + } + + pub(crate) fn into_accept_encoding_header_value(self) -> Option { if self.gzip { - http::HeaderValue::from_static("gzip,identity") + Some(http::HeaderValue::from_static("gzip,identity")) } else { - http::HeaderValue::from_static("identity") + None } } @@ -95,18 +100,23 @@ impl EnabledEncodings { } #[derive(Clone, Copy, Debug)] -pub(crate) enum Encoding { +#[non_exhaustive] +#[doc(hidden)] +pub enum CompressionEncoding { Gzip, } -impl Encoding { +impl CompressionEncoding { /// Based on the `grpc-accept-encoding` header, pick an encoding to use. - pub(crate) fn from_accept_encoding_header(map: &http::HeaderMap) -> Option { + pub(crate) fn from_accept_encoding_header( + map: &http::HeaderMap, + enabled_encodings: EnabledCompressionEncodings, + ) -> Option { let header_value = map.get(ACCEPT_ENCODING_HEADER)?; let header_value_str = header_value.to_str().ok()?; split_by_comma(header_value_str).find_map(|value| match value { - "gzip" => Some(Encoding::Gzip), + "gzip" if enabled_encodings.gzip() => Some(CompressionEncoding::Gzip), _ => None, }) } @@ -116,14 +126,22 @@ impl Encoding { let header_value_str = header_value.to_str().ok()?; match header_value_str { - "gzip" => Some(Encoding::Gzip), + "gzip" => Some(CompressionEncoding::Gzip), _ => None, } } pub(crate) fn into_header_value(self) -> http::HeaderValue { match self { - Encoding::Gzip => http::HeaderValue::from_static("gzip"), + CompressionEncoding::Gzip => http::HeaderValue::from_static("gzip"), + } + } +} + +impl fmt::Display for CompressionEncoding { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CompressionEncoding::Gzip => write!(f, "gzip"), } } } @@ -134,7 +152,7 @@ fn split_by_comma(s: &str) -> impl Iterator { /// Compress `len` bytes from `in_buffer` into `out_buffer`. pub(crate) fn compress( - encoding: Encoding, + encoding: CompressionEncoding, in_buffer: &mut B, out_buffer: &mut BytesMut, len: usize, @@ -146,7 +164,7 @@ where out_buffer.reserve(capacity); match encoding { - Encoding::Gzip => { + CompressionEncoding::Gzip => { let mut gzip_decoder = GzEncoder::new( &in_buffer.as_ref()[0..len], // FIXME: support customizing the compression level @@ -166,7 +184,7 @@ where } pub(crate) fn decompress( - encoding: Encoding, + encoding: CompressionEncoding, in_buffer: &mut BytesMut, out_buffer: &mut BytesMut, len: usize, @@ -176,7 +194,7 @@ pub(crate) fn decompress( out_buffer.reserve(capacity); match encoding { - Encoding::Gzip => { + CompressionEncoding::Gzip => { let mut gzip_decoder = GzDecoder::new(&in_buffer[0..len]); let mut out_writer = out_buffer.writer(); @@ -184,6 +202,8 @@ pub(crate) fn decompress( } } + // TODO(david): is this necessary? test sending multiple requests and + // responses on the same channel in_buffer.advance(len); Ok(()) @@ -198,7 +218,7 @@ mod tests { #[test] fn remove_disabled_encodings_empty_map() { let mut map = HeaderMap::new(); - let encodings = EnabledEncodings { gzip: true }; + let encodings = EnabledCompressionEncodings { gzip: true }; encodings.remove_disabled_encodings_from_accept_encoding(&mut map); assert!(map.is_empty()); } @@ -211,7 +231,7 @@ mod tests { "gzip".parse().unwrap(), ); - let encodings = EnabledEncodings { gzip: true }; + let encodings = EnabledCompressionEncodings { gzip: true }; encodings.remove_disabled_encodings_from_accept_encoding(&mut map); assert_eq!(&map[ACCEPT_ENCODING_HEADER], "gzip"); @@ -225,7 +245,7 @@ mod tests { "gzip".parse().unwrap(), ); - let encodings = EnabledEncodings { gzip: false }; + let encodings = EnabledCompressionEncodings { gzip: false }; encodings.remove_disabled_encodings_from_accept_encoding(&mut map); assert!(map.get(ACCEPT_ENCODING_HEADER).is_none()); @@ -239,7 +259,7 @@ mod tests { "foo,gzip,identity".parse().unwrap(), ); - let encodings = EnabledEncodings { gzip: true }; + let encodings = EnabledCompressionEncodings { gzip: true }; encodings.remove_disabled_encodings_from_accept_encoding(&mut map); assert_eq!(&map[ACCEPT_ENCODING_HEADER], "gzip"); diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 831c19a37..a34aa33b8 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -1,5 +1,5 @@ use super::{ - compression::{decompress, Encoding}, + compression::{decompress, CompressionEncoding}, DecodeBuf, Decoder, }; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; @@ -29,7 +29,7 @@ pub struct Streaming { buf: BytesMut, decompress_buf: BytesMut, trailers: Option, - encoding: Option, + encoding: Option, } impl Unpin for Streaming {} @@ -52,7 +52,7 @@ impl Streaming { decoder: D, body: B, status_code: StatusCode, - encoding: Option, + encoding: Option, ) -> Self where B: Body + Send + Sync + 'static, @@ -72,16 +72,21 @@ impl Streaming { } #[doc(hidden)] - pub fn new_request(decoder: D, body: B) -> Self + pub fn new_request(decoder: D, body: B, encoding: Option) -> Self where B: Body + Send + Sync + 'static, B::Error: Into, D: Decoder + Send + Sync + 'static, { - Self::new(decoder, body, Direction::Request, None) + Self::new(decoder, body, Direction::Request, encoding) } - fn new(decoder: D, body: B, direction: Direction, encoding: Option) -> Self + fn new( + decoder: D, + body: B, + direction: Direction, + encoding: Option, + ) -> Self where B: Body + Send + Sync + 'static, B::Error: Into, @@ -207,7 +212,8 @@ impl Streaming { let result = if *compression { if let Err(err) = decompress( // TODO(david): handle missing self.encoding - self.encoding.unwrap(), + self.encoding + .expect("message was compressed but compression not enabled on server"), &mut self.buf, &mut self.decompress_buf, *len, diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 61bb82c7f..6ffc6d7d2 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -1,5 +1,5 @@ use super::{ - compression::{compress, Encoding}, + compression::{compress, CompressionEncoding}, EncodeBuf, Encoder, }; use crate::{Code, Status}; @@ -19,35 +19,35 @@ pub(super) const BUFFER_SIZE: usize = 8 * 1024; pub(crate) fn encode_server( encoder: T, source: U, - encoding: Option, + compression_encoding: Option, ) -> EncodeBody>> where T: Encoder + Send + Sync + 'static, T::Item: Send + Sync, U: Stream> + Send + Sync + 'static, { - let stream = encode(encoder, source, encoding).into_stream(); + let stream = encode(encoder, source, compression_encoding).into_stream(); EncodeBody::new_server(stream) } pub(crate) fn encode_client( encoder: T, source: U, + compression_encoding: Option, ) -> EncodeBody>> where T: Encoder + Send + Sync + 'static, T::Item: Send + Sync, U: Stream + Send + Sync + 'static, { - // TODO(david): get encoding as argument? - let stream = encode(encoder, source.map(Ok), None).into_stream(); + let stream = encode(encoder, source.map(Ok), compression_encoding).into_stream(); EncodeBody::new_client(stream) } fn encode( mut encoder: T, source: U, - encoding: Option, + compression_encoding: Option, ) -> impl TryStream where T: Encoder, @@ -56,8 +56,8 @@ where async_stream::stream! { let mut buf = BytesMut::with_capacity(BUFFER_SIZE); - let (compression_enabled, mut compression_buf) = match encoding { - Some(Encoding::Gzip) => (true, BytesMut::with_capacity(BUFFER_SIZE)), + let (compression_enabled, mut compression_buf) = match compression_encoding { + Some(CompressionEncoding::Gzip) => (true, BytesMut::with_capacity(BUFFER_SIZE)), None => (false, BytesMut::new()), }; @@ -77,7 +77,7 @@ where let compressed_len = compression_buf.len(); let compress_result = compress( - encoding.unwrap(), + compression_encoding.unwrap(), &mut compression_buf, &mut buf, compressed_len, diff --git a/tonic/src/codec/mod.rs b/tonic/src/codec/mod.rs index 3d520dc9d..779e9cec2 100644 --- a/tonic/src/codec/mod.rs +++ b/tonic/src/codec/mod.rs @@ -10,15 +10,17 @@ mod encode; #[cfg(feature = "prost")] mod prost; +use crate::Status; use std::io; -pub use self::decode::Streaming; pub(crate) use self::encode::{encode_client, encode_server}; + +pub use self::buffer::{DecodeBuf, EncodeBuf}; +pub use self::compression::{CompressionEncoding, EnabledCompressionEncodings}; +pub use self::decode::Streaming; #[cfg(feature = "prost")] #[cfg_attr(docsrs, doc(cfg(feature = "prost")))] pub use self::prost::ProstCodec; -use crate::Status; -pub use buffer::{DecodeBuf, EncodeBuf}; /// Trait that knows how to encode and decode gRPC messages. pub trait Codec: Default { diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index 28e85f4a5..3ddb65c4e 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -100,7 +100,7 @@ mod tests { let body = body::MockBody::new(&buf[..], 10005, 0); - let mut stream = Streaming::new_request(decoder, body); + let mut stream = Streaming::new_request(decoder, body, None); let mut i = 0usize; while let Some(output_msg) = stream.message().await.unwrap() { diff --git a/tonic/src/codegen.rs b/tonic/src/codegen.rs index dd83f2c4a..321615c8b 100644 --- a/tonic/src/codegen.rs +++ b/tonic/src/codegen.rs @@ -10,6 +10,7 @@ pub use std::sync::Arc; pub use std::task::{Context, Poll}; pub use tower_service::Service; pub type StdError = Box; +pub use crate::codec::{CompressionEncoding, EnabledCompressionEncodings}; pub use crate::service::interceptor::InterceptedService; pub use http_body::Body; diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index 09c40bc15..3aeee1d35 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -1,6 +1,9 @@ use crate::{ body::BoxBody, - codec::{compression::Encoding, encode_server, Codec, Streaming}, + codec::{ + compression::{CompressionEncoding, EnabledCompressionEncodings}, + encode_server, Codec, Streaming, + }, server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService}, Code, Request, Status, }; @@ -20,6 +23,10 @@ use std::fmt; /// implements some [`Body`]. pub struct Grpc { codec: T, + /// Which compression encodings does the server accept for requests? + accept_compression_encodings: EnabledCompressionEncodings, + /// Which compression encodings might the server use for responses. + send_compression_encodings: EnabledCompressionEncodings, } impl Grpc @@ -29,7 +36,42 @@ where { /// Creates a new gRPC server with the provided [`Codec`]. pub fn new(codec: T) -> Self { - Self { codec } + Self { + codec, + accept_compression_encodings: EnabledCompressionEncodings::default(), + send_compression_encodings: EnabledCompressionEncodings::default(), + } + } + + pub fn accept_gzip(mut self) -> Self { + self.accept_compression_encodings.enable_gzip(); + self + } + + pub fn send_gzip(mut self) -> Self { + self.send_compression_encodings.enable_gzip(); + self + } + + #[doc(hidden)] + pub fn apply_compression_config( + self, + accept_encodings: EnabledCompressionEncodings, + send_encodings: EnabledCompressionEncodings, + ) -> Self { + let mut this = self; + + let EnabledCompressionEncodings { gzip: accept_gzip } = accept_encodings; + if accept_gzip { + this = this.accept_gzip(); + } + + let EnabledCompressionEncodings { gzip: send_gzip } = send_encodings; + if send_gzip { + this = this.send_gzip(); + } + + this } /// Handle a single unary gRPC request. @@ -43,7 +85,10 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { - let encoding = Encoding::from_accept_encoding_header(req.headers()); + let encoding = CompressionEncoding::from_accept_encoding_header( + req.headers(), + self.send_compression_encodings, + ); let request = match self.map_request_unary(req).await { Ok(r) => r, @@ -137,8 +182,29 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { + // TODO(david): probably should set this directly on `Grpc`, like the client + let request_compression_encoding = if let Some(request_compression_encoding) = + CompressionEncoding::from_encoding_header(request.headers()) + { + let encoding_supported = match request_compression_encoding { + CompressionEncoding::Gzip => self.accept_compression_encodings.gzip(), + }; + + if encoding_supported { + Some(request_compression_encoding) + } else { + return Err(Status::unimplemented(format!( + "Request is compressed with `{}` which the server doesn't support", + request_compression_encoding + ))); + } + } else { + None + }; + let (parts, body) = request.into_parts(); - let stream = Streaming::new_request(self.codec.decoder(), body); + let stream = + Streaming::new_request(self.codec.decoder(), body, request_compression_encoding); futures_util::pin_mut!(stream); @@ -164,13 +230,16 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { - Request::from_http(request.map(|body| Streaming::new_request(self.codec.decoder(), body))) + Request::from_http(request.map(|body| { + // TODO(david): get compression encoding from request and don't hard code `None` + Streaming::new_request(self.codec.decoder(), body, None) + })) } fn map_response( &mut self, response: Result, Status>, - encoding: Option, + encoding: Option, ) -> http::Response where B: TryStream + Send + Sync + 'static, diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 8b076b7eb..204ab7bc9 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -2,7 +2,6 @@ use super::super::service; use super::Channel; #[cfg(feature = "tls")] use super::ClientTlsConfig; -use crate::codec::compression::{EnabledEncodings, Encoding}; #[cfg(feature = "tls")] use crate::transport::service::TlsConnector; use crate::transport::Error; @@ -40,8 +39,6 @@ pub struct Endpoint { pub(crate) http2_keep_alive_timeout: Option, pub(crate) http2_keep_alive_while_idle: Option, pub(crate) http2_adaptive_window: Option, - pub(crate) accept_encoding: EnabledEncodings, - pub(crate) send_encoding: Option, } impl Endpoint { @@ -243,33 +240,6 @@ impl Endpoint { } } - /// Enable `gzip` compressed responses. - /// - /// This will tell the server that `gzip` compression is accepted. Messages will be - /// automatically decompressed. - /// - /// This does not compress messages sent by the client. - /// - /// Compression is not enabled by default. - // TODO(david): disabling compression on individual messages - pub fn accept_gzip(self) -> Self { - Endpoint { - accept_encoding: self.accept_encoding.gzip(), - ..self - } - } - - /// Compress requests with `gzip`. - /// - /// This requires the server to accept `gzip` compressed requests otherwise it might - /// respond with an error. - pub fn send_gzip(self) -> Self { - Endpoint { - send_encoding: Some(Encoding::Gzip), - ..self - } - } - /// Create a channel from this config. pub async fn connect(&self) -> Result { let mut http = hyper::client::connect::HttpConnector::new(); @@ -359,8 +329,6 @@ impl From for Endpoint { http2_keep_alive_timeout: None, http2_keep_alive_while_idle: None, http2_adaptive_window: None, - accept_encoding: EnabledEncodings::default(), - send_encoding: None, } } } diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index 83b6d26d0..2cfcb0dad 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -12,16 +12,14 @@ pub use tls::ClientTlsConfig; use super::service::{Connection, DynamicServiceStream}; use crate::{ body::BoxBody, - codec::compression::{EnabledEncodings, Encoding}, + codec::compression::{CompressionEncoding, EnabledCompressionEncodings}, }; -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::Bytes; use http::{ uri::{InvalidUri, Uri}, Request, Response, }; -use http_body::Body as _; use hyper::client::connect::Connection as HyperConnection; -use pin_project::pin_project; use std::{ fmt, future::Future, @@ -41,7 +39,6 @@ use tower::{ util::BoxService, Service, }; -use tower_http::set_header::SetRequestHeader; type Svc = BoxService, Response, crate::Error>; @@ -73,8 +70,6 @@ const DEFAULT_BUFFER_SIZE: usize = 1024; #[derive(Clone)] pub struct Channel { svc: Buffer>, - /// The encoding that request bodies will be compressed with. - send_encoding: Option, } /// A future that resolves to an HTTP response. @@ -146,15 +141,12 @@ impl Channel { C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); - let accept_encoding = endpoint.accept_encoding; - let send_encoding = endpoint.send_encoding; let svc = Connection::lazy(connector, endpoint); - let svc = with_accept_encoding(svc, accept_encoding); let svc = BoxService::new(svc); let svc = Buffer::new(svc, buffer_size); - Channel { svc, send_encoding } + Channel { svc } } pub(crate) async fn connect(connector: C, endpoint: Endpoint) -> Result @@ -165,17 +157,14 @@ impl Channel { C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); - let accept_encoding = endpoint.accept_encoding; - let send_encoding = endpoint.send_encoding; let svc = Connection::connect(connector, endpoint) .await .map_err(super::Error::from_source)?; - let svc = with_accept_encoding(svc, accept_encoding); let svc = BoxService::new(svc); let svc = Buffer::new(svc, buffer_size); - Ok(Channel { svc, send_encoding }) + Ok(Channel { svc }) } pub(crate) fn balance(discover: D, buffer_size: usize) -> Self @@ -189,25 +178,10 @@ impl Channel { let svc = BoxService::new(svc); let svc = Buffer::new(svc, buffer_size); - Channel { - svc, - send_encoding: None, - } + Channel { svc } } } -fn with_accept_encoding( - svc: S, - accept_encoding: EnabledEncodings, -) -> SetRequestHeader { - let header_value = accept_encoding.into_accept_encoding_header_value(); - SetRequestHeader::overriding( - svc, - http::header::HeaderName::from_static(crate::codec::compression::ACCEPT_ENCODING_HEADER), - header_value, - ) -} - impl Service> for Channel { type Response = http::Response; type Error = super::Error; @@ -218,25 +192,6 @@ impl Service> for Channel { } fn call(&mut self, request: http::Request) -> Self::Future { - let (mut parts, body) = request.into_parts(); - - let new_body = if let Some(encoding) = self.send_encoding { - parts.headers.insert( - crate::codec::compression::ENCODING_HEADER, - encoding.into_header_value(), - ); - - CompressEachChunkBody { - inner: body, - encoding, - encoding_buf: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE), - } - .boxed() - } else { - body - }; - - let request = http::Request::from_parts(parts, new_body); let inner = Service::call(&mut self.svc, request); ResponseFuture { inner } } @@ -263,64 +218,3 @@ impl fmt::Debug for ResponseFuture { f.debug_struct("ResponseFuture").finish() } } - -/// A `http_body::Body` that compresses each chunk with a given encoding. -#[pin_project] -struct CompressEachChunkBody { - #[pin] - inner: B, - encoding: Encoding, - encoding_buf: BytesMut, -} - -impl http_body::Body for CompressEachChunkBody -where - B: http_body::Body, -{ - type Data = Bytes; - type Error = crate::Status; - - fn poll_data( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - let this = self.project(); - match futures_util::ready!(this.inner.poll_data(cx)) { - Some(Ok(mut chunk)) => { - let len = chunk.len(); - - this.encoding_buf.clear(); - - if let Err(err) = crate::codec::compression::compress( - *this.encoding, - &mut chunk, - this.encoding_buf, - len, - ) { - let status = - crate::Status::internal("Failed to compress body chunk").with_source(err); - return Poll::Ready(Some(Err(status))); - } - - let chunk = this.encoding_buf.clone().freeze(); - - Poll::Ready(Some(Ok(chunk))) - } - other => Poll::Ready(other), - } - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - self.project().inner.poll_trailers(cx) - } - - fn is_end_stream(&self) -> bool { - self.inner.is_end_stream() - } - - // we don't define `size_hint` because we compress each - // chunk and dunno the size -} diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 5d42d0852..7ccbc589d 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -27,7 +27,7 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, Or, Routes, ServerIo}; -use crate::{body::BoxBody, codec::compression::EnabledEncodings}; +use crate::body::BoxBody; use bytes::Bytes; use futures_core::Stream; use futures_util::{ @@ -85,8 +85,6 @@ pub struct Server { max_frame_size: Option, accept_http1: bool, layer: L, - send_encodings: EnabledEncodings, - accept_encodings: EnabledEncodings, } /// A stack based `Service` router. @@ -323,22 +321,6 @@ impl Server { } } - /// Compress outgoing messages with `gzip` if supported by the client. - pub fn send_gzip(self) -> Self { - Server { - send_encodings: self.send_encodings.gzip(), - ..self - } - } - - /// Accept requests compressed with `gzip`. - pub fn accept_gzip(self) -> Self { - Server { - accept_encodings: self.accept_encodings.gzip(), - ..self - } - } - /// Create a router with the `S` typed service as the first service. /// /// This will clone the `Server` builder and create a router that will @@ -464,8 +446,6 @@ impl Server { http2_keepalive_timeout: self.http2_keepalive_timeout, max_frame_size: self.max_frame_size, accept_http1: self.accept_http1, - send_encodings: EnabledEncodings::default(), - accept_encodings: EnabledEncodings::default(), } } @@ -496,7 +476,6 @@ impl Server { let timeout = self.timeout; let max_frame_size = self.max_frame_size; let http2_only = !self.accept_http1; - let encodings = self.send_encodings; let http2_keepalive_interval = self.http2_keepalive_interval; let http2_keepalive_timeout = self @@ -513,7 +492,6 @@ impl Server { concurrency_limit, timeout, trace_interceptor, - encodings, _io: PhantomData, }; @@ -779,7 +757,6 @@ impl fmt::Debug for Server { struct Svc { inner: S, trace_interceptor: Option, - encodings: EnabledEncodings, } impl Service> for Svc @@ -798,12 +775,6 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - if let Some(value) = req.headers().get("grpc-encoding") { - if value == "gzip" { - todo!() - } - } - let span = if let Some(trace_interceptor) = &self.trace_interceptor { let (parts, body) = req.into_parts(); let bodyless_request = Request::from_parts(parts, ()); @@ -818,11 +789,6 @@ where tracing::Span::none() }; - // remove disabled encodings from `grpc-accept-encoding` so the inner service doesn't even - // seen them. - self.encodings - .remove_disabled_encodings_from_accept_encoding(req.headers_mut()); - SvcFuture { inner: self.inner.call(req), span, @@ -867,7 +833,6 @@ struct MakeSvc { timeout: Option, inner: S, trace_interceptor: Option, - encodings: EnabledEncodings, _io: PhantomData IO>, } @@ -895,7 +860,6 @@ where let concurrency_limit = self.concurrency_limit; let timeout = self.timeout; let trace_interceptor = self.trace_interceptor.clone(); - let encodings = self.encodings; let svc = ServiceBuilder::new() .layer_fn(RecoverError::new) @@ -931,7 +895,6 @@ where .service(Svc { inner: svc, trace_interceptor, - encodings, }); future::ready(Ok(svc)) From 50aa2e23977795b2e98aa18c66751fafa5621644 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 28 Jun 2021 12:00:58 +0200 Subject: [PATCH 08/29] Test sending compressed request to server that doesn't support it --- tests/compression/src/compressing_request.rs | 42 +++++++++++++++++++- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index 636abbe34..427c3ef8d 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -1,6 +1,8 @@ use super::*; use http_body::Body as _; +// TODO(david): send_gzip on channel, but disabling compression of a message + #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_enabled() { let svc = test_server::TestServer::new(Svc).accept_gzip(); @@ -32,13 +34,18 @@ async fn client_enabled_server_enabled() { }) }; + fn assert_right_encoding(req: http::Request) -> http::Request { + assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); + req + } + tokio::spawn(async move { Server::builder() .layer( ServiceBuilder::new() - // TODO(david): require request to have `grpc-encoding: gzip` .layer( ServiceBuilder::new() + .map_request(assert_right_encoding) .layer(measure_request_body_size_layer) .into_inner(), ) @@ -69,4 +76,35 @@ async fn client_enabled_server_enabled() { assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } -// TODO(david): send_gzip on channel, but disabling compression of a message +#[tokio::test(flavor = "multi_thread")] +async fn client_enabled_server_disabled() { + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel).send_gzip(); + + let status = client + .compress_input(SomeData { + data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), + }) + .await + .unwrap_err(); + + assert_eq!(status.code(), tonic::Code::Unimplemented); + assert_eq!(status.message(), "Request is compressed with `gzip` which the server doesn't support"); +} From 6cb6fe220da651cafb082d5b11606a07056b58e1 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 28 Jun 2021 12:09:24 +0200 Subject: [PATCH 09/29] Clean up a bit --- tests/compression/src/compressing_request.rs | 5 +- tests/compression/src/lib.rs | 1 - tonic/Cargo.toml | 2 - tonic/benches/decode.rs | 3 +- tonic/src/codec/compression.rs | 138 ++----------------- tonic/src/status.rs | 11 -- tonic/src/transport/channel/mod.rs | 5 +- 7 files changed, 15 insertions(+), 150 deletions(-) diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index 427c3ef8d..1afa6fa73 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -106,5 +106,8 @@ async fn client_enabled_server_disabled() { .unwrap_err(); assert_eq!(status.code(), tonic::Code::Unimplemented); - assert_eq!(status.message(), "Request is compressed with `gzip` which the server doesn't support"); + assert_eq!( + status.message(), + "Request is compressed with `gzip` which the server doesn't support" + ); } diff --git a/tests/compression/src/lib.rs b/tests/compression/src/lib.rs index 16c10eb11..2668f4ba8 100644 --- a/tests/compression/src/lib.rs +++ b/tests/compression/src/lib.rs @@ -18,7 +18,6 @@ mod util; tonic::include_proto!("test"); -// TODO(david): client copmressing messages // TODO(david): client streaming // TODO(david): server streaming // TODO(david): bidirectional streaming diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 97e31c186..252729b16 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -30,7 +30,6 @@ transport = [ "hyper", "tokio", "tower", - "tower-http", "tracing-futures", "tokio/macros", "tokio/time", @@ -75,7 +74,6 @@ hyper = { version = "0.14.2", features = ["full"], optional = true } tokio = { version = "1.0.1", features = ["net", "rt-multi-thread"], optional = true } tokio-stream = "0.1" tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true } -tower-http = { version = "0.1", features = ["set-header"], optional = true } tracing-futures = { version = "0.2", optional = true } # rustls diff --git a/tonic/benches/decode.rs b/tonic/benches/decode.rs index 41b249672..4750d2508 100644 --- a/tonic/benches/decode.rs +++ b/tonic/benches/decode.rs @@ -22,7 +22,8 @@ macro_rules! bench { b.iter(|| { rt.block_on(async { let decoder = MockDecoder::new($message_size); - let mut stream = Streaming::new_request(decoder, body.clone()); + // TODO(david): add benchmark with compression + let mut stream = Streaming::new_request(decoder, body.clone(), None); let mut count = 0; while let Some(msg) = stream.message().await.unwrap() { diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index e5b5a1af7..f61e823fc 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -1,7 +1,7 @@ use super::encode::BUFFER_SIZE; use bytes::{Buf, BufMut, BytesMut}; use flate2::read::{GzDecoder, GzEncoder}; -use std::fmt::{self, Write}; +use std::fmt; pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; @@ -13,7 +13,8 @@ pub struct EnabledCompressionEncodings { } impl EnabledCompressionEncodings { - pub(crate) fn gzip(self) -> bool { + /// Check if `gzip` compression is enabled. + pub fn gzip(self) -> bool { self.gzip } @@ -23,86 +24,20 @@ impl EnabledCompressionEncodings { } pub(crate) fn into_accept_encoding_header_value(self) -> Option { - if self.gzip { + let Self { gzip } = self; + if gzip { Some(http::HeaderValue::from_static("gzip,identity")) } else { None } } - - /// Find the `grpc-accept-encoding` header and remove the encoding values that aren't enabled. - /// - /// For example a header value like `gzip,brotli,identity` where only `gzip` is enabled will - /// become `gzip`. - /// - /// This is used to remove disabled encodings from incoming requests in the server before they - /// each the actual `server::Grpc` service implementation. It is not possible to configure - /// `server::Grpc` so the configuration must be done at the `Server` level. - pub(crate) fn remove_disabled_encodings_from_accept_encoding(self, map: &mut http::HeaderMap) { - let accept_encoding = if let Some(accept_encoding) = map.remove(ACCEPT_ENCODING_HEADER) { - accept_encoding - } else { - return; - }; - - let accept_encoding_str = if let Ok(accept_encoding) = accept_encoding.to_str() { - accept_encoding - } else { - map.insert( - http::header::HeaderName::from_static(ACCEPT_ENCODING_HEADER), - accept_encoding, - ); - return; - }; - - // first check if we need to make changes to avoid allocating - let contains_disabled_encodings = - split_by_comma(accept_encoding_str).any(|encoding| match encoding { - "gzip" => !self.gzip, - _ => true, - }); - - if !contains_disabled_encodings { - // no changes necessary, put the original value back - map.insert( - http::header::HeaderName::from_static(ACCEPT_ENCODING_HEADER), - accept_encoding, - ); - return; - } - - // can be simplified when `Iterator::intersperse` is stable - let enabled_encodings = - split_by_comma(accept_encoding_str).filter_map(|encoding| match encoding { - "gzip" if self.gzip => Some("gzip"), - _ => None, - }); - - let mut new_value = String::new(); - let mut is_first = true; - - for encoding in enabled_encodings { - if is_first { - let _ = write!(new_value, "{}", encoding); - } else { - let _ = write!(new_value, ",{}", encoding); - }; - is_first = false; - } - - if !new_value.is_empty() { - map.insert( - http::header::HeaderName::from_static(ACCEPT_ENCODING_HEADER), - new_value.parse().unwrap(), - ); - } - } } -#[derive(Clone, Copy, Debug)] +/// The compression encodings Tonic supports. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[non_exhaustive] -#[doc(hidden)] pub enum CompressionEncoding { + #[allow(missing_docs)] Gzip, } @@ -208,60 +143,3 @@ pub(crate) fn decompress( Ok(()) } - -#[cfg(test)] -mod tests { - #[allow(unused_imports)] - use super::*; - use http::header::{HeaderMap, HeaderName}; - - #[test] - fn remove_disabled_encodings_empty_map() { - let mut map = HeaderMap::new(); - let encodings = EnabledCompressionEncodings { gzip: true }; - encodings.remove_disabled_encodings_from_accept_encoding(&mut map); - assert!(map.is_empty()); - } - - #[test] - fn remove_disabled_encodings_single_supported() { - let mut map = HeaderMap::new(); - map.insert( - HeaderName::from_static(ACCEPT_ENCODING_HEADER), - "gzip".parse().unwrap(), - ); - - let encodings = EnabledCompressionEncodings { gzip: true }; - encodings.remove_disabled_encodings_from_accept_encoding(&mut map); - - assert_eq!(&map[ACCEPT_ENCODING_HEADER], "gzip"); - } - - #[test] - fn remove_disabled_encodings_single_unsupported() { - let mut map = HeaderMap::new(); - map.insert( - HeaderName::from_static(ACCEPT_ENCODING_HEADER), - "gzip".parse().unwrap(), - ); - - let encodings = EnabledCompressionEncodings { gzip: false }; - encodings.remove_disabled_encodings_from_accept_encoding(&mut map); - - assert!(map.get(ACCEPT_ENCODING_HEADER).is_none()); - } - - #[test] - fn remove_disabled_encodings_multiple_supported() { - let mut map = HeaderMap::new(); - map.insert( - HeaderName::from_static(ACCEPT_ENCODING_HEADER), - "foo,gzip,identity".parse().unwrap(), - ); - - let encodings = EnabledCompressionEncodings { gzip: true }; - encodings.remove_disabled_encodings_from_accept_encoding(&mut map); - - assert_eq!(&map[ACCEPT_ENCODING_HEADER], "gzip"); - } -} diff --git a/tonic/src/status.rs b/tonic/src/status.rs index e9f79bfef..e5bb0dc56 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -334,17 +334,6 @@ impl Status { Err(err) } - /// Set the source error of the status - pub(crate) fn with_source(self, source: T) -> Self - where - T: Into>, - { - Self { - source: Some(source.into()), - ..self - } - } - // FIXME: bubble this into `transport` and expose generic http2 reasons. #[cfg(feature = "transport")] fn from_h2_error(err: &h2::Error) -> Status { diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index 2cfcb0dad..dcf77bcfa 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -10,10 +10,7 @@ pub use endpoint::Endpoint; pub use tls::ClientTlsConfig; use super::service::{Connection, DynamicServiceStream}; -use crate::{ - body::BoxBody, - codec::compression::{CompressionEncoding, EnabledCompressionEncodings}, -}; +use crate::body::BoxBody; use bytes::Bytes; use http::{ uri::{InvalidUri, Uri}, From 6b1946d83981f2220036494e17713541c303aee3 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 28 Jun 2021 15:42:59 +0200 Subject: [PATCH 10/29] Compress server streams --- tests/compression/proto/test.proto | 5 +- tests/compression/src/compressing_request.rs | 22 +-- tests/compression/src/compressing_response.rs | 16 +- tests/compression/src/lib.rs | 28 ++- tests/compression/src/server_stream.rs | 162 ++++++++++++++++++ tonic/src/codec/compression.rs | 35 ++-- tonic/src/codec/decode.rs | 32 ++-- tonic/src/codec/encode.rs | 40 ++--- tonic/src/codec/mod.rs | 7 + tonic/src/codec/prost.rs | 6 +- tonic/src/server/grpc.rs | 9 +- 11 files changed, 274 insertions(+), 88 deletions(-) create mode 100644 tests/compression/src/server_stream.rs diff --git a/tests/compression/proto/test.proto b/tests/compression/proto/test.proto index 824026c43..05a8e7366 100644 --- a/tests/compression/proto/test.proto +++ b/tests/compression/proto/test.proto @@ -5,8 +5,9 @@ package test; import "google/protobuf/empty.proto"; service Test { - rpc CompressOutput(google.protobuf.Empty) returns (SomeData); - rpc CompressInput(SomeData) returns (google.protobuf.Empty); + rpc CompressOutputUnary(google.protobuf.Empty) returns (SomeData); + rpc CompressInputUnary(SomeData) returns (google.protobuf.Empty); + rpc CompressOutputStream(google.protobuf.Empty) returns (stream SomeData); } message SomeData { diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index 1afa6fa73..705c93e01 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -64,16 +64,16 @@ async fn client_enabled_server_enabled() { let mut client = test_client::TestClient::new(channel).send_gzip(); - client - .compress_input(SomeData { - data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), - }) - .await - .unwrap(); - - let bytes_sent = bytes_sent_counter.load(Relaxed); - dbg!(&bytes_sent); - assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); + for _ in 0..3 { + client + .compress_input_unary(SomeData { + data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), + }) + .await + .unwrap(); + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(dbg!(bytes_sent) < UNCOMPRESSED_MIN_BODY_SIZE); + } } #[tokio::test(flavor = "multi_thread")] @@ -99,7 +99,7 @@ async fn client_enabled_server_disabled() { let mut client = test_client::TestClient::new(channel).send_gzip(); let status = client - .compress_input(SomeData { + .compress_input_unary(SomeData { data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), }) .await diff --git a/tests/compression/src/compressing_response.rs b/tests/compression/src/compressing_response.rs index a6d68e130..3b22fefad 100644 --- a/tests/compression/src/compressing_response.rs +++ b/tests/compression/src/compressing_response.rs @@ -67,12 +67,12 @@ async fn client_enabled_server_enabled() { let mut client = test_client::TestClient::new(channel).accept_gzip(); - let res = client.compress_output(()).await.unwrap(); - - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); - - let bytes_sent = bytes_sent_counter.load(Relaxed); - assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); + for _ in 0..3 { + let res = client.compress_output_unary(()).await.unwrap(); + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); + } } #[tokio::test(flavor = "multi_thread")] @@ -113,7 +113,7 @@ async fn client_enabled_server_disabled() { let mut client = test_client::TestClient::new(channel).accept_gzip(); - let res = client.compress_output(()).await.unwrap(); + let res = client.compress_output_unary(()).await.unwrap(); assert!(res.metadata().get("grpc-encoding").is_none()); @@ -183,7 +183,7 @@ async fn client_disabled() { let mut client = test_client::TestClient::new(channel); - let res = client.compress_output(()).await.unwrap(); + let res = client.compress_output_unary(()).await.unwrap(); assert!(res.metadata().get("grpc-encoding").is_none()); diff --git a/tests/compression/src/lib.rs b/tests/compression/src/lib.rs index 2668f4ba8..5e6a073fc 100644 --- a/tests/compression/src/lib.rs +++ b/tests/compression/src/lib.rs @@ -1,8 +1,12 @@ #![allow(unused_imports)] -use std::sync::{ - atomic::{AtomicUsize, Ordering::Relaxed}, - Arc, +use futures::{Stream, StreamExt}; +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, + }, }; use tokio::net::TcpListener; use tonic::{ @@ -14,12 +18,12 @@ use tower_http::{map_request_body::MapRequestBodyLayer, map_response_body::MapRe mod compressing_request; mod compressing_response; +mod server_stream; mod util; tonic::include_proto!("test"); // TODO(david): client streaming -// TODO(david): server streaming // TODO(david): bidirectional streaming struct Svc; @@ -28,15 +32,27 @@ const UNCOMPRESSED_MIN_BODY_SIZE: usize = 1024; #[tonic::async_trait] impl test_server::Test for Svc { - async fn compress_output(&self, _req: Request<()>) -> Result, Status> { + async fn compress_output_unary(&self, _req: Request<()>) -> Result, Status> { let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE]; Ok(Response::new(SomeData { data: data.to_vec(), })) } - async fn compress_input(&self, req: Request) -> Result, Status> { + async fn compress_input_unary(&self, req: Request) -> Result, Status> { assert_eq!(req.into_inner().data.len(), UNCOMPRESSED_MIN_BODY_SIZE); Ok(Response::new(())) } + + type CompressOutputStreamStream = + Pin> + Send + Sync + 'static>>; + + async fn compress_output_stream( + &self, + _req: Request<()>, + ) -> Result, Status> { + let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); + let stream = futures::stream::repeat(SomeData { data }).map(Ok::<_, Status>); + Ok(Response::new(Box::pin(stream))) + } } diff --git a/tests/compression/src/server_stream.rs b/tests/compression/src/server_stream.rs new file mode 100644 index 000000000..3bbada610 --- /dev/null +++ b/tests/compression/src/server_stream.rs @@ -0,0 +1,162 @@ +use super::*; +use tonic::Streaming; + +#[tokio::test(flavor = "multi_thread")] +async fn client_enabled_server_enabled() { + let svc = test_server::TestServer::new(Svc).send_gzip(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel).accept_gzip(); + + let res = client.compress_output_stream(()).await.unwrap(); + + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + + let mut stream: Streaming = res.into_inner(); + + stream + .next() + .await + .expect("stream empty") + .expect("item was error"); + assert!(dbg!(bytes_sent_counter.load(Relaxed)) < UNCOMPRESSED_MIN_BODY_SIZE); + + stream + .next() + .await + .expect("stream empty") + .expect("item was error"); + assert!(dbg!(bytes_sent_counter.load(Relaxed)) < UNCOMPRESSED_MIN_BODY_SIZE); +} + +#[tokio::test(flavor = "multi_thread")] +async fn client_disabled_server_enabled() { + let svc = test_server::TestServer::new(Svc).send_gzip(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + let res = client.compress_output_stream(()).await.unwrap(); + + assert!(res.metadata().get("grpc-encoding").is_none()); + + let mut stream: Streaming = res.into_inner(); + + stream + .next() + .await + .expect("stream empty") + .expect("item was error"); + assert!(dbg!(bytes_sent_counter.load(Relaxed)) > UNCOMPRESSED_MIN_BODY_SIZE); +} + +#[tokio::test(flavor = "multi_thread")] +async fn client_enabled_server_disabled() { + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel).accept_gzip(); + + let res = client.compress_output_stream(()).await.unwrap(); + + assert!(res.metadata().get("grpc-encoding").is_none()); + + let mut stream: Streaming = res.into_inner(); + + stream + .next() + .await + .expect("stream empty") + .expect("item was error"); + assert!(dbg!(bytes_sent_counter.load(Relaxed)) > UNCOMPRESSED_MIN_BODY_SIZE); +} diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index f61e823fc..f8a4e5415 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -86,52 +86,49 @@ fn split_by_comma(s: &str) -> impl Iterator { } /// Compress `len` bytes from `in_buffer` into `out_buffer`. -pub(crate) fn compress( +pub(crate) fn compress( encoding: CompressionEncoding, - in_buffer: &mut B, - out_buffer: &mut BytesMut, + uncompressed_buf: &mut BytesMut, + out_buf: &mut BytesMut, len: usize, -) -> Result<(), std::io::Error> -where - B: AsRef<[u8]> + bytes::Buf, -{ +) -> Result<(), std::io::Error> { let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE; - out_buffer.reserve(capacity); + out_buf.reserve(capacity); match encoding { CompressionEncoding::Gzip => { - let mut gzip_decoder = GzEncoder::new( - &in_buffer.as_ref()[0..len], + let mut gzip_encoder = GzEncoder::new( + &uncompressed_buf[0..len], // FIXME: support customizing the compression level flate2::Compression::new(6), ); - let mut out_writer = out_buffer.writer(); + let mut out_writer = out_buf.writer(); - tokio::task::block_in_place(|| std::io::copy(&mut gzip_decoder, &mut out_writer))?; + tokio::task::block_in_place(|| std::io::copy(&mut gzip_encoder, &mut out_writer))?; } } // TODO(david): is this necessary? test sending multiple requests and // responses on the same channel - in_buffer.advance(len); + uncompressed_buf.advance(len); Ok(()) } pub(crate) fn decompress( encoding: CompressionEncoding, - in_buffer: &mut BytesMut, - out_buffer: &mut BytesMut, + compressed_buf: &mut BytesMut, + out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error> { let estimate_decompressed_len = len * 2; let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE; - out_buffer.reserve(capacity); + out_buf.reserve(capacity); match encoding { CompressionEncoding::Gzip => { - let mut gzip_decoder = GzDecoder::new(&in_buffer[0..len]); - let mut out_writer = out_buffer.writer(); + let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]); + let mut out_writer = out_buf.writer(); tokio::task::block_in_place(|| std::io::copy(&mut gzip_decoder, &mut out_writer))?; } @@ -139,7 +136,7 @@ pub(crate) fn decompress( // TODO(david): is this necessary? test sending multiple requests and // responses on the same channel - in_buffer.advance(len); + compressed_buf.advance(len); Ok(()) } diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index a34aa33b8..3ffed94a8 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -1,6 +1,6 @@ use super::{ compression::{decompress, CompressionEncoding}, - DecodeBuf, Decoder, + DecodeBuf, Decoder, HEADER_SIZE, }; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; @@ -173,7 +173,7 @@ impl Streaming { fn decode_chunk(&mut self) -> Result, Status> { if let State::ReadHeader = self.state { - if self.buf.remaining() < 5 { + if self.buf.remaining() < HEADER_SIZE { return Ok(None); } @@ -209,7 +209,9 @@ impl Streaming { return Ok(None); } - let result = if *compression { + let decoding_result = if *compression { + self.decompress_buf.clear(); + if let Err(err) = decompress( // TODO(david): handle missing self.encoding self.encoding @@ -228,26 +230,24 @@ impl Streaming { }; return Err(Status::new(Code::Internal, message)); } - let uncompressed_len = self.decompress_buf.len(); + let decompressed_len = self.decompress_buf.len(); self.decoder.decode(&mut DecodeBuf::new( &mut self.decompress_buf, - uncompressed_len, + decompressed_len, )) } else { - match self - .decoder + self.decoder .decode(&mut DecodeBuf::new(&mut self.buf, *len)) - { - Ok(Some(msg)) => { - self.state = State::ReadHeader; - Ok(Some(msg)) - } - Ok(None) => Ok(None), - Err(e) => Err(e), - } }; - return result; + return match decoding_result { + Ok(Some(msg)) => { + self.state = State::ReadHeader; + Ok(Some(msg)) + } + Ok(None) => Ok(None), + Err(e) => Err(e), + }; } Ok(None) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 6ffc6d7d2..08319045a 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -1,6 +1,6 @@ use super::{ compression::{compress, CompressionEncoding}, - EncodeBuf, Encoder, + EncodeBuf, Encoder, HEADER_SIZE, }; use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; @@ -56,7 +56,7 @@ where async_stream::stream! { let mut buf = BytesMut::with_capacity(BUFFER_SIZE); - let (compression_enabled, mut compression_buf) = match compression_encoding { + let (compression_enabled, mut uncompression_buf) = match compression_encoding { Some(CompressionEncoding::Gzip) => (true, BytesMut::with_capacity(BUFFER_SIZE)), None => (false, BytesMut::new()), }; @@ -66,42 +66,40 @@ where loop { match source.next().await { Some(Ok(item)) => { - buf.reserve(5); + buf.reserve(HEADER_SIZE); unsafe { - buf.advance_mut(5); + buf.advance_mut(HEADER_SIZE); } if compression_enabled { - compression_buf.clear(); - encoder.encode(item, &mut EncodeBuf::new(&mut compression_buf)).map_err(drop).unwrap(); - let compressed_len = compression_buf.len(); + uncompression_buf.clear(); - let compress_result = compress( + encoder.encode(item, &mut EncodeBuf::new(&mut uncompression_buf)) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; + + let uncompressed_len = uncompression_buf.len(); + + compress( compression_encoding.unwrap(), - &mut compression_buf, + &mut uncompression_buf, &mut buf, - compressed_len, - ); - - if let Err(err) = compress_result { - yield Err(Status::internal(format!("Error compressing: {}", err))) - } + uncompressed_len, + ).map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; } else { - encoder.encode(item, &mut EncodeBuf::new(&mut buf)).map_err(drop).unwrap(); + encoder.encode(item, &mut EncodeBuf::new(&mut buf)) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; } // now that we know length, we can write the header - let len = buf.len() - 5; + let len = buf.len() - HEADER_SIZE; assert!(len <= std::u32::MAX as usize); { - let mut buf = &mut buf[..5]; - + let mut buf = &mut buf[..HEADER_SIZE]; buf.put_u8(compression_enabled as u8); - buf.put_u32(len as u32); } - yield Ok(buf.split_to(len + 5).freeze()); + yield Ok(buf.split_to(len + HEADER_SIZE).freeze()); }, Some(Err(status)) => yield Err(status), None => break, diff --git a/tonic/src/codec/mod.rs b/tonic/src/codec/mod.rs index 779e9cec2..58b40324e 100644 --- a/tonic/src/codec/mod.rs +++ b/tonic/src/codec/mod.rs @@ -22,6 +22,13 @@ pub use self::decode::Streaming; #[cfg_attr(docsrs, doc(cfg(feature = "prost")))] pub use self::prost::ProstCodec; +// 5 bytes +const HEADER_SIZE: usize = + // compression flag + std::mem::size_of::() + + // data length + std::mem::size_of::(); + /// Trait that knows how to encode and decode gRPC messages. pub trait Codec: Default { /// The encodable message. diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index 3ddb65c4e..027b4e6ba 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -77,7 +77,9 @@ fn from_decode_error(error: prost1::DecodeError) -> crate::Status { #[cfg(test)] mod tests { - use crate::codec::{encode_server, DecodeBuf, Decoder, EncodeBuf, Encoder, Streaming}; + use crate::codec::{ + encode_server, DecodeBuf, Decoder, EncodeBuf, Encoder, Streaming, HEADER_SIZE, + }; use crate::Status; use bytes::{Buf, BufMut, BytesMut}; use http_body::Body; @@ -92,7 +94,7 @@ mod tests { let mut buf = BytesMut::new(); - buf.reserve(msg.len() + 5); + buf.reserve(msg.len() + HEADER_SIZE); buf.put_u8(0); buf.put_u32(msg.len() as u32); diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index 3aeee1d35..7d6fa69ff 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -121,18 +121,21 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { - // TODO(david): encoding + let encoding = CompressionEncoding::from_accept_encoding_header( + req.headers(), + self.send_compression_encodings, + ); let request = match self.map_request_unary(req).await { Ok(r) => r, Err(status) => { - return self.map_response::(Err(status), None); + return self.map_response::(Err(status), encoding); } }; let response = service.call(request).await; - self.map_response(response, None) + self.map_response(response, encoding) } /// Handle a client side streaming gRPC request. From 82c2297a2c4fea315d31d9a58261bf1e69861eb0 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 28 Jun 2021 16:41:24 +0200 Subject: [PATCH 11/29] Compress client streams --- tests/compression/proto/test.proto | 3 +- tests/compression/src/client_stream.rs | 133 +++++++++++++++++++ tests/compression/src/compressing_request.rs | 57 +++----- tests/compression/src/lib.rs | 22 ++- tests/compression/src/server_stream.rs | 6 +- tests/compression/src/util.rs | 25 ++++ tonic/src/server/grpc.rs | 88 +++++++----- 7 files changed, 252 insertions(+), 82 deletions(-) create mode 100644 tests/compression/src/client_stream.rs diff --git a/tests/compression/proto/test.proto b/tests/compression/proto/test.proto index 05a8e7366..05893cd90 100644 --- a/tests/compression/proto/test.proto +++ b/tests/compression/proto/test.proto @@ -7,7 +7,8 @@ import "google/protobuf/empty.proto"; service Test { rpc CompressOutputUnary(google.protobuf.Empty) returns (SomeData); rpc CompressInputUnary(SomeData) returns (google.protobuf.Empty); - rpc CompressOutputStream(google.protobuf.Empty) returns (stream SomeData); + rpc CompressOutputServerStream(google.protobuf.Empty) returns (stream SomeData); + rpc CompressInputClientStream(stream SomeData) returns (google.protobuf.Empty); } message SomeData { diff --git a/tests/compression/src/client_stream.rs b/tests/compression/src/client_stream.rs new file mode 100644 index 000000000..0ae370834 --- /dev/null +++ b/tests/compression/src/client_stream.rs @@ -0,0 +1,133 @@ +use super::*; +use http_body::Body as _; + +#[tokio::test(flavor = "multi_thread")] +async fn client_enabled_server_enabled() { + let svc = test_server::TestServer::new(Svc).accept_gzip(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + fn assert_right_encoding(req: http::Request) -> http::Request { + assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); + req + } + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .map_request(assert_right_encoding) + .layer(measure_request_body_size_layer(bytes_sent_counter.clone())) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel).send_gzip(); + + let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); + let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); + let req = Request::new(Box::pin(stream)); + + client.compress_input_client_stream(req).await.unwrap(); + + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(dbg!(bytes_sent) < UNCOMPRESSED_MIN_BODY_SIZE); +} + +#[tokio::test(flavor = "multi_thread")] +async fn client_disabled_server_enabled() { + let svc = test_server::TestServer::new(Svc).accept_gzip(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + fn assert_right_encoding(req: http::Request) -> http::Request { + assert!(req.headers().get("grpc-encoding").is_none()); + req + } + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .map_request(assert_right_encoding) + .layer(measure_request_body_size_layer(bytes_sent_counter.clone())) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); + let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); + let req = Request::new(Box::pin(stream)); + + client.compress_input_client_stream(req).await.unwrap(); + + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(dbg!(bytes_sent) > UNCOMPRESSED_MIN_BODY_SIZE); +} + +#[tokio::test(flavor = "multi_thread")] +async fn client_enabled_server_disabled() { + let svc = test_server::TestServer::new(Svc); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel).send_gzip(); + + let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); + let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); + let req = Request::new(Box::pin(stream)); + + let status = client.compress_input_client_stream(req).await.unwrap_err(); + + assert_eq!(status.code(), tonic::Code::Unimplemented); + assert_eq!( + status.message(), + "Request is compressed with `gzip` which the server doesn't support" + ); +} diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index 705c93e01..8a4ac98f6 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -12,49 +12,30 @@ async fn client_enabled_server_enabled() { let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); - let measure_request_body_size_layer = { - let bytes_sent_counter = bytes_sent_counter.clone(); - MapRequestBodyLayer::new(move |mut body: hyper::Body| { - let (mut tx, new_body) = hyper::Body::channel(); - - let bytes_sent_counter = bytes_sent_counter.clone(); - tokio::spawn(async move { - while let Some(chunk) = body.data().await { - let chunk = chunk.unwrap(); - bytes_sent_counter.fetch_add(chunk.len(), Relaxed); - tx.send_data(chunk).await.unwrap(); - } - - if let Some(trailers) = body.trailers().await.unwrap() { - tx.send_trailers(trailers).await.unwrap(); - } - }); - - new_body - }) - }; - fn assert_right_encoding(req: http::Request) -> http::Request { assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); req } - tokio::spawn(async move { - Server::builder() - .layer( - ServiceBuilder::new() - .layer( - ServiceBuilder::new() - .map_request(assert_right_encoding) - .layer(measure_request_body_size_layer) - .into_inner(), - ) - .into_inner(), - ) - .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) - .await - .unwrap(); + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer( + ServiceBuilder::new() + .map_request(assert_right_encoding) + .layer(measure_request_body_size_layer(bytes_sent_counter)) + .into_inner(), + ) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } }); let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) diff --git a/tests/compression/src/lib.rs b/tests/compression/src/lib.rs index 5e6a073fc..18e0fe7b5 100644 --- a/tests/compression/src/lib.rs +++ b/tests/compression/src/lib.rs @@ -1,5 +1,6 @@ #![allow(unused_imports)] +use self::util::*; use futures::{Stream, StreamExt}; use std::{ pin::Pin, @@ -11,11 +12,12 @@ use std::{ use tokio::net::TcpListener; use tonic::{ transport::{Channel, Server}, - Request, Response, Status, + Request, Response, Status, Streaming, }; use tower::{layer::layer_fn, Service, ServiceBuilder}; use tower_http::{map_request_body::MapRequestBodyLayer, map_response_body::MapResponseBodyLayer}; +mod client_stream; mod compressing_request; mod compressing_response; mod server_stream; @@ -23,7 +25,6 @@ mod util; tonic::include_proto!("test"); -// TODO(david): client streaming // TODO(david): bidirectional streaming struct Svc; @@ -44,15 +45,26 @@ impl test_server::Test for Svc { Ok(Response::new(())) } - type CompressOutputStreamStream = + type CompressOutputServerStreamStream = Pin> + Send + Sync + 'static>>; - async fn compress_output_stream( + async fn compress_output_server_stream( &self, _req: Request<()>, - ) -> Result, Status> { + ) -> Result, Status> { let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); let stream = futures::stream::repeat(SomeData { data }).map(Ok::<_, Status>); Ok(Response::new(Box::pin(stream))) } + + async fn compress_input_client_stream( + &self, + req: Request>, + ) -> Result, Status> { + let mut stream = req.into_inner(); + while let Some(item) = stream.next().await { + item.unwrap(); + } + Ok(Response::new(())) + } } diff --git a/tests/compression/src/server_stream.rs b/tests/compression/src/server_stream.rs index 3bbada610..de1f6e52e 100644 --- a/tests/compression/src/server_stream.rs +++ b/tests/compression/src/server_stream.rs @@ -38,7 +38,7 @@ async fn client_enabled_server_enabled() { let mut client = test_client::TestClient::new(channel).accept_gzip(); - let res = client.compress_output_stream(()).await.unwrap(); + let res = client.compress_output_server_stream(()).await.unwrap(); assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); @@ -96,7 +96,7 @@ async fn client_disabled_server_enabled() { let mut client = test_client::TestClient::new(channel); - let res = client.compress_output_stream(()).await.unwrap(); + let res = client.compress_output_server_stream(()).await.unwrap(); assert!(res.metadata().get("grpc-encoding").is_none()); @@ -147,7 +147,7 @@ async fn client_enabled_server_disabled() { let mut client = test_client::TestClient::new(channel).accept_gzip(); - let res = client.compress_output_stream(()).await.unwrap(); + let res = client.compress_output_server_stream(()).await.unwrap(); assert!(res.metadata().get("grpc-encoding").is_none()); diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs index 17767fc1f..f16d210b7 100644 --- a/tests/compression/src/util.rs +++ b/tests/compression/src/util.rs @@ -10,6 +10,7 @@ use std::{ }, task::{Context, Poll}, }; +use tower_http::map_request_body::MapRequestBodyLayer; /// A body that tracks how many bytes passes through it #[pin_project] @@ -56,3 +57,27 @@ where self.inner.size_hint() } } + +#[allow(dead_code)] +pub fn measure_request_body_size_layer( + bytes_sent_counter: Arc, +) -> MapRequestBodyLayer hyper::Body + Clone> { + MapRequestBodyLayer::new(move |mut body: hyper::Body| { + let (mut tx, new_body) = hyper::Body::channel(); + + let bytes_sent_counter = bytes_sent_counter.clone(); + tokio::spawn(async move { + while let Some(chunk) = body.data().await { + let chunk = chunk.unwrap(); + bytes_sent_counter.fetch_add(chunk.len(), Relaxed); + tx.send_data(chunk).await.unwrap(); + } + + if let Some(trailers) = body.trailers().await.unwrap() { + tx.send_trailers(trailers).await.unwrap(); + } + }); + + new_body + }) +} diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index 7d6fa69ff..3493208d4 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -12,6 +12,15 @@ use futures_util::{future, stream, TryStreamExt}; use http_body::Body; use std::fmt; +macro_rules! t { + ($result:expr) => { + match $result { + Ok(value) => value, + Err(status) => return status.to_http(), + } + }; +} + /// A gRPC Server handler. /// /// This will wrap some inner [`Codec`] and provide utilities to handle @@ -85,7 +94,7 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { - let encoding = CompressionEncoding::from_accept_encoding_header( + let accept_encoding = CompressionEncoding::from_accept_encoding_header( req.headers(), self.send_compression_encodings, ); @@ -96,7 +105,7 @@ where return self .map_response::>>>( Err(status), - encoding, + accept_encoding, ); } }; @@ -106,7 +115,7 @@ where .await .map(|r| r.map(|m| stream::once(future::ok(m)))); - self.map_response(response, encoding) + self.map_response(response, accept_encoding) } /// Handle a server side streaming request. @@ -121,7 +130,7 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { - let encoding = CompressionEncoding::from_accept_encoding_header( + let accept_encoding = CompressionEncoding::from_accept_encoding_header( req.headers(), self.send_compression_encodings, ); @@ -129,13 +138,13 @@ where let request = match self.map_request_unary(req).await { Ok(r) => r, Err(status) => { - return self.map_response::(Err(status), encoding); + return self.map_response::(Err(status), accept_encoding); } }; let response = service.call(request).await; - self.map_response(response, encoding) + self.map_response(response, accept_encoding) } /// Handle a client side streaming gRPC request. @@ -149,14 +158,18 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send + 'static, { - // TODO(david): encoding + let accept_encoding = CompressionEncoding::from_accept_encoding_header( + req.headers(), + self.accept_compression_encodings, + ); + + let request = t!(self.map_request_streaming(req)); - let request = self.map_request_streaming(req); let response = service .call(request) .await .map(|r| r.map(|m| stream::once(future::ok(m)))); - self.map_response(response, None) + self.map_response(response, accept_encoding) } /// Handle a bi-directional streaming gRPC request. @@ -171,8 +184,7 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { - // TODO(david): encoding - let request = self.map_request_streaming(req); + let request = t!(self.map_request_streaming(req)); let response = service.call(request).await; self.map_response(response, None) } @@ -185,25 +197,7 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { - // TODO(david): probably should set this directly on `Grpc`, like the client - let request_compression_encoding = if let Some(request_compression_encoding) = - CompressionEncoding::from_encoding_header(request.headers()) - { - let encoding_supported = match request_compression_encoding { - CompressionEncoding::Gzip => self.accept_compression_encodings.gzip(), - }; - - if encoding_supported { - Some(request_compression_encoding) - } else { - return Err(Status::unimplemented(format!( - "Request is compressed with `{}` which the server doesn't support", - request_compression_encoding - ))); - } - } else { - None - }; + let request_compression_encoding = self.request_encoding_if_supported(&request)?; let (parts, body) = request.into_parts(); let stream = @@ -228,15 +222,15 @@ where fn map_request_streaming( &mut self, request: http::Request, - ) -> Request> + ) -> Result>, Status> where B: Body + Send + Sync + 'static, B::Error: Into + Send, { - Request::from_http(request.map(|body| { - // TODO(david): get compression encoding from request and don't hard code `None` - Streaming::new_request(self.codec.decoder(), body, None) - })) + let encoding = self.request_encoding_if_supported(&request)?; + let request = + request.map(|body| Streaming::new_request(self.codec.decoder(), body, encoding)); + Ok(Request::from_http(request)) } fn map_response( @@ -272,6 +266,30 @@ where http::Response::from_parts(parts, BoxBody::new(body)) } + + fn request_encoding_if_supported( + &self, + request: &http::Request, + ) -> Result, Status> { + if let Some(request_compression_encoding) = + CompressionEncoding::from_encoding_header(request.headers()) + { + let encoding_supported = match request_compression_encoding { + CompressionEncoding::Gzip => self.accept_compression_encodings.gzip(), + }; + + if encoding_supported { + Ok(Some(request_compression_encoding)) + } else { + Err(Status::unimplemented(format!( + "Request is compressed with `{}` which the server doesn't support", + request_compression_encoding + ))) + } + } else { + Ok(None) + } + } } impl fmt::Debug for Grpc { From ab10667763db3bc8c64ab009804134573698236d Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 28 Jun 2021 17:05:29 +0200 Subject: [PATCH 12/29] Bidirectional streaming compression --- tests/compression/proto/test.proto | 1 + tests/compression/src/bidirectional_stream.rs | 76 +++++++++++++++++++ tests/compression/src/lib.rs | 21 ++++- tonic-build/src/client.rs | 17 +---- tonic/src/client/grpc.rs | 4 +- tonic/src/server/grpc.rs | 14 +++- 6 files changed, 109 insertions(+), 24 deletions(-) create mode 100644 tests/compression/src/bidirectional_stream.rs diff --git a/tests/compression/proto/test.proto b/tests/compression/proto/test.proto index 05893cd90..820188803 100644 --- a/tests/compression/proto/test.proto +++ b/tests/compression/proto/test.proto @@ -9,6 +9,7 @@ service Test { rpc CompressInputUnary(SomeData) returns (google.protobuf.Empty); rpc CompressOutputServerStream(google.protobuf.Empty) returns (stream SomeData); rpc CompressInputClientStream(stream SomeData) returns (google.protobuf.Empty); + rpc CompressInputOutputBidirectionalStream(stream SomeData) returns (stream SomeData); } message SomeData { diff --git a/tests/compression/src/bidirectional_stream.rs b/tests/compression/src/bidirectional_stream.rs new file mode 100644 index 000000000..03f4e5d28 --- /dev/null +++ b/tests/compression/src/bidirectional_stream.rs @@ -0,0 +1,76 @@ +use super::*; + +#[tokio::test(flavor = "multi_thread")] +async fn client_enabled_server_enabled() { + let svc = test_server::TestServer::new(Svc).accept_gzip().send_gzip(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + fn assert_right_encoding(req: http::Request) -> http::Request { + assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); + req + } + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .map_request(assert_right_encoding) + .layer(measure_request_body_size_layer(bytes_sent_counter.clone())) + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel) + .send_gzip() + .accept_gzip(); + + let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); + let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); + let req = Request::new(Box::pin(stream)); + + let res = client + .compress_input_output_bidirectional_stream(req) + .await + .unwrap(); + + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + + let mut stream: Streaming = res.into_inner(); + + stream + .next() + .await + .expect("stream empty") + .expect("item was error"); + + stream + .next() + .await + .expect("stream empty") + .expect("item was error"); + + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); +} diff --git a/tests/compression/src/lib.rs b/tests/compression/src/lib.rs index 18e0fe7b5..4ff5861ad 100644 --- a/tests/compression/src/lib.rs +++ b/tests/compression/src/lib.rs @@ -17,6 +17,7 @@ use tonic::{ use tower::{layer::layer_fn, Service, ServiceBuilder}; use tower_http::{map_request_body::MapRequestBodyLayer, map_response_body::MapResponseBodyLayer}; +mod bidirectional_stream; mod client_stream; mod compressing_request; mod compressing_response; @@ -25,8 +26,7 @@ mod util; tonic::include_proto!("test"); -// TODO(david): bidirectional streaming - +#[derive(Debug)] struct Svc; const UNCOMPRESSED_MIN_BODY_SIZE: usize = 1024; @@ -67,4 +67,21 @@ impl test_server::Test for Svc { } Ok(Response::new(())) } + + type CompressInputOutputBidirectionalStreamStream = + Pin> + Send + Sync + 'static>>; + + async fn compress_input_output_bidirectional_stream( + &self, + req: Request>, + ) -> Result, Status> { + let mut stream = req.into_inner(); + while let Some(item) = stream.next().await { + item.unwrap(); + } + + let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); + let stream = futures::stream::repeat(SomeData { data }).map(Ok::<_, Status>); + Ok(Response::new(Box::pin(stream))) + } } diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index 984528010..6297bb888 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -20,8 +20,6 @@ pub fn generate( let connect = generate_connect(&service_ident); let service_doc = generate_doc_comments(service.comment()); - let struct_debug = format!("{} {{{{ ... }}}}", &service_ident); - quote! { /// Generated client implementations. pub mod #client_mod { @@ -29,6 +27,7 @@ pub fn generate( use tonic::codegen::*; #service_doc + #[derive(Debug, Clone)] pub struct #service_ident { inner: tonic::client::Grpc, } @@ -73,20 +72,6 @@ pub fn generate( #methods } - - impl Clone for #service_ident { - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), - } - } - } - - impl std::fmt::Debug for #service_ident { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, #struct_debug) - } - } } } } diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 814f37e3c..46b5d5e90 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -244,10 +244,10 @@ impl Clone for Grpc { } } -impl fmt::Debug for Grpc { +impl fmt::Debug for Grpc { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Grpc") - .field("inner", &self.inner) + .field("inner", &format_args!("{}", std::any::type_name::())) .field("compression_encoding", &self.send_compression_encodings) .field( "accept_compression_encodings", diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index 3493208d4..ea4aba985 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -184,9 +184,15 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { + let accept_encoding = CompressionEncoding::from_accept_encoding_header( + req.headers(), + self.accept_compression_encodings, + ); + let request = t!(self.map_request_streaming(req)); + let response = service.call(request).await; - self.map_response(response, None) + self.map_response(response, accept_encoding) } async fn map_request_unary( @@ -236,7 +242,7 @@ where fn map_response( &mut self, response: Result, Status>, - encoding: Option, + accept_encoding: Option, ) -> http::Response where B: TryStream + Send + Sync + 'static, @@ -254,7 +260,7 @@ where http::header::HeaderValue::from_static("application/grpc"), ); - if let Some(encoding) = encoding { + if let Some(encoding) = accept_encoding { // Set the content encoding parts.headers.insert( crate::codec::compression::ENCODING_HEADER, @@ -262,7 +268,7 @@ where ); } - let body = encode_server(self.codec.encoder(), body.into_stream(), encoding); + let body = encode_server(self.codec.encoder(), body.into_stream(), accept_encoding); http::Response::from_parts(parts, BoxBody::new(body)) } From 91ff0f768f22343b8ba9b4c268eaa6c1a068b181 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 28 Jun 2021 17:30:40 +0200 Subject: [PATCH 13/29] Handle receiving unsupported encoding --- tests/compression/src/client_stream.rs | 2 +- tests/compression/src/compressing_request.rs | 4 +- tests/compression/src/compressing_response.rs | 44 ++++++++++++++++- tests/compression/src/lib.rs | 4 ++ tonic/benches/decode.rs | 1 - tonic/src/client/grpc.rs | 6 ++- tonic/src/codec/compression.rs | 47 +++++++++++++------ tonic/src/codec/decode.rs | 6 +-- tonic/src/server/grpc.rs | 22 ++------- 9 files changed, 92 insertions(+), 44 deletions(-) diff --git a/tests/compression/src/client_stream.rs b/tests/compression/src/client_stream.rs index 0ae370834..bceed0ee9 100644 --- a/tests/compression/src/client_stream.rs +++ b/tests/compression/src/client_stream.rs @@ -128,6 +128,6 @@ async fn client_enabled_server_disabled() { assert_eq!(status.code(), tonic::Code::Unimplemented); assert_eq!( status.message(), - "Request is compressed with `gzip` which the server doesn't support" + "Content is compressed with `gzip` which isn't supported" ); } diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index 8a4ac98f6..212d86bd0 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -1,8 +1,6 @@ use super::*; use http_body::Body as _; -// TODO(david): send_gzip on channel, but disabling compression of a message - #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_enabled() { let svc = test_server::TestServer::new(Svc).accept_gzip(); @@ -89,6 +87,6 @@ async fn client_enabled_server_disabled() { assert_eq!(status.code(), tonic::Code::Unimplemented); assert_eq!( status.message(), - "Request is compressed with `gzip` which the server doesn't support" + "Content is compressed with `gzip` which isn't supported" ); } diff --git a/tests/compression/src/compressing_response.rs b/tests/compression/src/compressing_response.rs index 3b22fefad..81d2d034c 100644 --- a/tests/compression/src/compressing_response.rs +++ b/tests/compression/src/compressing_response.rs @@ -1,7 +1,5 @@ use super::*; -// TODO(david): document that using a multi threaded tokio runtime is -// required (because of `block_in_place`) #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_enabled() { #[derive(Clone, Copy)] @@ -190,3 +188,45 @@ async fn client_disabled() { let bytes_sent = bytes_sent_counter.load(Relaxed); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } + +#[tokio::test(flavor = "multi_thread")] +async fn server_replying_with_unsupported_encoding() { + let svc = test_server::TestServer::new(Svc).send_gzip(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + fn add_weird_content_encoding(mut response: http::Response) -> http::Response { + response + .headers_mut() + .insert("grpc-encoding", "br".parse().unwrap()); + response + } + + tokio::spawn(async move { + Server::builder() + .layer( + ServiceBuilder::new() + .map_response(add_weird_content_encoding) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel).accept_gzip(); + let status: Status = client.compress_output_unary(()).await.unwrap_err(); + + assert_eq!(status.code(), tonic::Code::Unimplemented); + assert_eq!( + status.message(), + "Content is compressed with `br` which isn't supported" + ); +} diff --git a/tests/compression/src/lib.rs b/tests/compression/src/lib.rs index 4ff5861ad..7501bb5fd 100644 --- a/tests/compression/src/lib.rs +++ b/tests/compression/src/lib.rs @@ -1,5 +1,9 @@ #![allow(unused_imports)] +// TODO(david): document that using a multi threaded tokio runtime is +// required (because of `block_in_place`) +// TODO(david): send_gzip on channel, but disabling compression of a message + use self::util::*; use futures::{Stream, StreamExt}; use std::{ diff --git a/tonic/benches/decode.rs b/tonic/benches/decode.rs index 4750d2508..96f5b498d 100644 --- a/tonic/benches/decode.rs +++ b/tonic/benches/decode.rs @@ -22,7 +22,6 @@ macro_rules! bench { b.iter(|| { rt.block_on(async { let decoder = MockDecoder::new($message_size); - // TODO(david): add benchmark with compression let mut stream = Streaming::new_request(decoder, body.clone(), None); let mut count = 0; diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 46b5d5e90..2de56f554 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -204,8 +204,10 @@ impl Grpc { .await .map_err(|err| Status::from_error(err.into()))?; - // TODO(david): server compressing with algorithm the client doesn't know - let encoding = CompressionEncoding::from_encoding_header(response.headers()); + let encoding = CompressionEncoding::from_encoding_header( + response.headers(), + self.accept_compression_encodings, + )?; let status_code = response.status(); let trailers_only_status = Status::from_header_map(response.headers()); diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index f8a4e5415..c717290a0 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -1,3 +1,5 @@ +use crate::Status; + use super::encode::BUFFER_SIZE; use bytes::{Buf, BufMut, BytesMut}; use flate2::read::{GzDecoder, GzEncoder}; @@ -50,19 +52,39 @@ impl CompressionEncoding { let header_value = map.get(ACCEPT_ENCODING_HEADER)?; let header_value_str = header_value.to_str().ok()?; + let EnabledCompressionEncodings { gzip } = enabled_encodings; + split_by_comma(header_value_str).find_map(|value| match value { - "gzip" if enabled_encodings.gzip() => Some(CompressionEncoding::Gzip), + "gzip" if gzip => Some(CompressionEncoding::Gzip), _ => None, }) } - pub(crate) fn from_encoding_header(map: &http::HeaderMap) -> Option { - let header_value = map.get(ENCODING_HEADER)?; - let header_value_str = header_value.to_str().ok()?; + /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported. + pub(crate) fn from_encoding_header( + map: &http::HeaderMap, + enabled_encodings: EnabledCompressionEncodings, + ) -> Result, Status> { + let header_value = if let Some(value) = map.get(ENCODING_HEADER) { + value + } else { + return Ok(None); + }; + + let header_value_str = if let Ok(value) = header_value.to_str() { + value + } else { + return Ok(None); + }; + + let EnabledCompressionEncodings { gzip } = enabled_encodings; match header_value_str { - "gzip" => Some(CompressionEncoding::Gzip), - _ => None, + "gzip" if gzip => Ok(Some(CompressionEncoding::Gzip)), + other => Err(Status::unimplemented(format!( + "Content is compressed with `{}` which isn't supported", + other + ))), } } @@ -85,10 +107,10 @@ fn split_by_comma(s: &str) -> impl Iterator { s.trim().split(',').map(|s| s.trim()) } -/// Compress `len` bytes from `in_buffer` into `out_buffer`. +/// Compress `len` bytes from `decompressed_buf` into `out_buf`. pub(crate) fn compress( encoding: CompressionEncoding, - uncompressed_buf: &mut BytesMut, + decompressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error> { @@ -98,7 +120,7 @@ pub(crate) fn compress( match encoding { CompressionEncoding::Gzip => { let mut gzip_encoder = GzEncoder::new( - &uncompressed_buf[0..len], + &decompressed_buf[0..len], // FIXME: support customizing the compression level flate2::Compression::new(6), ); @@ -108,13 +130,12 @@ pub(crate) fn compress( } } - // TODO(david): is this necessary? test sending multiple requests and - // responses on the same channel - uncompressed_buf.advance(len); + decompressed_buf.advance(len); Ok(()) } +/// Decompress `len` bytes from `compressed_buf` into `out_buf`. pub(crate) fn decompress( encoding: CompressionEncoding, compressed_buf: &mut BytesMut, @@ -134,8 +155,6 @@ pub(crate) fn decompress( } } - // TODO(david): is this necessary? test sending multiple requests and - // responses on the same channel compressed_buf.advance(len); Ok(()) diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 3ffed94a8..40e4748fb 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -213,9 +213,9 @@ impl Streaming { self.decompress_buf.clear(); if let Err(err) = decompress( - // TODO(david): handle missing self.encoding - self.encoding - .expect("message was compressed but compression not enabled on server"), + self.encoding.unwrap_or_else(|| { + unreachable!("message was compressed but `Streaming.encoding` was `None`. This is a bug in Tonic. Please file an issue") + }), &mut self.buf, &mut self.decompress_buf, *len, diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index ea4aba985..160c3a994 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -277,24 +277,10 @@ where &self, request: &http::Request, ) -> Result, Status> { - if let Some(request_compression_encoding) = - CompressionEncoding::from_encoding_header(request.headers()) - { - let encoding_supported = match request_compression_encoding { - CompressionEncoding::Gzip => self.accept_compression_encodings.gzip(), - }; - - if encoding_supported { - Ok(Some(request_compression_encoding)) - } else { - Err(Status::unimplemented(format!( - "Request is compressed with `{}` which the server doesn't support", - request_compression_encoding - ))) - } - } else { - Ok(None) - } + CompressionEncoding::from_encoding_header( + request.headers(), + self.accept_compression_encodings, + ) } } From 1d2367319a70c0610d1d5925060741f867f67ec5 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 28 Jun 2021 17:47:32 +0200 Subject: [PATCH 14/29] Clean up --- examples/src/blocking/client.rs | 2 +- tests/compression/Cargo.toml | 2 -- tonic/Cargo.toml | 2 +- tonic/src/client/grpc.rs | 4 ++-- tonic/src/codec/compression.rs | 3 +-- tonic/src/transport/channel/mod.rs | 19 +++++++++---------- 6 files changed, 14 insertions(+), 18 deletions(-) diff --git a/examples/src/blocking/client.rs b/examples/src/blocking/client.rs index fe83348a9..12788e085 100644 --- a/examples/src/blocking/client.rs +++ b/examples/src/blocking/client.rs @@ -27,7 +27,7 @@ impl BlockingClient { let rt = Builder::new_multi_thread().enable_all().build().unwrap(); let client = rt.block_on(GreeterClient::connect(dst))?; - Ok(Self { client, rt }) + Ok(Self { rt, client }) } pub fn say_hello( diff --git a/tests/compression/Cargo.toml b/tests/compression/Cargo.toml index eb3516cc9..215e3d456 100644 --- a/tests/compression/Cargo.toml +++ b/tests/compression/Cargo.toml @@ -6,8 +6,6 @@ edition = "2018" publish = false license = "MIT" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] tonic = { path = "../../tonic" } prost = "0.7" diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 252729b16..c7b4f18d6 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -71,7 +71,7 @@ async-trait = { version = "0.1.13", optional = true } # transport h2 = { version = "0.3", optional = true } hyper = { version = "0.14.2", features = ["full"], optional = true } -tokio = { version = "1.0.1", features = ["net", "rt-multi-thread"], optional = true } +tokio = { version = "1.0.1", features = ["net"], optional = true } tokio-stream = "0.1" tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true } tracing-futures = { version = "0.2", optional = true } diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 2de56f554..ee1fa7f93 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -246,10 +246,10 @@ impl Clone for Grpc { } } -impl fmt::Debug for Grpc { +impl fmt::Debug for Grpc { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Grpc") - .field("inner", &format_args!("{}", std::any::type_name::())) + .field("inner", &self.inner) .field("compression_encoding", &self.send_compression_encodings) .field( "accept_compression_encodings", diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index c717290a0..0cd326275 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -1,6 +1,5 @@ -use crate::Status; - use super::encode::BUFFER_SIZE; +use crate::Status; use bytes::{Buf, BufMut, BytesMut}; use flate2::read::{GzDecoder, GzEncoder}; use std::fmt; diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index dcf77bcfa..c2ecd8e39 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -29,15 +29,15 @@ use tokio::{ sync::mpsc::{channel, Sender}, }; +use tower::balance::p2c::Balance; use tower::{ - balance::p2c::Balance, buffer::{self, Buffer}, discover::{Change, Discover}, - util::BoxService, + util::{BoxService, Either}, Service, }; -type Svc = BoxService, Response, crate::Error>; +type Svc = Either, Response, crate::Error>>; const DEFAULT_BUFFER_SIZE: usize = 1024; @@ -137,11 +137,10 @@ impl Channel { C::Future: Unpin + Send, C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, { - let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); + let buffer_size = endpoint.buffer_size.clone().unwrap_or(DEFAULT_BUFFER_SIZE); let svc = Connection::lazy(connector, endpoint); - let svc = BoxService::new(svc); - let svc = Buffer::new(svc, buffer_size); + let svc = Buffer::new(Either::A(svc), buffer_size); Channel { svc } } @@ -153,13 +152,12 @@ impl Channel { C::Future: Unpin + Send, C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, { - let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); + let buffer_size = endpoint.buffer_size.clone().unwrap_or(DEFAULT_BUFFER_SIZE); let svc = Connection::connect(connector, endpoint) .await .map_err(super::Error::from_source)?; - let svc = BoxService::new(svc); - let svc = Buffer::new(svc, buffer_size); + let svc = Buffer::new(Either::A(svc), buffer_size); Ok(Channel { svc }) } @@ -173,7 +171,7 @@ impl Channel { let svc = Balance::new(discover); let svc = BoxService::new(svc); - let svc = Buffer::new(svc, buffer_size); + let svc = Buffer::new(Either::B(svc), buffer_size); Channel { svc } } @@ -190,6 +188,7 @@ impl Service> for Channel { fn call(&mut self, request: http::Request) -> Self::Future { let inner = Service::call(&mut self.svc, request); + ResponseFuture { inner } } } From cc6b91fc6b07fa97603228ca35d27596c62ca8bc Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 28 Jun 2021 18:00:01 +0200 Subject: [PATCH 15/29] Add note to future self --- tests/compression/src/compressing_request.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index 212d86bd0..bd58a0209 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -89,4 +89,9 @@ async fn client_enabled_server_disabled() { status.message(), "Content is compressed with `gzip` which isn't supported" ); + + // TODO(david): include header with which encodings are supported as per the spec: + // + // > The server will then include a grpc-accept-encoding response header which specifies the + // algorithms that the server accepts. } From cf81479d9aac485c16bf8aa7a6e42fea70d799b1 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 29 Jun 2021 14:27:22 +0200 Subject: [PATCH 16/29] Support disabling compression for individual responses --- tests/compression/proto/test.proto | 1 + tests/compression/src/bidirectional_stream.rs | 4 +- tests/compression/src/client_stream.rs | 52 +++++- tests/compression/src/compressing_request.rs | 4 +- tests/compression/src/compressing_response.rs | 164 +++++++++++++++++- tests/compression/src/lib.rs | 49 +++++- tests/compression/src/server_stream.rs | 6 +- tonic-build/src/client.rs | 2 +- tonic/src/codec/compression.rs | 17 ++ tonic/src/codec/encode.rs | 22 ++- tonic/src/codec/prost.rs | 8 +- tonic/src/response.rs | 10 +- tonic/src/server/grpc.rs | 74 ++++++-- 13 files changed, 375 insertions(+), 38 deletions(-) diff --git a/tests/compression/proto/test.proto b/tests/compression/proto/test.proto index 820188803..325471b2f 100644 --- a/tests/compression/proto/test.proto +++ b/tests/compression/proto/test.proto @@ -9,6 +9,7 @@ service Test { rpc CompressInputUnary(SomeData) returns (google.protobuf.Empty); rpc CompressOutputServerStream(google.protobuf.Empty) returns (stream SomeData); rpc CompressInputClientStream(stream SomeData) returns (google.protobuf.Empty); + rpc CompressOutputClientStream(stream SomeData) returns (SomeData); rpc CompressInputOutputBidirectionalStream(stream SomeData) returns (stream SomeData); } diff --git a/tests/compression/src/bidirectional_stream.rs b/tests/compression/src/bidirectional_stream.rs index 03f4e5d28..ca5397790 100644 --- a/tests/compression/src/bidirectional_stream.rs +++ b/tests/compression/src/bidirectional_stream.rs @@ -2,7 +2,9 @@ use super::*; #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_enabled() { - let svc = test_server::TestServer::new(Svc).accept_gzip().send_gzip(); + let svc = test_server::TestServer::new(Svc::default()) + .accept_gzip() + .send_gzip(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); diff --git a/tests/compression/src/client_stream.rs b/tests/compression/src/client_stream.rs index bceed0ee9..8851fb773 100644 --- a/tests/compression/src/client_stream.rs +++ b/tests/compression/src/client_stream.rs @@ -3,7 +3,7 @@ use http_body::Body as _; #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_enabled() { - let svc = test_server::TestServer::new(Svc).accept_gzip(); + let svc = test_server::TestServer::new(Svc::default()).accept_gzip(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -51,7 +51,7 @@ async fn client_enabled_server_enabled() { #[tokio::test(flavor = "multi_thread")] async fn client_disabled_server_enabled() { - let svc = test_server::TestServer::new(Svc).accept_gzip(); + let svc = test_server::TestServer::new(Svc::default()).accept_gzip(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -99,7 +99,7 @@ async fn client_disabled_server_enabled() { #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_disabled() { - let svc = test_server::TestServer::new(Svc); + let svc = test_server::TestServer::new(Svc::default()); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -131,3 +131,49 @@ async fn client_enabled_server_disabled() { "Content is compressed with `gzip` which isn't supported" ); } + +#[tokio::test(flavor = "multi_thread")] +async fn compressing_response_from_client_stream() { + let svc = test_server::TestServer::new(Svc::default()).send_gzip(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel).accept_gzip(); + + let stream = futures::stream::iter(vec![]); + let req = Request::new(Box::pin(stream)); + + let res = client.compress_output_client_stream(req).await.unwrap(); + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); +} diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index bd58a0209..318c2861a 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -3,7 +3,7 @@ use http_body::Body as _; #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_enabled() { - let svc = test_server::TestServer::new(Svc).accept_gzip(); + let svc = test_server::TestServer::new(Svc::default()).accept_gzip(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -57,7 +57,7 @@ async fn client_enabled_server_enabled() { #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_disabled() { - let svc = test_server::TestServer::new(Svc); + let svc = test_server::TestServer::new(Svc::default()); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); diff --git a/tests/compression/src/compressing_response.rs b/tests/compression/src/compressing_response.rs index 81d2d034c..eb3c73b91 100644 --- a/tests/compression/src/compressing_response.rs +++ b/tests/compression/src/compressing_response.rs @@ -29,7 +29,7 @@ async fn client_enabled_server_enabled() { } } - let svc = test_server::TestServer::new(Svc).send_gzip(); + let svc = test_server::TestServer::new(Svc::default()).send_gzip(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -75,7 +75,7 @@ async fn client_enabled_server_enabled() { #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_disabled() { - let svc = test_server::TestServer::new(Svc); + let svc = test_server::TestServer::new(Svc::default()); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -145,7 +145,7 @@ async fn client_disabled() { } } - let svc = test_server::TestServer::new(Svc).send_gzip(); + let svc = test_server::TestServer::new(Svc::default()).send_gzip(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -191,7 +191,7 @@ async fn client_disabled() { #[tokio::test(flavor = "multi_thread")] async fn server_replying_with_unsupported_encoding() { - let svc = test_server::TestServer::new(Svc).send_gzip(); + let svc = test_server::TestServer::new(Svc::default()).send_gzip(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -230,3 +230,159 @@ async fn server_replying_with_unsupported_encoding() { "Content is compressed with `br` which isn't supported" ); } + +#[tokio::test(flavor = "multi_thread")] +async fn disabling_compression_on_single_response() { + let svc = test_server::TestServer::new(Svc { + disable_compressing_on_response: true, + }) + .send_gzip(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel).accept_gzip(); + + let res = client.compress_output_unary(()).await.unwrap(); + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); +} + +#[tokio::test(flavor = "multi_thread")] +async fn disabling_compression_on_response_but_keeping_compression_on_stream() { + let svc = test_server::TestServer::new(Svc { + disable_compressing_on_response: true, + }) + .send_gzip(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel).accept_gzip(); + + let res = client.compress_output_server_stream(()).await.unwrap(); + + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + + let mut stream: Streaming = res.into_inner(); + + stream + .next() + .await + .expect("stream empty") + .expect("item was error"); + assert!(dbg!(bytes_sent_counter.load(Relaxed)) < UNCOMPRESSED_MIN_BODY_SIZE); + + stream + .next() + .await + .expect("stream empty") + .expect("item was error"); + assert!(dbg!(bytes_sent_counter.load(Relaxed)) < UNCOMPRESSED_MIN_BODY_SIZE); +} + +#[tokio::test(flavor = "multi_thread")] +async fn disabling_compression_on_response_from_client_stream() { + let svc = test_server::TestServer::new(Svc { + disable_compressing_on_response: true, + }) + .send_gzip(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let bytes_sent_counter = bytes_sent_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: bytes_sent_counter.clone(), + } + })) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + } + }); + + let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel).accept_gzip(); + + let stream = futures::stream::iter(vec![]); + let req = Request::new(Box::pin(stream)); + + let res = client.compress_output_client_stream(req).await.unwrap(); + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + let bytes_sent = bytes_sent_counter.load(Relaxed); + assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); +} diff --git a/tests/compression/src/lib.rs b/tests/compression/src/lib.rs index 7501bb5fd..4a174ddb5 100644 --- a/tests/compression/src/lib.rs +++ b/tests/compression/src/lib.rs @@ -31,17 +31,38 @@ mod util; tonic::include_proto!("test"); #[derive(Debug)] -struct Svc; +struct Svc { + disable_compressing_on_response: bool, +} + +impl Default for Svc { + fn default() -> Self { + Self { + disable_compressing_on_response: false, + } + } +} const UNCOMPRESSED_MIN_BODY_SIZE: usize = 1024; +impl Svc { + fn prepare_response(&self, mut res: Response) -> Response { + if self.disable_compressing_on_response { + res.disable_compression(); + } + + res + } +} + #[tonic::async_trait] impl test_server::Test for Svc { async fn compress_output_unary(&self, _req: Request<()>) -> Result, Status> { let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE]; - Ok(Response::new(SomeData { + + Ok(self.prepare_response(Response::new(SomeData { data: data.to_vec(), - })) + }))) } async fn compress_input_unary(&self, req: Request) -> Result, Status> { @@ -58,7 +79,7 @@ impl test_server::Test for Svc { ) -> Result, Status> { let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); let stream = futures::stream::repeat(SomeData { data }).map(Ok::<_, Status>); - Ok(Response::new(Box::pin(stream))) + Ok(self.prepare_response(Response::new(Box::pin(stream)))) } async fn compress_input_client_stream( @@ -69,7 +90,23 @@ impl test_server::Test for Svc { while let Some(item) = stream.next().await { item.unwrap(); } - Ok(Response::new(())) + Ok(self.prepare_response(Response::new(()))) + } + + async fn compress_output_client_stream( + &self, + req: Request>, + ) -> Result, Status> { + let mut stream = req.into_inner(); + while let Some(item) = stream.next().await { + item.unwrap(); + } + + let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE]; + + Ok(self.prepare_response(Response::new(SomeData { + data: data.to_vec(), + }))) } type CompressInputOutputBidirectionalStreamStream = @@ -86,6 +123,6 @@ impl test_server::Test for Svc { let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); let stream = futures::stream::repeat(SomeData { data }).map(Ok::<_, Status>); - Ok(Response::new(Box::pin(stream))) + Ok(self.prepare_response(Response::new(Box::pin(stream)))) } } diff --git a/tests/compression/src/server_stream.rs b/tests/compression/src/server_stream.rs index de1f6e52e..b63f9d9eb 100644 --- a/tests/compression/src/server_stream.rs +++ b/tests/compression/src/server_stream.rs @@ -3,7 +3,7 @@ use tonic::Streaming; #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_enabled() { - let svc = test_server::TestServer::new(Svc).send_gzip(); + let svc = test_server::TestServer::new(Svc::default()).send_gzip(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -61,7 +61,7 @@ async fn client_enabled_server_enabled() { #[tokio::test(flavor = "multi_thread")] async fn client_disabled_server_enabled() { - let svc = test_server::TestServer::new(Svc).send_gzip(); + let svc = test_server::TestServer::new(Svc::default()).send_gzip(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -112,7 +112,7 @@ async fn client_disabled_server_enabled() { #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_disabled() { - let svc = test_server::TestServer::new(Svc); + let svc = test_server::TestServer::new(Svc::default()); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index 6297bb888..309141452 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -201,7 +201,7 @@ fn generate_client_streaming( pub async fn #ident( &mut self, request: impl tonic::IntoStreamingRequest - ) -> Result, tonic::Status> { + ) -> Result, tonic::Status> where T: std::fmt::Debug { self.inner.ready().await.map_err(|e| { tonic::Status::new(tonic::Code::Unknown, format!("Service was not ready: {}", e.into())) })?; diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 0cd326275..1880dbb15 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -158,3 +158,20 @@ pub(crate) fn decompress( Ok(()) } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum SingleMessageCompressionOverride { + /// Inherit whatever compression is already configured. If the stream is compressed this + /// message will also be configured. + /// + /// This is the default. + Inherit, + /// Don't compress this message, even if compression is enabled on the stream. + Disable, +} + +impl Default for SingleMessageCompressionOverride { + fn default() -> Self { + Self::Inherit + } +} diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 08319045a..f6fc09541 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -1,5 +1,5 @@ use super::{ - compression::{compress, CompressionEncoding}, + compression::{compress, CompressionEncoding, SingleMessageCompressionOverride}, EncodeBuf, Encoder, HEADER_SIZE, }; use crate::{Code, Status}; @@ -20,13 +20,14 @@ pub(crate) fn encode_server( encoder: T, source: U, compression_encoding: Option, + compression_override: SingleMessageCompressionOverride, ) -> EncodeBody>> where T: Encoder + Send + Sync + 'static, T::Item: Send + Sync, U: Stream> + Send + Sync + 'static, { - let stream = encode(encoder, source, compression_encoding).into_stream(); + let stream = encode(encoder, source, compression_encoding, compression_override).into_stream(); EncodeBody::new_server(stream) } @@ -40,7 +41,13 @@ where T::Item: Send + Sync, U: Stream + Send + Sync + 'static, { - let stream = encode(encoder, source.map(Ok), compression_encoding).into_stream(); + let stream = encode( + encoder, + source.map(Ok), + compression_encoding, + SingleMessageCompressionOverride::default(), + ) + .into_stream(); EncodeBody::new_client(stream) } @@ -48,6 +55,7 @@ fn encode( mut encoder: T, source: U, compression_encoding: Option, + compression_override: SingleMessageCompressionOverride, ) -> impl TryStream where T: Encoder, @@ -56,11 +64,13 @@ where async_stream::stream! { let mut buf = BytesMut::with_capacity(BUFFER_SIZE); - let (compression_enabled, mut uncompression_buf) = match compression_encoding { + let (compression_enabled_for_stream, mut uncompression_buf) = match compression_encoding { Some(CompressionEncoding::Gzip) => (true, BytesMut::with_capacity(BUFFER_SIZE)), None => (false, BytesMut::new()), }; + let compress_item = compression_enabled_for_stream && compression_override == SingleMessageCompressionOverride::Inherit; + futures_util::pin_mut!(source); loop { @@ -71,7 +81,7 @@ where buf.advance_mut(HEADER_SIZE); } - if compression_enabled { + if compress_item { uncompression_buf.clear(); encoder.encode(item, &mut EncodeBuf::new(&mut uncompression_buf)) @@ -95,7 +105,7 @@ where assert!(len <= std::u32::MAX as usize); { let mut buf = &mut buf[..HEADER_SIZE]; - buf.put_u8(compression_enabled as u8); + buf.put_u8(compress_item as u8); buf.put_u32(len as u32); } diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index 027b4e6ba..0db3ee06d 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -77,6 +77,7 @@ fn from_decode_error(error: prost1::DecodeError) -> crate::Status { #[cfg(test)] mod tests { + use crate::codec::compression::SingleMessageCompressionOverride; use crate::codec::{ encode_server, DecodeBuf, Decoder, EncodeBuf, Encoder, Streaming, HEADER_SIZE, }; @@ -121,7 +122,12 @@ mod tests { let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000); let source = futures_util::stream::iter(messages); - let body = encode_server(encoder, source, None); + let body = encode_server( + encoder, + source, + None, + SingleMessageCompressionOverride::default(), + ); futures_util::pin_mut!(body); diff --git a/tonic/src/response.rs b/tonic/src/response.rs index 87f59b4e4..d7f369a99 100644 --- a/tonic/src/response.rs +++ b/tonic/src/response.rs @@ -1,4 +1,6 @@ -use crate::{metadata::MetadataMap, Extensions}; +use crate::{ + codec::compression::SingleMessageCompressionOverride, metadata::MetadataMap, Extensions, +}; /// A gRPC response and metadata from an RPC call. #[derive(Debug)] @@ -107,6 +109,12 @@ impl Response { pub fn extensions_mut(&mut self) -> &mut Extensions { &mut self.extensions } + + // TODO(david): docs + pub fn disable_compression(&mut self) { + self.extensions_mut() + .insert(SingleMessageCompressionOverride::Disable); + } } #[cfg(test)] diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index 160c3a994..bed861009 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -1,7 +1,9 @@ use crate::{ body::BoxBody, codec::{ - compression::{CompressionEncoding, EnabledCompressionEncodings}, + compression::{ + CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride, + }, encode_server, Codec, Streaming, }, server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService}, @@ -106,6 +108,7 @@ where .map_response::>>>( Err(status), accept_encoding, + SingleMessageCompressionOverride::default(), ); } }; @@ -115,7 +118,9 @@ where .await .map(|r| r.map(|m| stream::once(future::ok(m)))); - self.map_response(response, accept_encoding) + let compression_override = compression_override_from_response(&response); + + self.map_response(response, accept_encoding, compression_override) } /// Handle a server side streaming request. @@ -138,13 +143,23 @@ where let request = match self.map_request_unary(req).await { Ok(r) => r, Err(status) => { - return self.map_response::(Err(status), accept_encoding); + return self.map_response::( + Err(status), + accept_encoding, + SingleMessageCompressionOverride::default(), + ); } }; let response = service.call(request).await; - self.map_response(response, accept_encoding) + self.map_response( + response, + accept_encoding, + // disabling compression of individual stream items must be done on + // the items themselves + SingleMessageCompressionOverride::default(), + ) } /// Handle a client side streaming gRPC request. @@ -157,10 +172,11 @@ where S: ClientStreamingService, B: Body + Send + Sync + 'static, B::Error: Into + Send + 'static, + T: std::fmt::Debug, { let accept_encoding = CompressionEncoding::from_accept_encoding_header( req.headers(), - self.accept_compression_encodings, + self.send_compression_encodings, ); let request = t!(self.map_request_streaming(req)); @@ -169,7 +185,10 @@ where .call(request) .await .map(|r| r.map(|m| stream::once(future::ok(m)))); - self.map_response(response, accept_encoding) + + let compression_override = compression_override_from_response(&response); + + self.map_response(response, accept_encoding, compression_override) } /// Handle a bi-directional streaming gRPC request. @@ -186,13 +205,18 @@ where { let accept_encoding = CompressionEncoding::from_accept_encoding_header( req.headers(), - self.accept_compression_encodings, + self.send_compression_encodings, ); let request = t!(self.map_request_streaming(req)); let response = service.call(request).await; - self.map_response(response, accept_encoding) + + self.map_response( + response, + accept_encoding, + SingleMessageCompressionOverride::default(), + ) } async fn map_request_unary( @@ -243,6 +267,7 @@ where &mut self, response: Result, Status>, accept_encoding: Option, + compression_override: SingleMessageCompressionOverride, ) -> http::Response where B: TryStream + Send + Sync + 'static, @@ -268,7 +293,12 @@ where ); } - let body = encode_server(self.codec.encoder(), body.into_stream(), accept_encoding); + let body = encode_server( + self.codec.encoder(), + body.into_stream(), + accept_encoding, + compression_override, + ); http::Response::from_parts(parts, BoxBody::new(body)) } @@ -286,6 +316,30 @@ where impl fmt::Debug for Grpc { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Grpc").field("codec", &self.codec).finish() + f.debug_struct("Grpc") + .field("codec", &self.codec) + .field( + "accept_compression_encodings", + &self.accept_compression_encodings, + ) + .field( + "send_compression_encodings", + &self.send_compression_encodings, + ) + .finish() } } + +fn compression_override_from_response( + res: &Result, E>, +) -> SingleMessageCompressionOverride { + res.as_ref() + .ok() + .and_then(|response| { + response + .extensions() + .get::() + .copied() + }) + .unwrap_or_default() +} From 42766978f524e5e41e8f7b6020f184ec9e43460c Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 29 Jun 2021 14:58:31 +0200 Subject: [PATCH 17/29] Add docs --- tonic/Cargo.toml | 2 +- tonic/src/client/grpc.rs | 50 +++++++++++++++++++++++++++++++++++++++ tonic/src/response.rs | 8 ++++++- tonic/src/server/grpc.rs | 51 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 109 insertions(+), 2 deletions(-) diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index c7b4f18d6..252729b16 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -71,7 +71,7 @@ async-trait = { version = "0.1.13", optional = true } # transport h2 = { version = "0.3", optional = true } hyper = { version = "0.14.2", features = ["full"], optional = true } -tokio = { version = "1.0.1", features = ["net"], optional = true } +tokio = { version = "1.0.1", features = ["net", "rt-multi-thread"], optional = true } tokio-stream = "0.1" tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true } tracing-futures = { version = "0.2", optional = true } diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index ee1fa7f93..a79d5d94f 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -47,11 +47,61 @@ impl Grpc { } } + /// Compress requests with `gzip`. + /// + /// Requires the server to accept `gzip` otherwise it might return an error. + /// + /// # Example + /// + /// The most common way of using this is through a client generated by tonic-build: + /// + /// ```rust + /// use tonic::transport::Channel; + /// # struct TestClient(T); + /// # impl TestClient { + /// # fn new(channel: T) -> Self { Self(channel) } + /// # fn send_gzip(self) -> Self { self } + /// # } + /// + /// # async { + /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) + /// .connect() + /// .await + /// .unwrap(); + /// + /// let client = TestClient::new(channel).send_gzip(); + /// # }; + /// ``` pub fn send_gzip(mut self) -> Self { self.send_compression_encodings = Some(CompressionEncoding::Gzip); self } + /// Enable accepting `gzip` compressed responses. + /// + /// Requires the server to also support sending compressed responses. + /// + /// # Example + /// + /// The most common way of using this is through a client generated by tonic-build: + /// + /// ```rust + /// use tonic::transport::Channel; + /// # struct TestClient(T); + /// # impl TestClient { + /// # fn new(channel: T) -> Self { Self(channel) } + /// # fn accept_gzip(self) -> Self { self } + /// # } + /// + /// # async { + /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) + /// .connect() + /// .await + /// .unwrap(); + /// + /// let client = TestClient::new(channel).accept_gzip(); + /// # }; + /// ``` pub fn accept_gzip(mut self) -> Self { self.accept_compression_encodings.enable_gzip(); self diff --git a/tonic/src/response.rs b/tonic/src/response.rs index d7f369a99..b6cf4f23a 100644 --- a/tonic/src/response.rs +++ b/tonic/src/response.rs @@ -110,7 +110,13 @@ impl Response { &mut self.extensions } - // TODO(david): docs + /// Disable compression of the response body. + /// + /// This disables compression of this response's body, even if compression is enabled on the + /// server. + /// + /// **Note** this only has effect on responses to unary requests. Response streams will still + /// be compressed according to the configuration of the server. pub fn disable_compression(&mut self) { self.extensions_mut() .insert(SingleMessageCompressionOverride::Disable); diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index bed861009..b472b300b 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -54,11 +54,62 @@ where } } + /// Enable accepting `gzip` compressed requests. + /// + /// If a request with an unsupported encoding is received the server will respond with + /// [`Code::UnUnimplemented`](crate::Code). + /// + /// # Example + /// + /// The most common way of using this is through a server generated by tonic-build: + /// + /// ```rust + /// # struct Svc; + /// # struct ExampleServer(T); + /// # impl ExampleServer { + /// # fn new(svc: T) -> Self { Self(svc) } + /// # fn accept_gzip(self) -> Self { self } + /// # } + /// # #[tonic::async_trait] + /// # trait Example {} + /// + /// #[tonic::async_trait] + /// impl Example for Svc { + /// // ... + /// } + /// + /// let service = ExampleServer::new(Svc).accept_gzip(); + /// ``` pub fn accept_gzip(mut self) -> Self { self.accept_compression_encodings.enable_gzip(); self } + /// Enable sending `gzip` compressed responses. + /// + /// Requires the client to also support receiving compressed responses. + /// + /// # Example + /// + /// The most common way of using this is through a server generated by tonic-build: + /// + /// ```rust + /// # struct Svc; + /// # struct ExampleServer(T); + /// # impl ExampleServer { + /// # fn new(svc: T) -> Self { Self(svc) } + /// # fn send_gzip(self) -> Self { self } + /// # } + /// # #[tonic::async_trait] + /// # trait Example {} + /// + /// #[tonic::async_trait] + /// impl Example for Svc { + /// // ... + /// } + /// + /// let service = ExampleServer::new(Svc).send_gzip(); + /// ``` pub fn send_gzip(mut self) -> Self { self.send_compression_encodings.enable_gzip(); self From 5db1d6a1f69d1e9a717b021c808e743a3e88ca9f Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 29 Jun 2021 15:04:44 +0200 Subject: [PATCH 18/29] Add compression examples --- examples/Cargo.toml | 8 ++++++ examples/src/compression/client.rs | 27 ++++++++++++++++++++ examples/src/compression/server.rs | 40 ++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+) create mode 100644 examples/src/compression/client.rs create mode 100644 examples/src/compression/server.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 083f21174..855f05f30 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -150,6 +150,14 @@ path = "src/hyper_warp_multiplex/client.rs" name = "hyper-warp-multiplex-server" path = "src/hyper_warp_multiplex/server.rs" +[[bin]] +name = "compression-server" +path = "src/compression/server.rs" + +[[bin]] +name = "compression-client" +path = "src/compression/client.rs" + [dependencies] tonic = { path = "../tonic", features = ["tls"] } prost = "0.7" diff --git a/examples/src/compression/client.rs b/examples/src/compression/client.rs new file mode 100644 index 000000000..77ffeebe9 --- /dev/null +++ b/examples/src/compression/client.rs @@ -0,0 +1,27 @@ +use hello_world::greeter_client::GreeterClient; +use hello_world::HelloRequest; +use tonic::transport::Channel; + +pub mod hello_world { + tonic::include_proto!("helloworld"); +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let channel = Channel::builder("http://[::1]:50051".parse().unwrap()) + .connect() + .await + .unwrap(); + + let mut client = GreeterClient::new(channel).send_gzip().accept_gzip(); + + let request = tonic::Request::new(HelloRequest { + name: "Tonic".into(), + }); + + let response = client.say_hello(request).await?; + + dbg!(response); + + Ok(()) +} diff --git a/examples/src/compression/server.rs b/examples/src/compression/server.rs new file mode 100644 index 000000000..36f5081fd --- /dev/null +++ b/examples/src/compression/server.rs @@ -0,0 +1,40 @@ +use tonic::{transport::Server, Request, Response, Status}; + +use hello_world::greeter_server::{Greeter, GreeterServer}; +use hello_world::{HelloReply, HelloRequest}; + +pub mod hello_world { + tonic::include_proto!("helloworld"); +} + +#[derive(Default)] +pub struct MyGreeter {} + +#[tonic::async_trait] +impl Greeter for MyGreeter { + async fn say_hello( + &self, + request: Request, + ) -> Result, Status> { + println!("Got a request from {:?}", request.remote_addr()); + + let reply = hello_world::HelloReply { + message: format!("Hello {}!", request.into_inner().name), + }; + Ok(Response::new(reply)) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = "[::1]:50051".parse().unwrap(); + let greeter = MyGreeter::default(); + + println!("GreeterServer listening on {}", addr); + + let service = GreeterServer::new(greeter).send_gzip().accept_gzip(); + + Server::builder().add_service(service).serve(addr).await?; + + Ok(()) +} From 755967d9b896ed92eef6428137b66b749002b291 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 29 Jun 2021 16:12:59 +0200 Subject: [PATCH 19/29] Disable compression behind feature flag --- tests/compression/Cargo.toml | 4 +- tonic-build/Cargo.toml | 1 + tonic-build/src/server.rs | 42 +++++++++---- tonic/Cargo.toml | 3 +- tonic/src/client/grpc.rs | 100 ++++++++++++++++++++--------- tonic/src/codec/decode.rs | 114 ++++++++++++++++++++++----------- tonic/src/codec/encode.rs | 61 ++++++++++++------ tonic/src/codec/mod.rs | 3 + tonic/src/codegen.rs | 1 + tonic/src/lib.rs | 5 ++ tonic/src/response.rs | 8 +-- tonic/src/server/grpc.rs | 119 ++++++++++++++++++++++++++++------- 12 files changed, 335 insertions(+), 126 deletions(-) diff --git a/tests/compression/Cargo.toml b/tests/compression/Cargo.toml index 215e3d456..6646d39c6 100644 --- a/tests/compression/Cargo.toml +++ b/tests/compression/Cargo.toml @@ -7,7 +7,7 @@ publish = false license = "MIT" [dependencies] -tonic = { path = "../../tonic" } +tonic = { path = "../../tonic", features = ["compression"] } prost = "0.7" tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] } tower = { version = "0.4", features = [] } @@ -21,4 +21,4 @@ pin-project = "1.0" hyper = "0.14" [build-dependencies] -tonic-build = { path = "../../tonic-build" } +tonic-build = { path = "../../tonic-build", features = ["compression"] } diff --git a/tonic-build/Cargo.toml b/tonic-build/Cargo.toml index 08af62c2a..529cd8875 100644 --- a/tonic-build/Cargo.toml +++ b/tonic-build/Cargo.toml @@ -26,6 +26,7 @@ default = ["transport", "rustfmt", "prost"] rustfmt = [] transport = [] prost = ["prost-build"] +compression = [] [package.metadata.docs.rs] all-features = true diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index ae0a976d7..1004ee701 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -36,6 +36,32 @@ pub fn generate( ); let transport = generate_transport(&server_service, &server_trait, &path); + let compression_enabled = cfg!(feature = "compression"); + + let compression_config_ty = if compression_enabled { + quote! { EnabledCompressionEncodings } + } else { + quote! { () } + }; + + let configure_compression_methods = if compression_enabled { + quote! { + // TODO(david): docs + pub fn accept_gzip(mut self) -> Self { + self.accept_compression_encodings.enable_gzip(); + self + } + + // TODO(david): docs + pub fn send_gzip(mut self) -> Self { + self.send_compression_encodings.enable_gzip(); + self + } + } + } else { + quote! {} + }; + quote! { /// Generated server implementations. pub mod #server_mod { @@ -48,8 +74,8 @@ pub fn generate( #[derive(Debug)] pub struct #server_service { inner: _Inner, - accept_compression_encodings: EnabledCompressionEncodings, - send_compression_encodings: EnabledCompressionEncodings, + accept_compression_encodings: #compression_config_ty, + send_compression_encodings: #compression_config_ty, } struct _Inner(Arc); @@ -72,17 +98,7 @@ pub fn generate( InterceptedService::new(Self::new(inner), interceptor) } - // TODO(david): docs - pub fn accept_gzip(mut self) -> Self { - self.accept_compression_encodings.enable_gzip(); - self - } - - // TODO(david): docs - pub fn send_gzip(mut self) -> Self { - self.send_compression_encodings.enable_gzip(); - self - } + #configure_compression_methods } impl Service> for #server_service diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 252729b16..50b03f9ce 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -39,6 +39,7 @@ tls-roots-common = ["tls"] tls-roots = ["tls-roots-common", "rustls-native-certs"] tls-webpki-roots = ["tls-roots-common", "webpki-roots"] prost = ["prost1", "prost-derive"] +compression = ["tokio/rt-multi-thread"] # [[bench]] # name = "bench_main" @@ -71,7 +72,7 @@ async-trait = { version = "0.1.13", optional = true } # transport h2 = { version = "0.3", optional = true } hyper = { version = "0.14.2", features = ["full"], optional = true } -tokio = { version = "1.0.1", features = ["net", "rt-multi-thread"], optional = true } +tokio = { version = "1.0.1", features = ["net"], optional = true } tokio-stream = "0.1" tower = { version = "0.4.7", features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true } tracing-futures = { version = "0.2", optional = true } diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index a79d5d94f..46357de49 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -1,10 +1,9 @@ +#[cfg(feature = "compression")] +use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings}; use crate::{ body::BoxBody, client::GrpcService, - codec::{ - compression::CompressionEncoding, encode_client, Codec, EnabledCompressionEncodings, - Streaming, - }, + codec::{encode_client, Codec, Streaming}, Code, Request, Response, Status, }; use futures_core::Stream; @@ -31,8 +30,10 @@ use std::fmt; /// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests pub struct Grpc { inner: T, + #[cfg(feature = "compression")] /// Which compression encodings does the client accept? accept_compression_encodings: EnabledCompressionEncodings, + #[cfg(feature = "compression")] /// The compression encoding that will be applied to requests. send_compression_encodings: Option, } @@ -42,7 +43,9 @@ impl Grpc { pub fn new(inner: T) -> Self { Self { inner, + #[cfg(feature = "compression")] send_compression_encodings: None, + #[cfg(feature = "compression")] accept_compression_encodings: EnabledCompressionEncodings::default(), } } @@ -72,11 +75,19 @@ impl Grpc { /// let client = TestClient::new(channel).send_gzip(); /// # }; /// ``` + #[cfg(feature = "compression")] + #[cfg_attr(docsrs, doc(cfg(feature = "compression")))] pub fn send_gzip(mut self) -> Self { self.send_compression_encodings = Some(CompressionEncoding::Gzip); self } + #[doc(hidden)] + #[cfg(not(feature = "compression"))] + pub fn send_gzip(self) -> Self { + panic!("`send_gzip` called on a server but the `compression` feature is not enabled on tonic"); + } + /// Enable accepting `gzip` compressed responses. /// /// Requires the server to also support sending compressed responses. @@ -102,11 +113,19 @@ impl Grpc { /// let client = TestClient::new(channel).accept_gzip(); /// # }; /// ``` + #[cfg(feature = "compression")] + #[cfg_attr(docsrs, doc(cfg(feature = "compression")))] pub fn accept_gzip(mut self) -> Self { self.accept_compression_encodings.enable_gzip(); self } + #[doc(hidden)] + #[cfg(not(feature = "compression"))] + pub fn accept_gzip(self) -> Self { + panic!("`accept_gzip` called on a client but the `compression` feature is not enabled on tonic"); + } + /// Check if the inner [`GrpcService`] is able to accept a new request. /// /// This will call [`GrpcService::poll_ready`] until it returns ready or @@ -216,7 +235,14 @@ impl Grpc { let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri"); let request = request - .map(|s| encode_client(codec.encoder(), s, self.send_compression_encodings)) + .map(|s| { + encode_client( + codec.encoder(), + s, + #[cfg(feature = "compression")] + self.send_compression_encodings, + ) + }) .map(BoxBody::new); let mut request = request.into_http(uri); @@ -231,21 +257,24 @@ impl Grpc { .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("application/grpc")); - if let Some(encoding) = self.send_compression_encodings { - request.headers_mut().insert( - crate::codec::compression::ENCODING_HEADER, - encoding.into_header_value(), - ); - } - - if let Some(header_value) = self - .accept_compression_encodings - .into_accept_encoding_header_value() + #[cfg(feature = "compression")] { - request.headers_mut().insert( - crate::codec::compression::ACCEPT_ENCODING_HEADER, - header_value, - ); + if let Some(encoding) = self.send_compression_encodings { + request.headers_mut().insert( + crate::codec::compression::ENCODING_HEADER, + encoding.into_header_value(), + ); + } + + if let Some(header_value) = self + .accept_compression_encodings + .into_accept_encoding_header_value() + { + request.headers_mut().insert( + crate::codec::compression::ACCEPT_ENCODING_HEADER, + header_value, + ); + } } let response = self @@ -254,6 +283,7 @@ impl Grpc { .await .map_err(|err| Status::from_error(err.into()))?; + #[cfg(feature = "compression")] let encoding = CompressionEncoding::from_encoding_header( response.headers(), self.accept_compression_encodings, @@ -276,7 +306,13 @@ impl Grpc { let response = response.map(|body| { if expect_additional_trailers { - Streaming::new_response(codec.decoder(), body, status_code, encoding) + Streaming::new_response( + codec.decoder(), + body, + status_code, + #[cfg(feature = "compression")] + encoding, + ) } else { Streaming::new_empty(codec.decoder(), body) } @@ -290,7 +326,9 @@ impl Clone for Grpc { fn clone(&self) -> Self { Self { inner: self.inner.clone(), + #[cfg(feature = "compression")] send_compression_encodings: self.send_compression_encodings, + #[cfg(feature = "compression")] accept_compression_encodings: self.accept_compression_encodings, } } @@ -298,13 +336,19 @@ impl Clone for Grpc { impl fmt::Debug for Grpc { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Grpc") - .field("inner", &self.inner) - .field("compression_encoding", &self.send_compression_encodings) - .field( - "accept_compression_encodings", - &self.accept_compression_encodings, - ) - .finish() + let mut f = f.debug_struct("Grpc"); + + f.field("inner", &self.inner); + + #[cfg(feature = "compression")] + f.field("compression_encoding", &self.send_compression_encodings); + + #[cfg(feature = "compression")] + f.field( + "accept_compression_encodings", + &self.accept_compression_encodings, + ); + + f.finish() } } diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 40e4748fb..0a8cdbf96 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -1,7 +1,6 @@ -use super::{ - compression::{decompress, CompressionEncoding}, - DecodeBuf, Decoder, HEADER_SIZE, -}; +#[cfg(feature = "compression")] +use super::compression::{decompress, CompressionEncoding}; +use super::{DecodeBuf, Decoder, HEADER_SIZE}; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use futures_core::Stream; @@ -27,8 +26,10 @@ pub struct Streaming { state: State, direction: Direction, buf: BytesMut, - decompress_buf: BytesMut, trailers: Option, + #[cfg(feature = "compression")] + decompress_buf: BytesMut, + #[cfg(feature = "compression")] encoding: Option, } @@ -52,14 +53,20 @@ impl Streaming { decoder: D, body: B, status_code: StatusCode, - encoding: Option, + #[cfg(feature = "compression")] encoding: Option, ) -> Self where B: Body + Send + Sync + 'static, B::Error: Into, D: Decoder + Send + Sync + 'static, { - Self::new(decoder, body, Direction::Response(status_code), encoding) + Self::new( + decoder, + body, + Direction::Response(status_code), + #[cfg(feature = "compression")] + encoding, + ) } pub(crate) fn new_empty(decoder: D, body: B) -> Self @@ -68,24 +75,40 @@ impl Streaming { B::Error: Into, D: Decoder + Send + Sync + 'static, { - Self::new(decoder, body, Direction::EmptyResponse, None) + Self::new( + decoder, + body, + Direction::EmptyResponse, + #[cfg(feature = "compression")] + None, + ) } #[doc(hidden)] - pub fn new_request(decoder: D, body: B, encoding: Option) -> Self + pub fn new_request( + decoder: D, + body: B, + #[cfg(feature = "compression")] encoding: Option, + ) -> Self where B: Body + Send + Sync + 'static, B::Error: Into, D: Decoder + Send + Sync + 'static, { - Self::new(decoder, body, Direction::Request, encoding) + Self::new( + decoder, + body, + Direction::Request, + #[cfg(feature = "compression")] + encoding, + ) } fn new( decoder: D, body: B, direction: Direction, - encoding: Option, + #[cfg(feature = "compression")] encoding: Option, ) -> Self where B: Body + Send + Sync + 'static, @@ -101,8 +124,10 @@ impl Streaming { state: State::ReadHeader, direction, buf: BytesMut::with_capacity(BUFFER_SIZE), - decompress_buf: BytesMut::new(), trailers: None, + #[cfg(feature = "compression")] + decompress_buf: BytesMut::new(), + #[cfg(feature = "compression")] encoding, } } @@ -179,7 +204,16 @@ impl Streaming { let is_compressed = match self.buf.get_u8() { 0 => false, - 1 => true, + 1 => { + if cfg!(feature = "compression") { + true + } else { + return Err(Status::new( + Code::Unimplemented, + "Message compressed, compression support not enabled.".to_string(), + )); + } + } f => { trace!("unexpected compression flag"); let message = if let Direction::Response(status) = self.direction { @@ -210,31 +244,37 @@ impl Streaming { } let decoding_result = if *compression { - self.decompress_buf.clear(); - - if let Err(err) = decompress( - self.encoding.unwrap_or_else(|| { - unreachable!("message was compressed but `Streaming.encoding` was `None`. This is a bug in Tonic. Please file an issue") - }), - &mut self.buf, - &mut self.decompress_buf, - *len, - ) { - let message = if let Direction::Response(status) = self.direction { - format!( - "Error decompressing: {}, while receiving response with status: {}", - err, status - ) - } else { - format!("Error decompressing: {}, while sending request", err) - }; - return Err(Status::new(Code::Internal, message)); + #[cfg(feature = "compression")] + { + self.decompress_buf.clear(); + + if let Err(err) = decompress( + self.encoding.unwrap_or_else(|| { + unreachable!("message was compressed but `Streaming.encoding` was `None`. This is a bug in Tonic. Please file an issue") + }), + &mut self.buf, + &mut self.decompress_buf, + *len, + ) { + let message = if let Direction::Response(status) = self.direction { + format!( + "Error decompressing: {}, while receiving response with status: {}", + err, status + ) + } else { + format!("Error decompressing: {}, while sending request", err) + }; + return Err(Status::new(Code::Internal, message)); + } + let decompressed_len = self.decompress_buf.len(); + self.decoder.decode(&mut DecodeBuf::new( + &mut self.decompress_buf, + decompressed_len, + )) } - let decompressed_len = self.decompress_buf.len(); - self.decoder.decode(&mut DecodeBuf::new( - &mut self.decompress_buf, - decompressed_len, - )) + + #[cfg(not(feature = "compression"))] + unreachable!("should not take this branch if compression is disabled") } else { self.decoder .decode(&mut DecodeBuf::new(&mut self.buf, *len)) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index f6fc09541..2fb53a329 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -1,7 +1,6 @@ -use super::{ - compression::{compress, CompressionEncoding, SingleMessageCompressionOverride}, - EncodeBuf, Encoder, HEADER_SIZE, -}; +#[cfg(feature = "compression")] +use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride}; +use super::{EncodeBuf, Encoder, HEADER_SIZE}; use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; use futures_core::{Stream, TryStream}; @@ -19,22 +18,31 @@ pub(super) const BUFFER_SIZE: usize = 8 * 1024; pub(crate) fn encode_server( encoder: T, source: U, - compression_encoding: Option, - compression_override: SingleMessageCompressionOverride, + #[cfg(feature = "compression")] compression_encoding: Option, + #[cfg(feature = "compression")] compression_override: SingleMessageCompressionOverride, ) -> EncodeBody>> where T: Encoder + Send + Sync + 'static, T::Item: Send + Sync, U: Stream> + Send + Sync + 'static, { - let stream = encode(encoder, source, compression_encoding, compression_override).into_stream(); + let stream = encode( + encoder, + source, + #[cfg(feature = "compression")] + compression_encoding, + #[cfg(feature = "compression")] + compression_override, + ) + .into_stream(); + EncodeBody::new_server(stream) } pub(crate) fn encode_client( encoder: T, source: U, - compression_encoding: Option, + #[cfg(feature = "compression")] compression_encoding: Option, ) -> EncodeBody>> where T: Encoder + Send + Sync + 'static, @@ -44,7 +52,9 @@ where let stream = encode( encoder, source.map(Ok), + #[cfg(feature = "compression")] compression_encoding, + #[cfg(feature = "compression")] SingleMessageCompressionOverride::default(), ) .into_stream(); @@ -54,8 +64,8 @@ where fn encode( mut encoder: T, source: U, - compression_encoding: Option, - compression_override: SingleMessageCompressionOverride, + #[cfg(feature = "compression")] compression_encoding: Option, + #[cfg(feature = "compression")] compression_override: SingleMessageCompressionOverride, ) -> impl TryStream where T: Encoder, @@ -64,13 +74,18 @@ where async_stream::stream! { let mut buf = BytesMut::with_capacity(BUFFER_SIZE); + #[cfg(feature = "compression")] let (compression_enabled_for_stream, mut uncompression_buf) = match compression_encoding { Some(CompressionEncoding::Gzip) => (true, BytesMut::with_capacity(BUFFER_SIZE)), None => (false, BytesMut::new()), }; + #[cfg(feature = "compression")] let compress_item = compression_enabled_for_stream && compression_override == SingleMessageCompressionOverride::Inherit; + #[cfg(not(feature = "compression"))] + let compress_item = false; + futures_util::pin_mut!(source); loop { @@ -82,19 +97,25 @@ where } if compress_item { - uncompression_buf.clear(); + #[cfg(feature = "compression")] + { + uncompression_buf.clear(); - encoder.encode(item, &mut EncodeBuf::new(&mut uncompression_buf)) - .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; + encoder.encode(item, &mut EncodeBuf::new(&mut uncompression_buf)) + .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; + + let uncompressed_len = uncompression_buf.len(); - let uncompressed_len = uncompression_buf.len(); + compress( + compression_encoding.unwrap(), + &mut uncompression_buf, + &mut buf, + uncompressed_len, + ).map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; + } - compress( - compression_encoding.unwrap(), - &mut uncompression_buf, - &mut buf, - uncompressed_len, - ).map_err(|err| Status::internal(format!("Error compressing: {}", err)))?; + #[cfg(not(feature = "compression"))] + unreachable!("compression disabled, should not take this branch"); } else { encoder.encode(item, &mut EncodeBuf::new(&mut buf)) .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?; diff --git a/tonic/src/codec/mod.rs b/tonic/src/codec/mod.rs index 58b40324e..d0c9d8b05 100644 --- a/tonic/src/codec/mod.rs +++ b/tonic/src/codec/mod.rs @@ -4,6 +4,7 @@ //! and a protobuf codec based on prost. mod buffer; +#[cfg(feature = "compression")] pub(crate) mod compression; mod decode; mod encode; @@ -16,6 +17,8 @@ use std::io; pub(crate) use self::encode::{encode_client, encode_server}; pub use self::buffer::{DecodeBuf, EncodeBuf}; +#[cfg(feature = "compression")] +#[cfg_attr(docsrs, doc(cfg(feature = "compression")))] pub use self::compression::{CompressionEncoding, EnabledCompressionEncodings}; pub use self::decode::Streaming; #[cfg(feature = "prost")] diff --git a/tonic/src/codegen.rs b/tonic/src/codegen.rs index 321615c8b..9d3a06996 100644 --- a/tonic/src/codegen.rs +++ b/tonic/src/codegen.rs @@ -10,6 +10,7 @@ pub use std::sync::Arc; pub use std::task::{Context, Poll}; pub use tower_service::Service; pub type StdError = Box; +#[cfg(feature = "compression")] pub use crate::codec::{CompressionEncoding, EnabledCompressionEncodings}; pub use crate::service::interceptor::InterceptedService; pub use http_body::Body; diff --git a/tonic/src/lib.rs b/tonic/src/lib.rs index 4d91643c1..ebd7f3de8 100644 --- a/tonic/src/lib.rs +++ b/tonic/src/lib.rs @@ -28,6 +28,10 @@ //! - `tls-webpki-roots`: Add the standard trust roots from the `webpki-roots` crate to //! `rustls`-based gRPC clients. Not enabled by default. //! - `prost`: Enables the [`prost`] based gRPC [`Codec`] implementation. +//! - `compression`: Enables compressing requests, responses, and streams. Note +//! that you must enable the `compression` feature on both `tonic` and +//! `tonic-build` to use it. Depends on `tokio`'s multi-threaded runtime and +//! [flate2]. Not enabled by default. //! //! # Structure //! @@ -62,6 +66,7 @@ //! [`rustls`]: https://docs.rs/rustls //! [`client`]: client/index.html //! [`transport`]: transport/index.html +//! [flate2]: https://crates.io/crates/flate2 #![recursion_limit = "256"] #![allow(clippy::inconsistent_struct_constructor)] diff --git a/tonic/src/response.rs b/tonic/src/response.rs index b6cf4f23a..e54c47aff 100644 --- a/tonic/src/response.rs +++ b/tonic/src/response.rs @@ -1,6 +1,4 @@ -use crate::{ - codec::compression::SingleMessageCompressionOverride, metadata::MetadataMap, Extensions, -}; +use crate::{metadata::MetadataMap, Extensions}; /// A gRPC response and metadata from an RPC call. #[derive(Debug)] @@ -117,9 +115,11 @@ impl Response { /// /// **Note** this only has effect on responses to unary requests. Response streams will still /// be compressed according to the configuration of the server. + #[cfg(feature = "compression")] + #[cfg_attr(docsrs, doc(cfg(feature = "compression")))] pub fn disable_compression(&mut self) { self.extensions_mut() - .insert(SingleMessageCompressionOverride::Disable); + .insert(crate::codec::compression::SingleMessageCompressionOverride::Disable); } } diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index b472b300b..726e7f707 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -1,11 +1,10 @@ +#[cfg(feature = "compression")] +use crate::codec::compression::{ + CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride, +}; use crate::{ body::BoxBody, - codec::{ - compression::{ - CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride, - }, - encode_server, Codec, Streaming, - }, + codec::{encode_server, Codec, Streaming}, server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService}, Code, Request, Status, }; @@ -35,8 +34,10 @@ macro_rules! t { pub struct Grpc { codec: T, /// Which compression encodings does the server accept for requests? + #[cfg(feature = "compression")] accept_compression_encodings: EnabledCompressionEncodings, /// Which compression encodings might the server use for responses. + #[cfg(feature = "compression")] send_compression_encodings: EnabledCompressionEncodings, } @@ -49,7 +50,9 @@ where pub fn new(codec: T) -> Self { Self { codec, + #[cfg(feature = "compression")] accept_compression_encodings: EnabledCompressionEncodings::default(), + #[cfg(feature = "compression")] send_compression_encodings: EnabledCompressionEncodings::default(), } } @@ -80,11 +83,19 @@ where /// /// let service = ExampleServer::new(Svc).accept_gzip(); /// ``` + #[cfg(feature = "compression")] + #[cfg_attr(docsrs, doc(cfg(feature = "compression")))] pub fn accept_gzip(mut self) -> Self { self.accept_compression_encodings.enable_gzip(); self } + #[doc(hidden)] + #[cfg(not(feature = "compression"))] + pub fn accept_gzip(self) -> Self { + panic!("`accept_gzip` called on a server but the `compression` feature is not enabled on tonic"); + } + /// Enable sending `gzip` compressed responses. /// /// Requires the client to also support receiving compressed responses. @@ -110,11 +121,20 @@ where /// /// let service = ExampleServer::new(Svc).send_gzip(); /// ``` + #[cfg(feature = "compression")] + #[cfg_attr(docsrs, doc(cfg(feature = "compression")))] pub fn send_gzip(mut self) -> Self { self.send_compression_encodings.enable_gzip(); self } + #[doc(hidden)] + #[cfg(not(feature = "compression"))] + pub fn send_gzip(self) -> Self { + panic!("`send_gzip` called on a server but the `compression` feature is not enabled on tonic"); + } + + #[cfg(feature = "compression")] #[doc(hidden)] pub fn apply_compression_config( self, @@ -136,6 +156,13 @@ where this } + #[cfg(not(feature = "compression"))] + #[doc(hidden)] + #[allow(unused_variables)] + pub fn apply_compression_config(self, accept_encodings: (), send_encodings: ()) -> Self { + self + } + /// Handle a single unary gRPC request. pub async fn unary( &mut self, @@ -147,6 +174,7 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { + #[cfg(feature = "compression")] let accept_encoding = CompressionEncoding::from_accept_encoding_header( req.headers(), self.send_compression_encodings, @@ -158,7 +186,9 @@ where return self .map_response::>>>( Err(status), + #[cfg(feature = "compression")] accept_encoding, + #[cfg(feature = "compression")] SingleMessageCompressionOverride::default(), ); } @@ -169,9 +199,16 @@ where .await .map(|r| r.map(|m| stream::once(future::ok(m)))); + #[cfg(feature = "compression")] let compression_override = compression_override_from_response(&response); - self.map_response(response, accept_encoding, compression_override) + self.map_response( + response, + #[cfg(feature = "compression")] + accept_encoding, + #[cfg(feature = "compression")] + compression_override, + ) } /// Handle a server side streaming request. @@ -186,6 +223,7 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { + #[cfg(feature = "compression")] let accept_encoding = CompressionEncoding::from_accept_encoding_header( req.headers(), self.send_compression_encodings, @@ -196,7 +234,9 @@ where Err(status) => { return self.map_response::( Err(status), + #[cfg(feature = "compression")] accept_encoding, + #[cfg(feature = "compression")] SingleMessageCompressionOverride::default(), ); } @@ -206,9 +246,11 @@ where self.map_response( response, + #[cfg(feature = "compression")] accept_encoding, // disabling compression of individual stream items must be done on // the items themselves + #[cfg(feature = "compression")] SingleMessageCompressionOverride::default(), ) } @@ -225,6 +267,7 @@ where B::Error: Into + Send + 'static, T: std::fmt::Debug, { + #[cfg(feature = "compression")] let accept_encoding = CompressionEncoding::from_accept_encoding_header( req.headers(), self.send_compression_encodings, @@ -237,9 +280,16 @@ where .await .map(|r| r.map(|m| stream::once(future::ok(m)))); + #[cfg(feature = "compression")] let compression_override = compression_override_from_response(&response); - self.map_response(response, accept_encoding, compression_override) + self.map_response( + response, + #[cfg(feature = "compression")] + accept_encoding, + #[cfg(feature = "compression")] + compression_override, + ) } /// Handle a bi-directional streaming gRPC request. @@ -254,6 +304,7 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { + #[cfg(feature = "compression")] let accept_encoding = CompressionEncoding::from_accept_encoding_header( req.headers(), self.send_compression_encodings, @@ -265,7 +316,9 @@ where self.map_response( response, + #[cfg(feature = "compression")] accept_encoding, + #[cfg(feature = "compression")] SingleMessageCompressionOverride::default(), ) } @@ -278,12 +331,18 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { + #[cfg(feature = "compression")] let request_compression_encoding = self.request_encoding_if_supported(&request)?; let (parts, body) = request.into_parts(); + + #[cfg(feature = "compression")] let stream = Streaming::new_request(self.codec.decoder(), body, request_compression_encoding); + #[cfg(not(feature = "compression"))] + let stream = Streaming::new_request(self.codec.decoder(), body); + futures_util::pin_mut!(stream); let message = stream @@ -308,17 +367,24 @@ where B: Body + Send + Sync + 'static, B::Error: Into + Send, { + #[cfg(feature = "compression")] let encoding = self.request_encoding_if_supported(&request)?; + + #[cfg(feature = "compression")] let request = request.map(|body| Streaming::new_request(self.codec.decoder(), body, encoding)); + + #[cfg(not(feature = "compression"))] + let request = request.map(|body| Streaming::new_request(self.codec.decoder(), body)); + Ok(Request::from_http(request)) } fn map_response( &mut self, response: Result, Status>, - accept_encoding: Option, - compression_override: SingleMessageCompressionOverride, + #[cfg(feature = "compression")] accept_encoding: Option, + #[cfg(feature = "compression")] compression_override: SingleMessageCompressionOverride, ) -> http::Response where B: TryStream + Send + Sync + 'static, @@ -336,6 +402,7 @@ where http::header::HeaderValue::from_static("application/grpc"), ); + #[cfg(feature = "compression")] if let Some(encoding) = accept_encoding { // Set the content encoding parts.headers.insert( @@ -347,13 +414,16 @@ where let body = encode_server( self.codec.encoder(), body.into_stream(), + #[cfg(feature = "compression")] accept_encoding, + #[cfg(feature = "compression")] compression_override, ); http::Response::from_parts(parts, BoxBody::new(body)) } + #[cfg(feature = "compression")] fn request_encoding_if_supported( &self, request: &http::Request, @@ -367,20 +437,27 @@ where impl fmt::Debug for Grpc { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Grpc") - .field("codec", &self.codec) - .field( - "accept_compression_encodings", - &self.accept_compression_encodings, - ) - .field( - "send_compression_encodings", - &self.send_compression_encodings, - ) - .finish() + let mut f = f.debug_struct("Grpc"); + + f.field("codec", &self.codec); + + #[cfg(feature = "compression")] + f.field( + "accept_compression_encodings", + &self.accept_compression_encodings, + ); + + #[cfg(feature = "compression")] + f.field( + "send_compression_encodings", + &self.send_compression_encodings, + ); + + f.finish() } } +#[cfg(feature = "compression")] fn compression_override_from_response( res: &Result, E>, ) -> SingleMessageCompressionOverride { From 9c230ba54ba25c98a31f93ebd7e9d4365e282aa6 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 29 Jun 2021 16:41:14 +0200 Subject: [PATCH 20/29] Add some docs --- tests/compression/src/lib.rs | 4 ---- tonic-build/src/client.rs | 7 +++++-- tonic-build/src/server.rs | 4 ++-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/compression/src/lib.rs b/tests/compression/src/lib.rs index 4a174ddb5..772be4ef9 100644 --- a/tests/compression/src/lib.rs +++ b/tests/compression/src/lib.rs @@ -1,9 +1,5 @@ #![allow(unused_imports)] -// TODO(david): document that using a multi threaded tokio runtime is -// required (because of `block_in_place`) -// TODO(david): send_gzip on channel, but disabling compression of a message - use self::util::*; use futures::{Stream, StreamExt}; use std::{ diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index 309141452..b44c1d0a2 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -58,13 +58,16 @@ pub fn generate( #service_ident::new(InterceptedService::new(inner, interceptor)) } - // TODO(david): docs + /// Compress requests with `gzip`. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. pub fn send_gzip(mut self) -> Self { self.inner = self.inner.send_gzip(); self } - // TODO(david): docs + /// Enable decompressing responses with `gzip`. pub fn accept_gzip(mut self) -> Self { self.inner = self.inner.accept_gzip(); self diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index 1004ee701..f8962432f 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -46,13 +46,13 @@ pub fn generate( let configure_compression_methods = if compression_enabled { quote! { - // TODO(david): docs + /// Enable decompressing requests with `gzip`. pub fn accept_gzip(mut self) -> Self { self.accept_compression_encodings.enable_gzip(); self } - // TODO(david): docs + /// Compress responses with `gzip`, if the client supports it. pub fn send_gzip(mut self) -> Self { self.send_compression_encodings.enable_gzip(); self From f048f874976a8d4ac067dfb9c5583e12f44a99d9 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 29 Jun 2021 16:45:42 +0200 Subject: [PATCH 21/29] Make flate2 optional dependency --- tonic/Cargo.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 50b03f9ce..9223208de 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -39,7 +39,7 @@ tls-roots-common = ["tls"] tls-roots = ["tls-roots-common", "rustls-native-certs"] tls-webpki-roots = ["tls-roots-common", "webpki-roots"] prost = ["prost1", "prost-derive"] -compression = ["tokio/rt-multi-thread"] +compression = ["tokio/rt-multi-thread", "flate2"] # [[bench]] # name = "bench_main" @@ -60,7 +60,6 @@ tokio-util = { version = "0.6", features = ["codec"] } async-stream = "0.3" http-body = "0.4.2" pin-project = "1.0" -flate2 = "1.0" # prost prost1 = { package = "prost", version = "0.7", optional = true } @@ -82,6 +81,9 @@ tokio-rustls = { version = "0.22", optional = true } rustls-native-certs = { version = "0.5", optional = true } webpki-roots = { version = "0.21.1", optional = true } +# compression +flate2 = { version = "1.0", optional = true } + [dev-dependencies] tokio = { version = "1.0", features = ["rt", "macros"] } static_assertions = "1.0" From ab1e953a1d6e90403b4deab41eb36d82fae181ef Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 29 Jun 2021 16:49:29 +0200 Subject: [PATCH 22/29] Fix docs wording --- tonic/src/response.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tonic/src/response.rs b/tonic/src/response.rs index e54c47aff..89fc98706 100644 --- a/tonic/src/response.rs +++ b/tonic/src/response.rs @@ -110,11 +110,12 @@ impl Response { /// Disable compression of the response body. /// - /// This disables compression of this response's body, even if compression is enabled on the - /// server. + /// This disables compression of the body of this response, even if compression is enabled on + /// the server. /// - /// **Note** this only has effect on responses to unary requests. Response streams will still - /// be compressed according to the configuration of the server. + /// **Note**: This only has effect on responses to unary requests and responses to client to + /// server streams. Response streams (server to client stream and bidirectional streams) will + /// still be compressed according to the configuration of the server. #[cfg(feature = "compression")] #[cfg_attr(docsrs, doc(cfg(feature = "compression")))] pub fn disable_compression(&mut self) { From ed62228258f290f4bbd4259a7c0d0f97931b302f Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 29 Jun 2021 17:00:02 +0200 Subject: [PATCH 23/29] Format --- tonic/src/client/grpc.rs | 4 +++- tonic/src/server/grpc.rs | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 46357de49..6e8e2d99f 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -85,7 +85,9 @@ impl Grpc { #[doc(hidden)] #[cfg(not(feature = "compression"))] pub fn send_gzip(self) -> Self { - panic!("`send_gzip` called on a server but the `compression` feature is not enabled on tonic"); + panic!( + "`send_gzip` called on a server but the `compression` feature is not enabled on tonic" + ); } /// Enable accepting `gzip` compressed responses. diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index 726e7f707..21b542981 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -131,7 +131,9 @@ where #[doc(hidden)] #[cfg(not(feature = "compression"))] pub fn send_gzip(self) -> Self { - panic!("`send_gzip` called on a server but the `compression` feature is not enabled on tonic"); + panic!( + "`send_gzip` called on a server but the `compression` feature is not enabled on tonic" + ); } #[cfg(feature = "compression")] From f726acbacecb2941d135ae5c20a783b8687a5852 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 29 Jun 2021 17:15:12 +0200 Subject: [PATCH 24/29] Reply with which encodings are supported --- tests/compression/src/compressing_request.rs | 8 +++---- tonic/src/codec/compression.rs | 22 +++++++++++++++----- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index 318c2861a..f4aaa2586 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -90,8 +90,8 @@ async fn client_enabled_server_disabled() { "Content is compressed with `gzip` which isn't supported" ); - // TODO(david): include header with which encodings are supported as per the spec: - // - // > The server will then include a grpc-accept-encoding response header which specifies the - // algorithms that the server accepts. + assert_eq!( + status.metadata().get("grpc-accept-encoding").unwrap(), + "identity" + ); } diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 1880dbb15..c1f438a64 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -1,5 +1,5 @@ use super::encode::BUFFER_SIZE; -use crate::Status; +use crate::{metadata::MetadataValue, Status}; use bytes::{Buf, BufMut, BytesMut}; use flate2::read::{GzDecoder, GzEncoder}; use std::fmt; @@ -80,10 +80,22 @@ impl CompressionEncoding { match header_value_str { "gzip" if gzip => Ok(Some(CompressionEncoding::Gzip)), - other => Err(Status::unimplemented(format!( - "Content is compressed with `{}` which isn't supported", - other - ))), + other => { + let mut status = Status::unimplemented(format!( + "Content is compressed with `{}` which isn't supported", + other + )); + + let header_value = enabled_encodings + .into_accept_encoding_header_value() + .map(MetadataValue::unchecked_from_header_value) + .unwrap_or_else(|| MetadataValue::from_static("identity")); + status + .metadata_mut() + .insert(ACCEPT_ENCODING_HEADER, header_value); + + Err(status) + } } } From 00f6989e4280a3c9a84c4edc0d9c3d620606fa94 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 1 Jul 2021 17:11:40 +0200 Subject: [PATCH 25/29] Convert tests to use mocked io --- tests/compression/src/bidirectional_stream.rs | 20 ++- tests/compression/src/client_stream.rs | 70 ++++------ tests/compression/src/compressing_request.rs | 34 ++--- tests/compression/src/compressing_response.rs | 126 +++++++----------- tests/compression/src/lib.rs | 8 +- tests/compression/src/server_stream.rs | 56 +++----- tests/compression/src/util.rs | 61 ++++++++- tonic/src/codec/compression.rs | 3 +- 8 files changed, 184 insertions(+), 194 deletions(-) diff --git a/tests/compression/src/bidirectional_stream.rs b/tests/compression/src/bidirectional_stream.rs index ca5397790..83e2031b3 100644 --- a/tests/compression/src/bidirectional_stream.rs +++ b/tests/compression/src/bidirectional_stream.rs @@ -2,13 +2,12 @@ use super::*; #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_enabled() { + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); + let svc = test_server::TestServer::new(Svc::default()) .accept_gzip() .send_gzip(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); fn assert_right_encoding(req: http::Request) -> http::Request { @@ -33,18 +32,15 @@ async fn client_enabled_server_enabled() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel) + let mut client = test_client::TestClient::new(mock_io_channel(client).await) .send_gzip() .accept_gzip(); @@ -73,6 +69,6 @@ async fn client_enabled_server_enabled() { .expect("stream empty") .expect("item was error"); - let bytes_sent = bytes_sent_counter.load(Relaxed); - assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); + let bytes_sent = bytes_sent_counter.load(SeqCst); + assert!(dbg!(bytes_sent) < UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/client_stream.rs b/tests/compression/src/client_stream.rs index 8851fb773..30bad6f0f 100644 --- a/tests/compression/src/client_stream.rs +++ b/tests/compression/src/client_stream.rs @@ -3,10 +3,9 @@ use http_body::Body as _; #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_enabled() { - let svc = test_server::TestServer::new(Svc::default()).accept_gzip(); + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + let svc = test_server::TestServer::new(Svc::default()).accept_gzip(); let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); @@ -26,18 +25,15 @@ async fn client_enabled_server_enabled() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel).send_gzip(); + let mut client = test_client::TestClient::new(mock_io_channel(client).await).send_gzip(); let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); @@ -45,16 +41,15 @@ async fn client_enabled_server_enabled() { client.compress_input_client_stream(req).await.unwrap(); - let bytes_sent = bytes_sent_counter.load(Relaxed); + let bytes_sent = bytes_sent_counter.load(SeqCst); assert!(dbg!(bytes_sent) < UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] async fn client_disabled_server_enabled() { - let svc = test_server::TestServer::new(Svc::default()).accept_gzip(); + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + let svc = test_server::TestServer::new(Svc::default()).accept_gzip(); let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); @@ -74,18 +69,15 @@ async fn client_disabled_server_enabled() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel); + let mut client = test_client::TestClient::new(mock_io_channel(client).await); let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); @@ -93,31 +85,27 @@ async fn client_disabled_server_enabled() { client.compress_input_client_stream(req).await.unwrap(); - let bytes_sent = bytes_sent_counter.load(Relaxed); + let bytes_sent = bytes_sent_counter.load(SeqCst); assert!(dbg!(bytes_sent) > UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_disabled() { - let svc = test_server::TestServer::new(Svc::default()); + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + let svc = test_server::TestServer::new(Svc::default()); tokio::spawn(async move { Server::builder() .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel).send_gzip(); + let mut client = test_client::TestClient::new(mock_io_channel(client).await).send_gzip(); let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); @@ -134,10 +122,9 @@ async fn client_enabled_server_disabled() { #[tokio::test(flavor = "multi_thread")] async fn compressing_response_from_client_stream() { - let svc = test_server::TestServer::new(Svc::default()).send_gzip(); + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + let svc = test_server::TestServer::new(Svc::default()).send_gzip(); let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); @@ -156,24 +143,21 @@ async fn compressing_response_from_client_stream() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel).accept_gzip(); + let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); let stream = futures::stream::iter(vec![]); let req = Request::new(Box::pin(stream)); let res = client.compress_output_client_stream(req).await.unwrap(); assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); - let bytes_sent = bytes_sent_counter.load(Relaxed); + let bytes_sent = bytes_sent_counter.load(SeqCst); assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index f4aaa2586..dfdad2257 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -3,10 +3,9 @@ use http_body::Body as _; #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_enabled() { - let svc = test_server::TestServer::new(Svc::default()).accept_gzip(); + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + let svc = test_server::TestServer::new(Svc::default()).accept_gzip(); let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); @@ -30,18 +29,15 @@ async fn client_enabled_server_enabled() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel).send_gzip(); + let mut client = test_client::TestClient::new(mock_io_channel(client).await).send_gzip(); for _ in 0..3 { client @@ -50,32 +46,28 @@ async fn client_enabled_server_enabled() { }) .await .unwrap(); - let bytes_sent = bytes_sent_counter.load(Relaxed); + let bytes_sent = bytes_sent_counter.load(SeqCst); assert!(dbg!(bytes_sent) < UNCOMPRESSED_MIN_BODY_SIZE); } } #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_disabled() { - let svc = test_server::TestServer::new(Svc::default()); + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + let svc = test_server::TestServer::new(Svc::default()); tokio::spawn(async move { Server::builder() .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel).send_gzip(); + let mut client = test_client::TestClient::new(mock_io_channel(client).await).send_gzip(); let status = client .compress_input_unary(SomeData { diff --git a/tests/compression/src/compressing_response.rs b/tests/compression/src/compressing_response.rs index eb3c73b91..305e97c50 100644 --- a/tests/compression/src/compressing_response.rs +++ b/tests/compression/src/compressing_response.rs @@ -2,6 +2,8 @@ use super::*; #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_enabled() { + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); + #[derive(Clone, Copy)] struct AssertCorrectAcceptEncoding(S); @@ -31,9 +33,6 @@ async fn client_enabled_server_enabled() { let svc = test_server::TestServer::new(Svc::default()).send_gzip(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ @@ -52,33 +51,29 @@ async fn client_enabled_server_enabled() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel).accept_gzip(); + let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); for _ in 0..3 { let res = client.compress_output_unary(()).await.unwrap(); assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); - let bytes_sent = bytes_sent_counter.load(Relaxed); + let bytes_sent = bytes_sent_counter.load(SeqCst); assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } } #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_disabled() { - let svc = test_server::TestServer::new(Svc::default()); + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + let svc = test_server::TestServer::new(Svc::default()); let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); @@ -98,29 +93,28 @@ async fn client_enabled_server_disabled() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel).accept_gzip(); + let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); let res = client.compress_output_unary(()).await.unwrap(); assert!(res.metadata().get("grpc-encoding").is_none()); - let bytes_sent = bytes_sent_counter.load(Relaxed); + let bytes_sent = bytes_sent_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] async fn client_disabled() { + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); + #[derive(Clone, Copy)] struct AssertCorrectAcceptEncoding(S); @@ -147,9 +141,6 @@ async fn client_disabled() { let svc = test_server::TestServer::new(Svc::default()).send_gzip(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ @@ -168,33 +159,29 @@ async fn client_disabled() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel); + let mut client = test_client::TestClient::new(mock_io_channel(client).await); let res = client.compress_output_unary(()).await.unwrap(); assert!(res.metadata().get("grpc-encoding").is_none()); - let bytes_sent = bytes_sent_counter.load(Relaxed); + let bytes_sent = bytes_sent_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] async fn server_replying_with_unsupported_encoding() { - let svc = test_server::TestServer::new(Svc::default()).send_gzip(); + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + let svc = test_server::TestServer::new(Svc::default()).send_gzip(); fn add_weird_content_encoding(mut response: http::Response) -> http::Response { response @@ -211,17 +198,14 @@ async fn server_replying_with_unsupported_encoding() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel).accept_gzip(); + let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); let status: Status = client.compress_output_unary(()).await.unwrap_err(); assert_eq!(status.code(), tonic::Code::Unimplemented); @@ -233,14 +217,13 @@ async fn server_replying_with_unsupported_encoding() { #[tokio::test(flavor = "multi_thread")] async fn disabling_compression_on_single_response() { + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); + let svc = test_server::TestServer::new(Svc { disable_compressing_on_response: true, }) .send_gzip(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ @@ -258,35 +241,31 @@ async fn disabling_compression_on_single_response() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel).accept_gzip(); + let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); let res = client.compress_output_unary(()).await.unwrap(); assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); - let bytes_sent = bytes_sent_counter.load(Relaxed); + let bytes_sent = bytes_sent_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] async fn disabling_compression_on_response_but_keeping_compression_on_stream() { + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); + let svc = test_server::TestServer::new(Svc { disable_compressing_on_response: true, }) .send_gzip(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ @@ -304,18 +283,15 @@ async fn disabling_compression_on_response_but_keeping_compression_on_stream() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel).accept_gzip(); + let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); let res = client.compress_output_server_stream(()).await.unwrap(); @@ -328,26 +304,25 @@ async fn disabling_compression_on_response_but_keeping_compression_on_stream() { .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(bytes_sent_counter.load(Relaxed)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(bytes_sent_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); stream .next() .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(bytes_sent_counter.load(Relaxed)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(bytes_sent_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] async fn disabling_compression_on_response_from_client_stream() { + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); + let svc = test_server::TestServer::new(Svc { disable_compressing_on_response: true, }) .send_gzip(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ @@ -365,24 +340,21 @@ async fn disabling_compression_on_response_from_client_stream() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel).accept_gzip(); + let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); let stream = futures::stream::iter(vec![]); let req = Request::new(Box::pin(stream)); let res = client.compress_output_client_stream(req).await.unwrap(); assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); - let bytes_sent = bytes_sent_counter.load(Relaxed); + let bytes_sent = bytes_sent_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/lib.rs b/tests/compression/src/lib.rs index 772be4ef9..b14a34491 100644 --- a/tests/compression/src/lib.rs +++ b/tests/compression/src/lib.rs @@ -1,20 +1,22 @@ #![allow(unused_imports)] use self::util::*; +use crate::util::{mock_io_channel, MockStream}; use futures::{Stream, StreamExt}; +use std::convert::TryFrom; use std::{ pin::Pin, sync::{ - atomic::{AtomicUsize, Ordering::Relaxed}, + atomic::{AtomicUsize, Ordering::SeqCst}, Arc, }, }; use tokio::net::TcpListener; use tonic::{ - transport::{Channel, Server}, + transport::{Channel, Endpoint, Server, Uri}, Request, Response, Status, Streaming, }; -use tower::{layer::layer_fn, Service, ServiceBuilder}; +use tower::{layer::layer_fn, service_fn, Service, ServiceBuilder}; use tower_http::{map_request_body::MapRequestBodyLayer, map_response_body::MapResponseBodyLayer}; mod bidirectional_stream; diff --git a/tests/compression/src/server_stream.rs b/tests/compression/src/server_stream.rs index b63f9d9eb..ffa46acbb 100644 --- a/tests/compression/src/server_stream.rs +++ b/tests/compression/src/server_stream.rs @@ -3,10 +3,9 @@ use tonic::Streaming; #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_enabled() { - let svc = test_server::TestServer::new(Svc::default()).send_gzip(); + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + let svc = test_server::TestServer::new(Svc::default()).send_gzip(); let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); @@ -25,18 +24,15 @@ async fn client_enabled_server_enabled() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel).accept_gzip(); + let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); let res = client.compress_output_server_stream(()).await.unwrap(); @@ -49,22 +45,21 @@ async fn client_enabled_server_enabled() { .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(bytes_sent_counter.load(Relaxed)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(bytes_sent_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); stream .next() .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(bytes_sent_counter.load(Relaxed)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(bytes_sent_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] async fn client_disabled_server_enabled() { - let svc = test_server::TestServer::new(Svc::default()).send_gzip(); + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + let svc = test_server::TestServer::new(Svc::default()).send_gzip(); let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); @@ -83,18 +78,15 @@ async fn client_disabled_server_enabled() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel); + let mut client = test_client::TestClient::new(mock_io_channel(client).await); let res = client.compress_output_server_stream(()).await.unwrap(); @@ -107,15 +99,14 @@ async fn client_disabled_server_enabled() { .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(bytes_sent_counter.load(Relaxed)) > UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(bytes_sent_counter.load(SeqCst)) > UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] async fn client_enabled_server_disabled() { - let svc = test_server::TestServer::new(Svc::default()); + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + let svc = test_server::TestServer::new(Svc::default()); let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); @@ -134,18 +125,15 @@ async fn client_enabled_server_disabled() { .into_inner(), ) .add_service(svc) - .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>( + MockStream(server), + )])) .await .unwrap(); } }); - let channel = Channel::builder(format!("http://{}", addr).parse().unwrap()) - .connect() - .await - .unwrap(); - - let mut client = test_client::TestClient::new(channel).accept_gzip(); + let mut client = test_client::TestClient::new(mock_io_channel(client).await).accept_gzip(); let res = client.compress_output_server_stream(()).await.unwrap(); @@ -158,5 +146,5 @@ async fn client_enabled_server_disabled() { .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(bytes_sent_counter.load(Relaxed)) > UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(bytes_sent_counter.load(SeqCst)) > UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs index f16d210b7..75e59f035 100644 --- a/tests/compression/src/util.rs +++ b/tests/compression/src/util.rs @@ -1,3 +1,4 @@ +use super::*; use bytes::Bytes; use futures::ready; use http_body::Body; @@ -5,11 +6,13 @@ use pin_project::pin_project; use std::{ pin::Pin, sync::{ - atomic::{AtomicUsize, Ordering::Relaxed}, + atomic::{AtomicUsize, Ordering::SeqCst}, Arc, }, task::{Context, Poll}, }; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tonic::transport::{server::Connected, Channel}; use tower_http::map_request_body::MapRequestBodyLayer; /// A body that tracks how many bytes passes through it @@ -35,7 +38,8 @@ where let counter: Arc = this.counter.clone(); match ready!(this.inner.poll_data(cx)) { Some(Ok(chunk)) => { - counter.fetch_add(chunk.len(), Relaxed); + println!("response body chunk size = {}", chunk.len()); + counter.fetch_add(chunk.len(), SeqCst); Poll::Ready(Some(Ok(chunk))) } x => Poll::Ready(x), @@ -69,7 +73,8 @@ pub fn measure_request_body_size_layer( tokio::spawn(async move { while let Some(chunk) = body.data().await { let chunk = chunk.unwrap(); - bytes_sent_counter.fetch_add(chunk.len(), Relaxed); + println!("request body chunk size = {}", chunk.len()); + bytes_sent_counter.fetch_add(chunk.len(), SeqCst); tx.send_data(chunk).await.unwrap(); } @@ -81,3 +86,53 @@ pub fn measure_request_body_size_layer( new_body }) } + +#[derive(Debug)] +pub struct MockStream(pub tokio::io::DuplexStream); + +impl Connected for MockStream { + type ConnectInfo = (); + + fn connect_info(&self) -> Self::ConnectInfo {} +} + +impl AsyncRead for MockStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl AsyncWrite for MockStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} + +pub async fn mock_io_channel(client: tokio::io::DuplexStream) -> Channel { + let mut client = Some(client); + + Endpoint::try_from("http://[::]:50051") + .unwrap() + .connect_with_connector(service_fn(move |_: Uri| { + let client = client.take().unwrap(); + async move { Ok::<_, std::io::Error>(MockStream(client)) } + })) + .await + .unwrap() +} diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index c1f438a64..e6f3e8a4d 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -137,7 +137,8 @@ pub(crate) fn compress( ); let mut out_writer = out_buf.writer(); - tokio::task::block_in_place(|| std::io::copy(&mut gzip_encoder, &mut out_writer))?; + let len = + tokio::task::block_in_place(|| std::io::copy(&mut gzip_encoder, &mut out_writer))?; } } From 1a784802cbbceaccca981d571e0357fd2ea80348 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 1 Jul 2021 17:22:40 +0200 Subject: [PATCH 26/29] Fix lints --- tests/compression/src/util.rs | 1 + tonic/src/codec/compression.rs | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs index 75e59f035..75df07f9b 100644 --- a/tests/compression/src/util.rs +++ b/tests/compression/src/util.rs @@ -124,6 +124,7 @@ impl AsyncWrite for MockStream { } } +#[allow(dead_code)] pub async fn mock_io_channel(client: tokio::io::DuplexStream) -> Channel { let mut client = Some(client); diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index e6f3e8a4d..c1f438a64 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -137,8 +137,7 @@ pub(crate) fn compress( ); let mut out_writer = out_buf.writer(); - let len = - tokio::task::block_in_place(|| std::io::copy(&mut gzip_encoder, &mut out_writer))?; + tokio::task::block_in_place(|| std::io::copy(&mut gzip_encoder, &mut out_writer))?; } } From 9eaffb8e11c49baee65d434b00b451eb693f5bfb Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 1 Jul 2021 17:38:03 +0200 Subject: [PATCH 27/29] Use separate counters --- tests/compression/src/bidirectional_stream.rs | 16 +++--- tests/compression/src/client_stream.rs | 28 ++++++----- tests/compression/src/compressing_request.rs | 8 +-- tests/compression/src/compressing_response.rs | 50 +++++++++---------- tests/compression/src/server_stream.rs | 26 +++++----- 5 files changed, 68 insertions(+), 60 deletions(-) diff --git a/tests/compression/src/bidirectional_stream.rs b/tests/compression/src/bidirectional_stream.rs index 83e2031b3..42106d1cb 100644 --- a/tests/compression/src/bidirectional_stream.rs +++ b/tests/compression/src/bidirectional_stream.rs @@ -8,7 +8,8 @@ async fn client_enabled_server_enabled() { .accept_gzip() .send_gzip(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let request_bytes_counter = Arc::new(AtomicUsize::new(0)); + let response_bytes_counter = Arc::new(AtomicUsize::new(0)); fn assert_right_encoding(req: http::Request) -> http::Request { assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); @@ -16,17 +17,20 @@ async fn client_enabled_server_enabled() { } tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let request_bytes_counter = request_bytes_counter.clone(); + let response_bytes_counter = response_bytes_counter.clone(); async move { Server::builder() .layer( ServiceBuilder::new() .map_request(assert_right_encoding) - .layer(measure_request_body_size_layer(bytes_sent_counter.clone())) + .layer(measure_request_body_size_layer( + request_bytes_counter.clone(), + )) .layer(MapResponseBodyLayer::new(move |body| { util::CountBytesBody { inner: body, - counter: bytes_sent_counter.clone(), + counter: response_bytes_counter.clone(), } })) .into_inner(), @@ -69,6 +73,6 @@ async fn client_enabled_server_enabled() { .expect("stream empty") .expect("item was error"); - let bytes_sent = bytes_sent_counter.load(SeqCst); - assert!(dbg!(bytes_sent) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(request_bytes_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(response_bytes_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/client_stream.rs b/tests/compression/src/client_stream.rs index 30bad6f0f..1dfb0787c 100644 --- a/tests/compression/src/client_stream.rs +++ b/tests/compression/src/client_stream.rs @@ -7,7 +7,7 @@ async fn client_enabled_server_enabled() { let svc = test_server::TestServer::new(Svc::default()).accept_gzip(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let request_bytes_counter = Arc::new(AtomicUsize::new(0)); fn assert_right_encoding(req: http::Request) -> http::Request { assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); @@ -15,13 +15,15 @@ async fn client_enabled_server_enabled() { } tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let request_bytes_counter = request_bytes_counter.clone(); async move { Server::builder() .layer( ServiceBuilder::new() .map_request(assert_right_encoding) - .layer(measure_request_body_size_layer(bytes_sent_counter.clone())) + .layer(measure_request_body_size_layer( + request_bytes_counter.clone(), + )) .into_inner(), ) .add_service(svc) @@ -41,7 +43,7 @@ async fn client_enabled_server_enabled() { client.compress_input_client_stream(req).await.unwrap(); - let bytes_sent = bytes_sent_counter.load(SeqCst); + let bytes_sent = request_bytes_counter.load(SeqCst); assert!(dbg!(bytes_sent) < UNCOMPRESSED_MIN_BODY_SIZE); } @@ -51,7 +53,7 @@ async fn client_disabled_server_enabled() { let svc = test_server::TestServer::new(Svc::default()).accept_gzip(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let request_bytes_counter = Arc::new(AtomicUsize::new(0)); fn assert_right_encoding(req: http::Request) -> http::Request { assert!(req.headers().get("grpc-encoding").is_none()); @@ -59,13 +61,15 @@ async fn client_disabled_server_enabled() { } tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let request_bytes_counter = request_bytes_counter.clone(); async move { Server::builder() .layer( ServiceBuilder::new() .map_request(assert_right_encoding) - .layer(measure_request_body_size_layer(bytes_sent_counter.clone())) + .layer(measure_request_body_size_layer( + request_bytes_counter.clone(), + )) .into_inner(), ) .add_service(svc) @@ -85,7 +89,7 @@ async fn client_disabled_server_enabled() { client.compress_input_client_stream(req).await.unwrap(); - let bytes_sent = bytes_sent_counter.load(SeqCst); + let bytes_sent = request_bytes_counter.load(SeqCst); assert!(dbg!(bytes_sent) > UNCOMPRESSED_MIN_BODY_SIZE); } @@ -126,10 +130,10 @@ async fn compressing_response_from_client_stream() { let svc = test_server::TestServer::new(Svc::default()).send_gzip(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let response_bytes_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let response_bytes_counter = response_bytes_counter.clone(); async move { Server::builder() .layer( @@ -137,7 +141,7 @@ async fn compressing_response_from_client_stream() { .layer(MapResponseBodyLayer::new(move |body| { util::CountBytesBody { inner: body, - counter: bytes_sent_counter.clone(), + counter: response_bytes_counter.clone(), } })) .into_inner(), @@ -158,6 +162,6 @@ async fn compressing_response_from_client_stream() { let res = client.compress_output_client_stream(req).await.unwrap(); assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); - let bytes_sent = bytes_sent_counter.load(SeqCst); + let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index dfdad2257..50dcc9f7d 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -7,7 +7,7 @@ async fn client_enabled_server_enabled() { let svc = test_server::TestServer::new(Svc::default()).accept_gzip(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let request_bytes_counter = Arc::new(AtomicUsize::new(0)); fn assert_right_encoding(req: http::Request) -> http::Request { assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); @@ -15,7 +15,7 @@ async fn client_enabled_server_enabled() { } tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let request_bytes_counter = request_bytes_counter.clone(); async move { Server::builder() .layer( @@ -23,7 +23,7 @@ async fn client_enabled_server_enabled() { .layer( ServiceBuilder::new() .map_request(assert_right_encoding) - .layer(measure_request_body_size_layer(bytes_sent_counter)) + .layer(measure_request_body_size_layer(request_bytes_counter)) .into_inner(), ) .into_inner(), @@ -46,7 +46,7 @@ async fn client_enabled_server_enabled() { }) .await .unwrap(); - let bytes_sent = bytes_sent_counter.load(SeqCst); + let bytes_sent = request_bytes_counter.load(SeqCst); assert!(dbg!(bytes_sent) < UNCOMPRESSED_MIN_BODY_SIZE); } } diff --git a/tests/compression/src/compressing_response.rs b/tests/compression/src/compressing_response.rs index 305e97c50..d6b6b3d1c 100644 --- a/tests/compression/src/compressing_response.rs +++ b/tests/compression/src/compressing_response.rs @@ -33,10 +33,10 @@ async fn client_enabled_server_enabled() { let svc = test_server::TestServer::new(Svc::default()).send_gzip(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let response_bytes_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let response_bytes_counter = response_bytes_counter.clone(); async move { Server::builder() .layer( @@ -45,7 +45,7 @@ async fn client_enabled_server_enabled() { .layer(MapResponseBodyLayer::new(move |body| { util::CountBytesBody { inner: body, - counter: bytes_sent_counter.clone(), + counter: response_bytes_counter.clone(), } })) .into_inner(), @@ -64,7 +64,7 @@ async fn client_enabled_server_enabled() { for _ in 0..3 { let res = client.compress_output_unary(()).await.unwrap(); assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); - let bytes_sent = bytes_sent_counter.load(SeqCst); + let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } } @@ -75,10 +75,10 @@ async fn client_enabled_server_disabled() { let svc = test_server::TestServer::new(Svc::default()); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let response_bytes_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let response_bytes_counter = response_bytes_counter.clone(); async move { Server::builder() // no compression enable on the server so responses should not be compressed @@ -87,7 +87,7 @@ async fn client_enabled_server_disabled() { .layer(MapResponseBodyLayer::new(move |body| { util::CountBytesBody { inner: body, - counter: bytes_sent_counter.clone(), + counter: response_bytes_counter.clone(), } })) .into_inner(), @@ -107,7 +107,7 @@ async fn client_enabled_server_disabled() { assert!(res.metadata().get("grpc-encoding").is_none()); - let bytes_sent = bytes_sent_counter.load(SeqCst); + let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } @@ -141,10 +141,10 @@ async fn client_disabled() { let svc = test_server::TestServer::new(Svc::default()).send_gzip(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let response_bytes_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let response_bytes_counter = response_bytes_counter.clone(); async move { Server::builder() .layer( @@ -153,7 +153,7 @@ async fn client_disabled() { .layer(MapResponseBodyLayer::new(move |body| { util::CountBytesBody { inner: body, - counter: bytes_sent_counter.clone(), + counter: response_bytes_counter.clone(), } })) .into_inner(), @@ -173,7 +173,7 @@ async fn client_disabled() { assert!(res.metadata().get("grpc-encoding").is_none()); - let bytes_sent = bytes_sent_counter.load(SeqCst); + let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } @@ -224,10 +224,10 @@ async fn disabling_compression_on_single_response() { }) .send_gzip(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let response_bytes_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let response_bytes_counter = response_bytes_counter.clone(); async move { Server::builder() .layer( @@ -235,7 +235,7 @@ async fn disabling_compression_on_single_response() { .layer(MapResponseBodyLayer::new(move |body| { util::CountBytesBody { inner: body, - counter: bytes_sent_counter.clone(), + counter: response_bytes_counter.clone(), } })) .into_inner(), @@ -253,7 +253,7 @@ async fn disabling_compression_on_single_response() { let res = client.compress_output_unary(()).await.unwrap(); assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); - let bytes_sent = bytes_sent_counter.load(SeqCst); + let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } @@ -266,10 +266,10 @@ async fn disabling_compression_on_response_but_keeping_compression_on_stream() { }) .send_gzip(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let response_bytes_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let response_bytes_counter = response_bytes_counter.clone(); async move { Server::builder() .layer( @@ -277,7 +277,7 @@ async fn disabling_compression_on_response_but_keeping_compression_on_stream() { .layer(MapResponseBodyLayer::new(move |body| { util::CountBytesBody { inner: body, - counter: bytes_sent_counter.clone(), + counter: response_bytes_counter.clone(), } })) .into_inner(), @@ -304,14 +304,14 @@ async fn disabling_compression_on_response_but_keeping_compression_on_stream() { .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(bytes_sent_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(response_bytes_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); stream .next() .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(bytes_sent_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(response_bytes_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] @@ -323,10 +323,10 @@ async fn disabling_compression_on_response_from_client_stream() { }) .send_gzip(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let response_bytes_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let response_bytes_counter = response_bytes_counter.clone(); async move { Server::builder() .layer( @@ -334,7 +334,7 @@ async fn disabling_compression_on_response_from_client_stream() { .layer(MapResponseBodyLayer::new(move |body| { util::CountBytesBody { inner: body, - counter: bytes_sent_counter.clone(), + counter: response_bytes_counter.clone(), } })) .into_inner(), @@ -355,6 +355,6 @@ async fn disabling_compression_on_response_from_client_stream() { let res = client.compress_output_client_stream(req).await.unwrap(); assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); - let bytes_sent = bytes_sent_counter.load(SeqCst); + let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/server_stream.rs b/tests/compression/src/server_stream.rs index ffa46acbb..c80de703b 100644 --- a/tests/compression/src/server_stream.rs +++ b/tests/compression/src/server_stream.rs @@ -7,10 +7,10 @@ async fn client_enabled_server_enabled() { let svc = test_server::TestServer::new(Svc::default()).send_gzip(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let response_bytes_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let response_bytes_counter = response_bytes_counter.clone(); async move { Server::builder() .layer( @@ -18,7 +18,7 @@ async fn client_enabled_server_enabled() { .layer(MapResponseBodyLayer::new(move |body| { util::CountBytesBody { inner: body, - counter: bytes_sent_counter.clone(), + counter: response_bytes_counter.clone(), } })) .into_inner(), @@ -45,14 +45,14 @@ async fn client_enabled_server_enabled() { .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(bytes_sent_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(response_bytes_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); stream .next() .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(bytes_sent_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(response_bytes_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] @@ -61,10 +61,10 @@ async fn client_disabled_server_enabled() { let svc = test_server::TestServer::new(Svc::default()).send_gzip(); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let response_bytes_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let response_bytes_counter = response_bytes_counter.clone(); async move { Server::builder() .layer( @@ -72,7 +72,7 @@ async fn client_disabled_server_enabled() { .layer(MapResponseBodyLayer::new(move |body| { util::CountBytesBody { inner: body, - counter: bytes_sent_counter.clone(), + counter: response_bytes_counter.clone(), } })) .into_inner(), @@ -99,7 +99,7 @@ async fn client_disabled_server_enabled() { .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(bytes_sent_counter.load(SeqCst)) > UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(response_bytes_counter.load(SeqCst)) > UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] @@ -108,10 +108,10 @@ async fn client_enabled_server_disabled() { let svc = test_server::TestServer::new(Svc::default()); - let bytes_sent_counter = Arc::new(AtomicUsize::new(0)); + let response_bytes_counter = Arc::new(AtomicUsize::new(0)); tokio::spawn({ - let bytes_sent_counter = bytes_sent_counter.clone(); + let response_bytes_counter = response_bytes_counter.clone(); async move { Server::builder() .layer( @@ -119,7 +119,7 @@ async fn client_enabled_server_disabled() { .layer(MapResponseBodyLayer::new(move |body| { util::CountBytesBody { inner: body, - counter: bytes_sent_counter.clone(), + counter: response_bytes_counter.clone(), } })) .into_inner(), @@ -146,5 +146,5 @@ async fn client_enabled_server_disabled() { .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(bytes_sent_counter.load(SeqCst)) > UNCOMPRESSED_MIN_BODY_SIZE); + assert!(dbg!(response_bytes_counter.load(SeqCst)) > UNCOMPRESSED_MIN_BODY_SIZE); } From e1e13a1a58fe555f56c27483566fc98cc6029649 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 1 Jul 2021 17:43:37 +0200 Subject: [PATCH 28/29] Don't make a long stream --- tests/compression/src/lib.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/compression/src/lib.rs b/tests/compression/src/lib.rs index b14a34491..38c12037a 100644 --- a/tests/compression/src/lib.rs +++ b/tests/compression/src/lib.rs @@ -76,7 +76,9 @@ impl test_server::Test for Svc { _req: Request<()>, ) -> Result, Status> { let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); - let stream = futures::stream::repeat(SomeData { data }).map(Ok::<_, Status>); + let stream = futures::stream::repeat(SomeData { data }) + .take(2) + .map(Ok::<_, Status>); Ok(self.prepare_response(Response::new(Box::pin(stream)))) } @@ -120,7 +122,9 @@ impl test_server::Test for Svc { } let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); - let stream = futures::stream::repeat(SomeData { data }).map(Ok::<_, Status>); + let stream = futures::stream::repeat(SomeData { data }) + .take(2) + .map(Ok::<_, Status>); Ok(self.prepare_response(Response::new(Box::pin(stream)))) } } From 5458db1806b123e924a6279e356dbcdc8d856a89 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 1 Jul 2021 20:33:15 +0200 Subject: [PATCH 29/29] Address review feedback --- tests/compression/src/bidirectional_stream.rs | 6 +++--- tests/compression/src/client_stream.rs | 4 ++-- tests/compression/src/compressing_request.rs | 2 +- tests/compression/src/compressing_response.rs | 4 ++-- tests/compression/src/server_stream.rs | 8 ++++---- tonic/Cargo.toml | 2 +- tonic/src/client/grpc.rs | 2 +- tonic/src/codec/compression.rs | 4 ++-- tonic/src/lib.rs | 3 +-- tonic/src/server/grpc.rs | 1 - 10 files changed, 17 insertions(+), 19 deletions(-) diff --git a/tests/compression/src/bidirectional_stream.rs b/tests/compression/src/bidirectional_stream.rs index 42106d1cb..53dc83393 100644 --- a/tests/compression/src/bidirectional_stream.rs +++ b/tests/compression/src/bidirectional_stream.rs @@ -50,7 +50,7 @@ async fn client_enabled_server_enabled() { let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); let stream = futures::stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); - let req = Request::new(Box::pin(stream)); + let req = Request::new(stream); let res = client .compress_input_output_bidirectional_stream(req) @@ -73,6 +73,6 @@ async fn client_enabled_server_enabled() { .expect("stream empty") .expect("item was error"); - assert!(dbg!(request_bytes_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); - assert!(dbg!(response_bytes_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(request_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/client_stream.rs b/tests/compression/src/client_stream.rs index 1dfb0787c..620f917c9 100644 --- a/tests/compression/src/client_stream.rs +++ b/tests/compression/src/client_stream.rs @@ -44,7 +44,7 @@ async fn client_enabled_server_enabled() { client.compress_input_client_stream(req).await.unwrap(); let bytes_sent = request_bytes_counter.load(SeqCst); - assert!(dbg!(bytes_sent) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] @@ -90,7 +90,7 @@ async fn client_disabled_server_enabled() { client.compress_input_client_stream(req).await.unwrap(); let bytes_sent = request_bytes_counter.load(SeqCst); - assert!(dbg!(bytes_sent) > UNCOMPRESSED_MIN_BODY_SIZE); + assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index 50dcc9f7d..de4124110 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -47,7 +47,7 @@ async fn client_enabled_server_enabled() { .await .unwrap(); let bytes_sent = request_bytes_counter.load(SeqCst); - assert!(dbg!(bytes_sent) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } } diff --git a/tests/compression/src/compressing_response.rs b/tests/compression/src/compressing_response.rs index d6b6b3d1c..e60903dc6 100644 --- a/tests/compression/src/compressing_response.rs +++ b/tests/compression/src/compressing_response.rs @@ -304,14 +304,14 @@ async fn disabling_compression_on_response_but_keeping_compression_on_stream() { .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(response_bytes_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); stream .next() .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(response_bytes_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] diff --git a/tests/compression/src/server_stream.rs b/tests/compression/src/server_stream.rs index c80de703b..2d302bf31 100644 --- a/tests/compression/src/server_stream.rs +++ b/tests/compression/src/server_stream.rs @@ -45,14 +45,14 @@ async fn client_enabled_server_enabled() { .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(response_bytes_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); stream .next() .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(response_bytes_counter.load(SeqCst)) < UNCOMPRESSED_MIN_BODY_SIZE); + assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] @@ -99,7 +99,7 @@ async fn client_disabled_server_enabled() { .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(response_bytes_counter.load(SeqCst)) > UNCOMPRESSED_MIN_BODY_SIZE); + assert!(response_bytes_counter.load(SeqCst) > UNCOMPRESSED_MIN_BODY_SIZE); } #[tokio::test(flavor = "multi_thread")] @@ -146,5 +146,5 @@ async fn client_enabled_server_disabled() { .await .expect("stream empty") .expect("item was error"); - assert!(dbg!(response_bytes_counter.load(SeqCst)) > UNCOMPRESSED_MIN_BODY_SIZE); + assert!(response_bytes_counter.load(SeqCst) > UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 9223208de..8cd6eea64 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -39,7 +39,7 @@ tls-roots-common = ["tls"] tls-roots = ["tls-roots-common", "rustls-native-certs"] tls-webpki-roots = ["tls-roots-common", "webpki-roots"] prost = ["prost1", "prost-derive"] -compression = ["tokio/rt-multi-thread", "flate2"] +compression = ["flate2"] # [[bench]] # name = "bench_main" diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 6e8e2d99f..c89c4c60f 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -86,7 +86,7 @@ impl Grpc { #[cfg(not(feature = "compression"))] pub fn send_gzip(self) -> Self { panic!( - "`send_gzip` called on a server but the `compression` feature is not enabled on tonic" + "`send_gzip` called on a client but the `compression` feature is not enabled on tonic" ); } diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index c1f438a64..8f4c279fd 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -137,7 +137,7 @@ pub(crate) fn compress( ); let mut out_writer = out_buf.writer(); - tokio::task::block_in_place(|| std::io::copy(&mut gzip_encoder, &mut out_writer))?; + std::io::copy(&mut gzip_encoder, &mut out_writer)?; } } @@ -162,7 +162,7 @@ pub(crate) fn decompress( let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]); let mut out_writer = out_buf.writer(); - tokio::task::block_in_place(|| std::io::copy(&mut gzip_decoder, &mut out_writer))?; + std::io::copy(&mut gzip_decoder, &mut out_writer)?; } } diff --git a/tonic/src/lib.rs b/tonic/src/lib.rs index ebd7f3de8..7fb77a63b 100644 --- a/tonic/src/lib.rs +++ b/tonic/src/lib.rs @@ -30,8 +30,7 @@ //! - `prost`: Enables the [`prost`] based gRPC [`Codec`] implementation. //! - `compression`: Enables compressing requests, responses, and streams. Note //! that you must enable the `compression` feature on both `tonic` and -//! `tonic-build` to use it. Depends on `tokio`'s multi-threaded runtime and -//! [flate2]. Not enabled by default. +//! `tonic-build` to use it. Depends on [flate2]. Not enabled by default. //! //! # Structure //! diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index 21b542981..7978e2b22 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -267,7 +267,6 @@ where S: ClientStreamingService, B: Body + Send + Sync + 'static, B::Error: Into + Send + 'static, - T: std::fmt::Debug, { #[cfg(feature = "compression")] let accept_encoding = CompressionEncoding::from_accept_encoding_header(