diff --git a/src/client/dispatch.rs b/src/client/dispatch.rs index 1684804789..31541027e9 100644 --- a/src/client/dispatch.rs +++ b/src/client/dispatch.rs @@ -1,4 +1,4 @@ -use futures::{Async, Poll, Stream}; +use futures::{future, Async, Future, Poll, Stream}; use futures::sync::{mpsc, oneshot}; use want; @@ -202,6 +202,37 @@ impl Callback { } } } + + pub(crate) fn send_when( + self, + mut when: impl Future)>, + ) -> impl Future { + let mut cb = Some(self); + + // "select" on this callback being canceled, and the future completing + future::poll_fn(move || { + match when.poll() { + Ok(Async::Ready(res)) => { + cb.take() + .expect("polled after complete") + .send(Ok(res)); + Ok(().into()) + }, + Ok(Async::NotReady) => { + // check if the callback is canceled + try_ready!(cb.as_mut().unwrap().poll_cancel()); + trace!("send_when canceled"); + Ok(().into()) + }, + Err(err) => { + cb.take() + .expect("polled after complete") + .send(Err(err)); + Ok(().into()) + } + } + }) + } } #[cfg(test)] diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs index bb5cdcb666..491a916572 100644 --- a/src/proto/h2/client.rs +++ b/src/proto/h2/client.rs @@ -163,16 +163,15 @@ where let content_length = content_length_parse_all(res.headers()); let res = res.map(|stream| ::Body::h2(stream, content_length)); - cb.send(Ok(res)); + Ok(res) }, Err(err) => { debug!("client response error: {}", err); - cb.send(Err((::Error::new_h2(err), None))); + Err((::Error::new_h2(err), None)) } } - Ok(()) }); - self.executor.execute(fut)?; + self.executor.execute(cb.send_when(fut))?; continue; },