Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - await tasks to cancel #6696

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 24 additions & 26 deletions crates/bevy_tasks/src/task_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ use std::{
future::Future,
marker::PhantomData,
mem,
pin::Pin,
sync::Arc,
thread::{self, JoinHandle},
};

use async_task::FallibleTask;
use concurrent_queue::ConcurrentQueue;
use futures_lite::{future, pin, FutureExt};

Expand Down Expand Up @@ -248,8 +248,8 @@ impl TaskPool {
let task_scope_executor = &async_executor::Executor::default();
let task_scope_executor: &'env async_executor::Executor =
unsafe { mem::transmute(task_scope_executor) };
let spawned: ConcurrentQueue<async_executor::Task<T>> = ConcurrentQueue::unbounded();
let spawned_ref: &'env ConcurrentQueue<async_executor::Task<T>> =
let spawned: ConcurrentQueue<FallibleTask<T>> = ConcurrentQueue::unbounded();
let spawned_ref: &'env ConcurrentQueue<FallibleTask<T>> =
unsafe { mem::transmute(&spawned) };

let scope = Scope {
Expand All @@ -267,10 +267,10 @@ impl TaskPool {
if spawned.is_empty() {
Vec::new()
} else {
let get_results = async move {
let mut results = Vec::with_capacity(spawned.len());
while let Ok(task) = spawned.pop() {
results.push(task.await);
let get_results = async {
let mut results = Vec::with_capacity(spawned_ref.len());
while let Ok(task) = spawned_ref.pop() {
results.push(task.await.unwrap());
}

results
Expand All @@ -279,23 +279,8 @@ impl TaskPool {
// Pin the futures on the stack.
pin!(get_results);

// SAFETY: This function blocks until all futures complete, so we do not read/write
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed some unneeded code here too. I can pull this out if necessary

// the data from futures outside of the 'scope lifetime. However,
// rust has no way of knowing this so we must convert to 'static
// here to appease the compiler as it is unable to validate safety.
let get_results: Pin<&mut (dyn Future<Output = Vec<T>> + 'static + Send)> = get_results;
let get_results: Pin<&'static mut (dyn Future<Output = Vec<T>> + 'static + Send)> =
unsafe { mem::transmute(get_results) };

// The thread that calls scope() will participate in driving tasks in the pool
// forward until the tasks that are spawned by this scope() call
// complete. (If the caller of scope() happens to be a thread in
// this thread pool, and we only have one thread in the pool, then
// simply calling future::block_on(spawned) would deadlock.)
let mut spawned = task_scope_executor.spawn(get_results);

loop {
if let Some(result) = future::block_on(future::poll_once(&mut spawned)) {
if let Some(result) = future::block_on(future::poll_once(&mut get_results)) {
break result;
};

Expand Down Expand Up @@ -378,7 +363,7 @@ impl Drop for TaskPool {
pub struct Scope<'scope, 'env: 'scope, T> {
executor: &'scope async_executor::Executor<'scope>,
task_scope_executor: &'scope async_executor::Executor<'scope>,
spawned: &'scope ConcurrentQueue<async_executor::Task<T>>,
spawned: &'scope ConcurrentQueue<FallibleTask<T>>,
// make `Scope` invariant over 'scope and 'env
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
Expand All @@ -394,7 +379,7 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
///
/// For more information, see [`TaskPool::scope`].
pub fn spawn<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
let task = self.executor.spawn(f);
let task = self.executor.spawn(f).fallible();
// ConcurrentQueue only errors when closed or full, but we never
// close and use an unbouded queue, so it is safe to unwrap
self.spawned.push(task).unwrap();
Expand All @@ -407,13 +392,26 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> {
///
/// For more information, see [`TaskPool::scope`].
pub fn spawn_on_scope<Fut: Future<Output = T> + 'scope + Send>(&self, f: Fut) {
let task = self.task_scope_executor.spawn(f);
let task = self.task_scope_executor.spawn(f).fallible();
// ConcurrentQueue only errors when closed or full, but we never
// close and use an unbouded queue, so it is safe to unwrap
self.spawned.push(task).unwrap();
}
}

impl<'scope, 'env, T> Drop for Scope<'scope, 'env, T>
where
T: 'scope,
{
fn drop(&mut self) {
future::block_on(async {
while let Ok(task) = self.spawned.pop() {
task.cancel().await;
}
});
}
}

#[cfg(test)]
#[allow(clippy::disallowed_types)]
mod tests {
Expand Down