Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a task set #3908

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tokio-stream/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ io-util = ["tokio/io-util"]
fs = ["tokio/fs"]
sync = ["tokio/sync", "tokio-util"]
signal = ["tokio/signal"]
rt = ["tokio/rt"]

[dependencies]
futures-core = { version = "0.3.0" }
Expand Down
10 changes: 10 additions & 0 deletions tokio-stream/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ macro_rules! cfg_sync {
}
}

macro_rules! cfg_rt {
($($item:item)*) => {
$(
#[cfg(feature = "rt")]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
$item
)*
}
}

macro_rules! cfg_signal {
($($item:item)*) => {
$(
Expand Down
5 changes: 5 additions & 0 deletions tokio-stream/src/wrappers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,8 @@ cfg_fs! {
mod read_dir;
pub use read_dir::ReadDirStream;
}

cfg_rt! {
mod task_set;
pub use task_set::TaskSetStream;
}
48 changes: 48 additions & 0 deletions tokio-stream/src/wrappers/task_set.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use crate::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::task::TaskSet;

/// A wrapper around [`TaskSet`] that implements [`Stream`].
/// It automatically propagates panics. You should manully poll
/// a task set if you want to handle them.
///
/// [`TaskSet`]: struct@tokio::task::TaskSet
/// [`Stream`]: trait@crate::Stream
#[derive(Debug)]
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub struct TaskSetStream<T> {
inner: TaskSet<T>,
}

impl<T> TaskSetStream<T> {
/// Create a new `TaskSetStream`.
pub fn new(task_set: TaskSet<T>) -> Self {
Self { inner: task_set }
}

/// Get back the inner `TaskSet`.
pub fn into_inner(self) -> TaskSet<T> {
self.inner
}
}

impl<T> Stream for TaskSetStream<T> {
type Item = T;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
Pin::into_inner(self).inner.poll_next_finished(cx)
}
}

impl<T> AsRef<TaskSet<T>> for TaskSetStream<T> {
fn as_ref(&self) -> &TaskSet<T> {
&self.inner
}
}

impl<T> AsMut<TaskSet<T>> for TaskSetStream<T> {
fn as_mut(&mut self) -> &mut TaskSet<T> {
&mut self.inner
}
}
3 changes: 3 additions & 0 deletions tokio/src/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ cfg_rt! {
mod unconstrained;
pub use unconstrained::{unconstrained, Unconstrained};

mod task_set;
pub use task_set::TaskSet;

cfg_trace! {
mod builder;
pub use builder::Builder;
Expand Down
150 changes: 150 additions & 0 deletions tokio/src/task/task_set.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use crate::{
future::poll_fn,
runtime::Handle,
task::{JoinError, JoinHandle}
};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};

/// A collection of running tasks.
///
/// It simplifies joining and helps to propagate panics.
#[derive(Debug)]
pub struct TaskSet<T> {
unfinished: Vec<JoinHandle<T>>,
// Used for all spawns.
handle: Handle,
}

impl<T> TaskSet<T> {
/// Creates a new empty TaskSet.
///
/// This function must be called inside the Tokio runtime context,
/// otherwise it will panic.
pub fn new() -> Self {
TaskSet::new_with(Handle::current())
}

/// Creates a new empty TaskSet.
///
/// Unlike the `new` function, it explicitly accepts runtime Handle
/// and never panicks.
pub fn new_with(handle: Handle) -> Self {
TaskSet {
unfinished: Vec::new(),
handle,
}
}

/// Returns true if there are no unwaited tasks remaining.
pub fn is_empty(&self) -> bool {
self.unfinished.is_empty()
}

/// Tries to wait for a finished task, with ability to handle failures.
///
/// # Panics
/// Panics if the set is empty.
///
/// # Return value
/// - Pending, if all tasks are still running
/// - Ready(None), if the set is empty.
/// - Ready(Some(Ok(x))), if a task resolved with resut x
/// - Ready(Some(Err(e))), if a task failed with error err.
pub fn try_poll_next_finished(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<T, JoinError>>> {
if self.is_empty() {
return Poll::Ready(None);
}
for (i, task) in self.unfinished.iter_mut().enumerate() {
let task = Pin::new(task);
if let Poll::Ready(result) = task.poll(cx) {
self.unfinished.swap_remove(i);
return Poll::Ready(Some(result));
}
}
Poll::Pending
}

/// Tries to wait for a finished task.
/// This function panics on task failure and ignores cancelled tasks.
///
/// # Return value
/// - Ready(None), if the set is empty or all remaining tasks were cancelled.
/// - Pending, if all tasks are still running
/// - Ready(Some(x)), if a task resolved with resut x
///
/// Unlike `try_poll_next_finished`, it is possible that None will be returned
/// after Some(Pending), e.g. if all tasks get cancelled.
pub fn poll_next_finished(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
loop {
match self.try_poll_next_finished(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(x))) => return Poll::Ready(Some(x)),
Poll::Ready(Some(Err(err))) => {
if err.is_cancelled() {
continue;
}
std::panic::resume_unwind(err.into_panic());
}
Poll::Ready(None) => return Poll::Ready(None),
}
}
}

/// Waits for a finished task, with ability to handle failures.
///
/// # Return value
/// - None, if the set is empty.
/// - Some(Ok(x)), when a task resolves with resut x
/// - Some(Err(e))), when a task failes with error err.
pub async fn try_wait_next_finished(&mut self) -> Option<Result<T, JoinError>> {
poll_fn(|cx| self.try_poll_next_finished(cx)).await
}

/// Waits for a next finished task.
/// This function panics on task failure and ignores cancelled tasks.
/// # Return value
/// - None, when the set is empty.
/// - Some(x), if a task resolved with result x
pub async fn wait_next_finished(&mut self) -> Option<T> {
poll_fn(|cx| self.poll_next_finished(cx)).await
}

/// Cancels all running tasks.
pub fn cancel(&self) {
for t in &self.unfinished {
t.abort();
}
}
}

impl<T: Send + 'static> TaskSet<T> {
fn track(&mut self, handle: JoinHandle<T>) {
self.unfinished.push(handle);
}
/// Spawns a future onto this task set
pub fn spawn<F: Future<Output = T> + Send + 'static>(&mut self, fut: F) {
let handle = self.handle.spawn(fut);
self.track(handle);
}

/// Spawns a blocking task onto this task set
pub fn spawn_blocking<R: FnOnce() -> T + Send + 'static>(&mut self, func: R) {
let handle = self.handle.spawn_blocking(func);
self.track(handle);
}
}

impl TaskSet<()> {
/// Waits for all running tasks to complete or propagates panic.
/// `is_empty` returns true after this method completes.
pub async fn wait_all(&mut self) {
while self.wait_next_finished().await.is_some() {}
}
}