diff --git a/tokio-stream/Cargo.toml b/tokio-stream/Cargo.toml index 83f8551826c..cbdda8a62bd 100644 --- a/tokio-stream/Cargo.toml +++ b/tokio-stream/Cargo.toml @@ -26,6 +26,7 @@ io-util = ["tokio/io-util"] fs = ["tokio/fs"] sync = ["tokio/sync", "tokio-util"] signal = ["tokio/signal"] +task = ["tokio-util/task"] [dependencies] futures-core = { version = "0.3.0" } diff --git a/tokio-stream/src/macros.rs b/tokio-stream/src/macros.rs index 1e3b61bac72..5733a6edbe1 100644 --- a/tokio-stream/src/macros.rs +++ b/tokio-stream/src/macros.rs @@ -18,6 +18,16 @@ macro_rules! cfg_io_util { } } +macro_rules! cfg_task { + ($($item:item)*) => { + $( + #[cfg(feature = "task")] + #[cfg_attr(docsrs, doc(cfg(feature = "task")))] + $item + )* + } +} + macro_rules! cfg_net { ($($item:item)*) => { $( diff --git a/tokio-stream/src/wrappers.rs b/tokio-stream/src/wrappers.rs index 62cabe4f7d0..4a495d15b87 100644 --- a/tokio-stream/src/wrappers.rs +++ b/tokio-stream/src/wrappers.rs @@ -48,6 +48,11 @@ cfg_net! { pub use unix_listener::UnixListenerStream; } +cfg_task! { + mod task; + pub use task::TaskSetStream; +} + cfg_io_util! { mod split; pub use split::SplitStream; diff --git a/tokio-stream/src/wrappers/task.rs b/tokio-stream/src/wrappers/task.rs new file mode 100644 index 00000000000..70fe68cc431 --- /dev/null +++ b/tokio-stream/src/wrappers/task.rs @@ -0,0 +1,47 @@ +use crate::Stream; +use std::task::Context; +use tokio::macros::support::{Pin, Poll}; +use tokio::task::JoinError; +use tokio_util::taskset::TaskSet; + +/// A wrapper around [`TaskSet`] that implements [`Stream`]. +/// +/// [`TcpListener`]: struct@tokio_util::task::TaskSet +/// [`Stream`]: trait@crate::Stream +#[derive(Debug, Default)] +#[cfg_attr(docsrs, doc(cfg(feature = "task")))] +pub struct TaskSetStream { + inner: TaskSet, +} + +impl TaskSetStream { + /// Create a new `TaskSetStream`. + pub fn new(task_set: TaskSet) -> Self { + Self { inner: task_set } + } + + /// Get back the inner `TaskSet`. + pub fn into_inner(self) -> TaskSet { + self.inner + } +} + +impl Stream for TaskSetStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_next_finished(cx) + } +} + +impl AsRef> for TaskSetStream { + fn as_ref(&self) -> &TaskSet { + &self.inner + } +} + +impl AsMut> for TaskSetStream { + fn as_mut(&mut self) -> &mut TaskSet { + &mut self.inner + } +} diff --git a/tokio-util/Cargo.toml b/tokio-util/Cargo.toml index 676b0e2ec94..ad0aa3bcd35 100644 --- a/tokio-util/Cargo.toml +++ b/tokio-util/Cargo.toml @@ -23,7 +23,7 @@ categories = ["asynchronous"] default = [] # Shorthand for enabling everything -full = ["codec", "compat", "io-util", "time", "net", "rt"] +full = ["codec", "compat", "io-util", "time", "net", "rt", "task"] net = ["tokio/net"] compat = ["futures-io",] @@ -32,6 +32,7 @@ time = ["tokio/time","slab"] io = [] io-util = ["io", "tokio/rt", "tokio/io-util"] rt = ["tokio/rt"] +task = ["tokio/rt", "futures-util"] __docs_rs = ["futures-util"] @@ -43,6 +44,7 @@ futures-core = "0.3.0" futures-sink = "0.3.0" futures-io = { version = "0.3.0", optional = true } futures-util = { version = "0.3.0", optional = true } +futures = { version = "0.3.0", optional = true } log = "0.4" pin-project-lite = "0.2.0" slab = { version = "0.4.1", optional = true } # Backs `DelayQueue` @@ -52,6 +54,7 @@ tokio = { version = "1.0.0", path = "../tokio", features = ["full"] } tokio-test = { version = "0.4.0", path = "../tokio-test" } tokio-stream = { version = "0.1", path = "../tokio-stream" } +scopeguard = "1.1.0" async-stream = "0.3.0" futures = "0.3.0" futures-test = "0.3.5" diff --git a/tokio-util/src/cfg.rs b/tokio-util/src/cfg.rs index 4035255aff0..e27795658c4 100644 --- a/tokio-util/src/cfg.rs +++ b/tokio-util/src/cfg.rs @@ -8,6 +8,16 @@ macro_rules! cfg_codec { } } +macro_rules! cfg_task { + ($($item:item)*) => { + $( + #[cfg(feature = "task")] + #[cfg_attr(docsrs, doc(cfg(feature = "task")))] + $item + )* + } +} + macro_rules! cfg_compat { ($($item:item)*) => { $( diff --git a/tokio-util/src/lib.rs b/tokio-util/src/lib.rs index 0b3e5962343..ffc1d0a698e 100644 --- a/tokio-util/src/lib.rs +++ b/tokio-util/src/lib.rs @@ -45,6 +45,10 @@ cfg_rt! { pub mod context; } +cfg_task! { + pub mod taskset; +} + cfg_time! { pub mod time; } diff --git a/tokio-util/src/taskset.rs b/tokio-util/src/taskset.rs new file mode 100644 index 00000000000..2a50dbc4852 --- /dev/null +++ b/tokio-util/src/taskset.rs @@ -0,0 +1,349 @@ +//! TaskSet API +//! +//! This module includes a TaskSet API for scoping the lifetimes of tasks and joining tasks in a +//! manner similar to [`futures_util::stream::FuturesUnordered`] + +use futures_util::FutureExt; +use std::future::Future; +use std::task::{Context, Poll}; +use tokio::runtime::Handle; +use tokio::task::{JoinError, JoinHandle}; + +/// A [`TaskSet`] is a mechanism for regulating the lifetimes of a set of asynchronous tasks. +/// +/// # Task Lifetimes +/// Tasks spawned on a [`TaskSet`] live at most slightly longer than the set that spawns them. +/// The reason for this is that aborting a task is not an instantaneous operation. +/// +/// Tasks could be running in another worker thread at the time they are cancelled, preventing +/// them from being canceled immediately. As a result, when you abort those tasks using the join +/// handles, Tokio marks those futures for cancellation. There are, however ways to await the +/// cancellation of tasks, but they only work in an async context. +/// +/// If you need to cancel a set and wait for the cancellation to complete, use [`Self::shutdown`]. +/// That function won't complete until *all* tasks have exited. +#[derive(Debug)] +pub struct TaskSet { + tasks: Vec>, + handle: Handle, +} + +impl TaskSet { + /// Constructs a new TaskSet. + /// + /// # Panic + /// Panics if invoked outside the context of a tokio runtime. + pub fn new() -> Self { + Self::new_with_handle(Handle::current()) + } + + /// Constructs a new TaskSet which will spawn tasks using the supplied handle. + pub fn new_with_handle(handle: Handle) -> Self { + Self { + tasks: Vec::new(), + handle, + } + } + + /// Returns the amount of unjoined tasks in the set. + pub fn len(&self) -> usize { + self.tasks.len() + } + + /// Returns true if there are no tasks left in the set. + pub fn is_empty(&self) -> bool { + self.tasks.is_empty() + } + + /// Consume the set, waiting for all tasks on it to complete. + /// + /// # Output + /// The output of this function orders the results of tasks in the same order with which they + /// were spawned. + /// + /// # Cancellation Safety + /// When canceled, any tasks which have yet to complete will become detached, and the results + /// of completed tasks will be discarded. + /// + /// # Examples + /// ``` + /// # use tokio_util::taskset::TaskSet; + /// # tokio_test::block_on(async { + /// let mut set = TaskSet::new(); + /// + /// const NUMS: [u8; 5] = [0, 1, 2, 3, 4]; + /// + /// for x in NUMS.iter() { + /// set.spawn(async move { *x }); + /// } + /// + /// let joined: Vec<_> = set.join_all().await.into_iter().map(|x| x.unwrap()).collect(); + /// + /// assert_eq!(NUMS, joined.as_slice()); + /// # }); + /// ``` + pub async fn join_all(mut self) -> Vec> { + let mut output = Vec::with_capacity(self.tasks.len()); + + for task in self.tasks.iter_mut() { + output.push(task.await) + } + + self.tasks.clear(); + + output + } + + /// Join onto the next available task, removing it from the set. + /// + /// # Cancellation Safety + /// This function is cancellation-safe. If dropped, it will be as if this was never called. + /// + /// # Examples + /// ``` + /// # use tokio_util::taskset::TaskSet; + /// # tokio_test::block_on(async { + /// let mut set = TaskSet::new(); + /// + /// set.spawn(async { 5u8 }); + /// set.spawn(std::future::pending()); + /// + /// let joined = set.next_finished().await.unwrap().unwrap(); + /// assert_eq!(5u8, joined); + /// # }); + /// ``` + pub async fn next_finished(&mut self) -> Option> { + if self.tasks.is_empty() { + None + } else { + futures_util::future::poll_fn(|cx| self.poll_next_finished(cx)).await + } + } + + /// Shutdown the task set, cancelling all running futures and returning the results from + /// joining the tasks. + /// + /// # Output + /// Like [`Self::join_all`], the ordering of tasks is preserved. To check if a task was + /// cancelled, use [`JoinError::is_cancelled`]. + /// + /// # Cancellation Safety + /// When cancelled, tasks will still have been aborted, but you will have no way of waiting + /// for the aborts to complete. + /// + /// # Examples + /// ## Verifying Task Shutdown + /// ``` + /// # use tokio_util::taskset::TaskSet; + /// # use std::sync::Arc; + /// # use std::sync::atomic::{AtomicU64, Ordering}; + /// # tokio_test::block_on(async { + /// const NUM_TASKS: u64 = 1024; + /// + /// let counter = Arc::new(AtomicU64::default()); + /// + /// let mut set: TaskSet<()> = TaskSet::new(); + /// + /// for _ in 0..NUM_TASKS { + /// let guard = scopeguard::guard(counter.clone(), |x| { + /// x.fetch_add(1, Ordering::SeqCst); + /// }); + /// + /// set.spawn(async move { + /// let _guard = guard; + /// // we must never surrender to the forces of time + /// let _: () = std::future::pending().await; + /// }); + /// } + /// + /// for e in set.shutdown().await.into_iter().map(|x| x.unwrap_err()) { + /// assert!(e.is_cancelled()); + /// } + /// + /// assert_eq!(NUM_TASKS, counter.load(Ordering::Relaxed)); + /// # }); + /// ``` + /// ## Cancellation + /// ``` + /// # use tokio_util::taskset::TaskSet; + /// # use std::sync::Arc; + /// # use std::sync::atomic::{AtomicU64, Ordering}; + /// # tokio_test::block_on(async { + /// const NUM_TASKS: u64 = 64; + /// + /// let counter = Arc::new(AtomicU64::default()); + /// + /// let mut set: TaskSet<()> = TaskSet::new(); + /// + /// for _ in 0..NUM_TASKS { + /// let guard = scopeguard::guard(counter.clone(), |x| { + /// x.fetch_add(1, Ordering::SeqCst); + /// }); + /// + /// set.spawn(async move { + /// let _guard = guard; + /// let _: () = std::future::pending().await; + /// }); + /// } + /// + /// let shutdown = set.shutdown(); + /// + /// // cancel task + /// drop(shutdown); + /// + /// tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + /// assert_eq!(NUM_TASKS, counter.load(Ordering::Relaxed)); + /// # }); + /// ``` + pub fn shutdown(self) -> impl Future>> { + // abort all tasks *before* the cancellable future + self.abort_all(); + // this part can be cancelled without us leaking tasks + self.join_all() + } + + /// Poll the next finished task. + pub fn poll_next_finished( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + // implementation based off of futures_util::select_all + // o(n) due to scan in search of ready tasks + + if self.tasks.is_empty() { + Poll::Ready(None) + } else { + let item = + self.tasks + .iter_mut() + .enumerate() + .find_map(|(i, f)| match f.poll_unpin(cx) { + Poll::Pending => None, + Poll::Ready(e) => Some((i, e)), + }); + match item { + Some((idx, res)) => { + let _ = self.tasks.swap_remove(idx); + Poll::Ready(Some(res)) + } + None => Poll::Pending, + } + } + } + + fn abort_all(&self) { + for task in self.tasks.iter() { + task.abort(); + } + } +} + +impl TaskSet +where + T: Send + 'static, +{ + /// Spawn a task onto the set. + pub fn spawn(&mut self, f: F) + where + F: Future + Send + 'static, + { + let task = self.handle.spawn(f); + + self.tasks.push(task); + } +} + +impl Default for TaskSet { + fn default() -> Self { + Self::new() + } +} + +/// Tasks are aborted on drop. +/// +/// Tasks aborted this way are not instantly cancelled, and the time required depends on a lot of +/// things. If you need to make sure that tasks are immediately cancelled, use [`Self::shutdown`]. +impl Drop for TaskSet { + fn drop(&mut self) { + self.abort_all(); + } +} + +#[cfg(test)] +mod tests { + const NUM_TASKS: u64 = 64; + + use super::*; + use std::sync::atomic::{AtomicU64, Ordering}; + use std::sync::Arc; + + #[tokio::test(flavor = "current_thread")] + async fn test_current_thread_abort() { + test_abort().await + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_multi_thread_abort() { + test_abort().await + } + + #[tokio::test(flavor = "current_thread")] + async fn test_current_thread_drop() { + test_drop().await + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_multi_thread_drop() { + test_drop().await + } + + /// Ensure that the future returned by [`TaskSet::shutdown`] won't complete until all tasks + /// have been dropped. + async fn test_abort() { + let counter = Arc::new(AtomicU64::default()); + + let mut set: TaskSet<()> = TaskSet::new(); + + for _ in 0..NUM_TASKS { + let guard = scopeguard::guard(counter.clone(), |x| { + x.fetch_add(1, Ordering::SeqCst); + }); + + set.spawn(async move { + let _guard = guard; + let _: () = std::future::pending().await; + }); + } + + for e in set.shutdown().await.into_iter().map(|x| x.unwrap_err()) { + assert!(e.is_cancelled()); + } + + assert_eq!(NUM_TASKS, counter.load(Ordering::Relaxed)); + } + + /// Ensure that dropping the TaskSet successfully aborts all tasks on the set. + async fn test_drop() { + let counter = Arc::new(AtomicU64::default()); + + let mut set = TaskSet::new(); + + for _ in 0..NUM_TASKS { + let guard = scopeguard::guard(counter.clone(), |x| { + x.fetch_add(1, Ordering::SeqCst); + }); + + set.spawn(async move { + let _guard = guard; + let _: () = std::future::pending().await; + }); + } + + drop(set); + + // if these simple tasks haven't been dropped yet, we have a problem + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + + assert_eq!(NUM_TASKS, counter.load(Ordering::Relaxed)); + } +}