diff --git a/kubert/src/shutdown.rs b/kubert/src/shutdown.rs index 927ff76..9842b13 100644 --- a/kubert/src/shutdown.rs +++ b/kubert/src/shutdown.rs @@ -119,6 +119,9 @@ impl Shutdown { #[cfg(feature = "runtime")] impl CancelOnShutdown { /// Wraps a `Future` or `Stream` that completes when the shutdown watch fires. + /// + /// The inner `Future`/`Stream` is given the chance to complete before the shutdown watch is + /// polled so that it has a chance to complete its work before the task is cancelled. pub(crate) fn new(watch: Watch, inner: T) -> Self { // XXX Unfortunately the `Watch` API doesn't give us any means to poll for updates, so we // have to box the async call to poll it from the stream. @@ -136,11 +139,14 @@ impl> std::future::Future for CancelOnShutdo fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { let mut this = self.project(); - if this.shutdown.as_mut().poll(cx).is_ready() { + // Drive the future to completion. + if this.inner.poll(cx).is_ready() { return Poll::Ready(()); } - this.inner.poll(cx) + // If the future is pending, register interest in the shutdown watch and complete the future + // if it has fired. + this.shutdown.as_mut().poll(cx) } } @@ -153,10 +159,66 @@ impl futures_core::Stream for CancelOnShutdown { let mut this = self.project(); + // Process items from the stream until it is pending (or the stream ends). + if let Poll::Ready(next) = this.inner.poll_next(cx) { + return Poll::Ready(next); + } + + // If the stream is pending, register interest in the shutdown watch and end the stream if + // it has fired. if this.shutdown.as_mut().poll(cx).is_ready() { return Poll::Ready(None); } - this.inner.poll_next(cx) + Poll::Pending + } +} + +#[cfg(all(test, feature = "runtime"))] +mod test { + use super::CancelOnShutdown; + use tokio_stream::wrappers::ReceiverStream; + use tokio_test::{assert_pending, assert_ready, assert_ready_eq, task}; + + #[tokio::test] + async fn cancel_stream_drains() { + let (shutdown_tx, shutdown_rx) = drain::channel(); + + let (stream_tx, stream_rx) = tokio::sync::mpsc::channel(3); + let mut stream_rx = task::spawn(CancelOnShutdown::new( + shutdown_rx, + ReceiverStream::new(stream_rx), + )); + stream_tx.try_send(1).unwrap(); + stream_tx.try_send(2).unwrap(); + stream_tx.try_send(3).unwrap(); + + assert_ready_eq!(stream_rx.poll_next(), Some(1)); + + let mut drain = task::spawn(shutdown_tx.drain()); + assert_ready_eq!(stream_rx.poll_next(), Some(2)); + assert_ready_eq!(stream_rx.poll_next(), Some(3)); + assert_pending!(drain.poll()); + assert_ready_eq!(stream_rx.poll_next(), None); + assert_ready!(drain.poll()); + } + + #[tokio::test] + async fn cancel_future_ends() { + let (shutdown_tx, shutdown_rx) = drain::channel(); + + let (_tx, rx) = tokio::sync::oneshot::channel::<()>(); + let mut rx = task::spawn(CancelOnShutdown::new( + shutdown_rx, + Box::pin(async move { + rx.await.unwrap(); + }), + )); + assert_pending!(rx.poll()); + + let mut drain = task::spawn(shutdown_tx.drain()); + assert_pending!(drain.poll()); + assert_ready!(rx.poll()); + assert_ready!(drain.poll()); } }