diff --git a/tests/integration_tests/tests/connection.rs b/tests/integration_tests/tests/connection.rs index 788a04318..4ffc2c3b2 100644 --- a/tests/integration_tests/tests/connection.rs +++ b/tests/integration_tests/tests/connection.rs @@ -3,7 +3,22 @@ use integration_tests::pb::{test_client::TestClient, test_server, Input, Output} use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::sync::oneshot; -use tonic::{transport::Server, Request, Response, Status}; +use tonic::{ + transport::{Endpoint, Server}, + Request, Response, Status, +}; + +struct Svc(Arc>>>); + +#[tonic::async_trait] +impl test_server::Test for Svc { + async fn unary_call(&self, _: Request) -> Result, Status> { + let mut l = self.0.lock().unwrap(); + l.take().unwrap().send(()).unwrap(); + + Ok(Response::new(Output {})) + } +} #[tokio::test] async fn connect_returns_err() { @@ -14,18 +29,6 @@ async fn connect_returns_err() { #[tokio::test] async fn connect_returns_err_via_call_after_connected() { - struct Svc(Arc>>>); - - #[tonic::async_trait] - impl test_server::Test for Svc { - async fn unary_call(&self, _: Request) -> Result, Status> { - let mut l = self.0.lock().unwrap(); - l.take().unwrap().send(()).unwrap(); - - Ok(Response::new(Output {})) - } - } - let (tx, rx) = oneshot::channel(); let sender = Arc::new(Mutex::new(Some(tx))); let svc = test_server::TestServer::new(Svc(sender)); @@ -53,3 +56,37 @@ async fn connect_returns_err_via_call_after_connected() { jh.await.unwrap(); } + +#[tokio::test] +async fn connect_lazy_reconnects_after_first_failure() { + let (tx, rx) = oneshot::channel(); + let sender = Arc::new(Mutex::new(Some(tx))); + let svc = test_server::TestServer::new(Svc(sender)); + + let channel = Endpoint::from_static("http://127.0.0.1:1339") + .connect_lazy() + .unwrap(); + + let mut client = TestClient::new(channel); + + // First call should fail, the server is not running + client.unary_call(Request::new(Input {})).await.unwrap_err(); + + // Start the server now, second call should succeed + let jh = tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1339".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + tokio::time::delay_for(Duration::from_millis(100)).await; + client.unary_call(Request::new(Input {})).await.unwrap(); + + // The server shut down, third call should fail + tokio::time::delay_for(Duration::from_millis(100)).await; + client.unary_call(Request::new(Input {})).await.unwrap_err(); + + jh.await.unwrap(); +} diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 5889f64c5..0e379996e 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -234,7 +234,7 @@ impl Endpoint { #[cfg(not(feature = "tls"))] let connector = service::connector(http); - Channel::new(connector, self.clone()) + Ok(Channel::new(connector, self.clone())) } /// Connect with a custom connector. diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index e32a774db..63b3080ee 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -130,7 +130,7 @@ impl Channel { (Self::balance(list, DEFAULT_BUFFER_SIZE), tx) } - pub(crate) fn new(connector: C, endpoint: Endpoint) -> Result + pub(crate) fn new(connector: C, endpoint: Endpoint) -> Self where C: Service + Send + 'static, C::Error: Into + Send, @@ -139,10 +139,10 @@ impl Channel { { let buffer_size = endpoint.buffer_size.clone().unwrap_or(DEFAULT_BUFFER_SIZE); - let svc = Connection::new(connector, endpoint).map_err(super::Error::from_source)?; + let svc = Connection::lazy(connector, endpoint); let svc = Buffer::new(Either::A(svc), buffer_size); - Ok(Channel { svc }) + Channel { svc } } pub(crate) async fn connect(connector: C, endpoint: Endpoint) -> Result diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index a3a934973..c3e8769fe 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -29,7 +29,7 @@ pub(crate) struct Connection { } impl Connection { - pub(crate) fn new(connector: C, endpoint: Endpoint) -> Result + fn new(connector: C, endpoint: Endpoint, is_lazy: bool) -> Self where C: Service + Send + 'static, C::Error: Into + Send, @@ -61,13 +61,13 @@ impl Connection { .into_inner(); let connector = HyperConnect::new(connector, settings); - let conn = Reconnect::new(connector, endpoint.uri.clone()); + let conn = Reconnect::new(connector, endpoint.uri.clone(), is_lazy); let inner = stack.layer(conn); - Ok(Self { + Self { inner: BoxService::new(inner), - }) + } } pub(crate) async fn connect(connector: C, endpoint: Endpoint) -> Result @@ -77,7 +77,17 @@ impl Connection { C::Future: Unpin + Send, C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, { - Self::new(connector, endpoint)?.ready_oneshot().await + Self::new(connector, endpoint, false).ready_oneshot().await + } + + pub(crate) fn lazy(connector: C, endpoint: Endpoint) -> Self + where + C: Service + Send + 'static, + C::Error: Into + Send, + C::Future: Unpin + Send, + C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + { + Self::new(connector, endpoint, true) } } diff --git a/tonic/src/transport/service/reconnect.rs b/tonic/src/transport/service/reconnect.rs index 1d6394a2b..e01074b00 100644 --- a/tonic/src/transport/service/reconnect.rs +++ b/tonic/src/transport/service/reconnect.rs @@ -19,6 +19,7 @@ where target: Target, error: Option, has_been_connected: bool, + is_lazy: bool, } #[derive(Debug)] @@ -32,13 +33,14 @@ impl Reconnect where M: Service, { - pub(crate) fn new(mk_service: M, target: Target) -> Self { + pub(crate) fn new(mk_service: M, target: Target, is_lazy: bool) -> Self { Reconnect { mk_service, state: State::Idle, target, error: None, has_been_connected: false, + is_lazy, } } } @@ -89,11 +91,11 @@ where state = State::Idle; - if self.has_been_connected { + if !(self.has_been_connected || self.is_lazy) { + return Poll::Ready(Err(e.into())); + } else { self.error = Some(e.into()); break; - } else { - return Poll::Ready(Err(e.into())); } } }