Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for scoped threads #312

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
317 changes: 270 additions & 47 deletions src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@ use std::{fmt, io};
use tracing::trace;

/// Mock implementation of `std::thread::JoinHandle`.
pub struct JoinHandle<T> {
result: Arc<Mutex<Option<std::thread::Result<T>>>>,
notify: rt::Notify,
thread: Thread,
}
pub struct JoinHandle<T>(JoinHandleInner<'static, T>);

/// Mock implementation of `std::thread::Thread`.
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -129,7 +125,7 @@ where
F: 'static,
T: 'static,
{
spawn_internal(f, None, None, location!())
JoinHandle(spawn_internal_static(f, None, None, location!()))
}

/// Mock implementation of `std::thread::park`.
Expand All @@ -143,43 +139,6 @@ pub fn park() {
rt::park(location!());
}

fn spawn_internal<F, T>(
f: F,
name: Option<String>,
stack_size: Option<usize>,
location: Location,
) -> JoinHandle<T>
where
F: FnOnce() -> T,
F: 'static,
T: 'static,
{
let result = Arc::new(Mutex::new(None));
let notify = rt::Notify::new(true, false);

let id = {
let name = name.clone();
let result = result.clone();
rt::spawn(stack_size, move || {
rt::execution(|execution| {
init_current(execution, name);
});

*result.lock().unwrap() = Some(Ok(f()));
notify.notify(location);
})
};

JoinHandle {
result,
notify,
thread: Thread {
id: ThreadId { id },
name,
},
}
}

impl Builder {
/// Generates the base configuration for spawning a thread, from which
/// configuration methods can be chained.
Expand Down Expand Up @@ -217,21 +176,53 @@ impl Builder {
F: Send + 'static,
T: Send + 'static,
{
Ok(spawn_internal(f, self.name, self.stack_size, location!()))
Ok(JoinHandle(spawn_internal_static(
f,
self.name,
self.stack_size,
location!(),
)))
}
}

impl Builder {
/// Spawns a new scoped thread using the settings set through this `Builder`.
pub fn spawn_scoped<'scope, 'env, F, T>(
self,
scope: &'scope Scope<'scope, 'env>,
f: F,
) -> io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Ok(ScopedJoinHandle(
// Safety: the call to this function requires a `&'scope Scope`
// which can only be constructed by `scope()`, which ensures that
// all spawned threads are joined before the `Scope` is destroyed.
unsafe {
spawn_internal(
f,
self.name,
self.stack_size,
Some(&scope.data),
location!(),
)
},
))
}
}

impl<T> JoinHandle<T> {
/// Waits for the associated thread to finish.
#[track_caller]
pub fn join(self) -> std::thread::Result<T> {
self.notify.wait(location!());
self.result.lock().unwrap().take().unwrap()
self.0.join()
}

/// Gets a handle to the underlying [`Thread`]
pub fn thread(&self) -> &Thread {
&self.thread
self.0.thread()
}
}

Expand Down Expand Up @@ -312,3 +303,235 @@ impl<T: 'static> fmt::Debug for LocalKey<T> {
f.pad("LocalKey { .. }")
}
}

/// A scope for spawning scoped threads.
///
/// See [`scope`] for more details.
#[derive(Debug)]
pub struct Scope<'scope, 'env: 'scope> {
data: ScopeData,
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
}

/// An owned permission to join on a scoped thread (block on its termination).
///
/// See [`Scope::spawn`] for details.
#[derive(Debug)]
pub struct ScopedJoinHandle<'scope, T>(JoinHandleInner<'scope, T>);

/// Create a scope for spawning scoped threads.
///
/// Mock implementation of [`std::thread::scope`].
#[track_caller]
pub fn scope<'env, F, T>(f: F) -> T
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
{
let scope = Scope {
data: ScopeData {
running_threads: Mutex::default(),
main_thread: current(),
},
env: PhantomData,
scope: PhantomData,
};

// Run `f`, but catch panics so we can make sure to wait for all the threads to join.
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(&scope)));

// Wait until all the threads are finished. This is required to fulfill
// the safety requirements of `spawn_internal`.
let running = loop {
{
let running = scope.data.running_threads.lock().unwrap();
if running.count == 0 {
break running;
}
}
park();
};

for notify in &running.notify_on_finished {
notify.wait(location!())
}

// Throw any panic from `f`, or the return value of `f` if no thread panicked.
match result {
Err(e) => std::panic::resume_unwind(e),
Ok(result) => result,
}
}

impl<'scope, 'env> Scope<'scope, 'env> {
/// Spawns a new thread within a scope, returning a [`ScopedJoinHandle`] for it.
///
/// See [`std::thread::Scope`] and [`std::thread::scope`] for details.
pub fn spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Builder::new()
.spawn_scoped(self, f)
.expect("failed to spawn thread")
}
}

impl<'scope, T> ScopedJoinHandle<'scope, T> {
/// Extracts a handle to the underlying thread.
pub fn thread(&self) -> &Thread {
self.0.thread()
}

/// Waits for the associated thread to finish.
pub fn join(self) -> std::thread::Result<T> {
self.0.join()
}
}

/// Handle for joining on a thread with a scope.
#[derive(Debug)]
struct JoinHandleInner<'scope, T> {
data: Arc<ThreadData<'scope, T>>,
thread: Thread,
}

/// Spawns a thread without a local scope.
fn spawn_internal_static<F, T>(
f: F,
name: Option<String>,
stack_size: Option<usize>,
location: Location,
) -> JoinHandleInner<'static, T>
where
F: FnOnce() -> T,
F: 'static,
T: 'static,
{
// Safety: the requirements of `spawn_internal` are trivially satisfied
// since there is no `scope`.
unsafe { spawn_internal(f, name, stack_size, None, location) }
}

/// Spawns a thread with an optional scope.
///
/// The caller must ensure that if `scope` is not None, the provided closure
/// finishes before `'scope` ends.
unsafe fn spawn_internal<'scope, F, T>(
f: F,
name: Option<String>,
stack_size: Option<usize>,
scope: Option<&'scope ScopeData>,
location: Location,
) -> JoinHandleInner<'scope, T>
where
F: FnOnce() -> T,
F: 'scope,
T: 'scope,
{
let scope_notify = scope
.clone()
.map(|scope| (scope.add_running_thread(), scope));
let thread_data = Arc::new(ThreadData::new());

let id = {
let name = name.clone();
// Hold a weak reference so that if the thread handle gets dropped, we
// don't try to store the result or notify anybody unnecessarily.
let weak_data = Arc::downgrade(&thread_data);

let body: Box<dyn FnOnce() + 'scope> = Box::new(move || {
rt::execution(|execution| {
init_current(execution, name);
});

// Ensure everything from the spawned thread's execution either gets
// stored in the thread handle or dropped before notifying that the
// thread has completed.
{
let result = f();
if let Some(thread_data) = weak_data.upgrade() {
*thread_data.result.lock().unwrap() = Some(Ok(result));
thread_data.notification.notify(location);
}
}

if let Some((notifier, scope)) = scope_notify {
notifier.notify(location!());
scope.remove_running_thread()
}
});
rt::spawn(
stack_size,
std::mem::transmute::<_, Box<dyn FnOnce()>>(body),
)
};

JoinHandleInner {
data: thread_data,
thread: Thread {
id: ThreadId { id },
name,
},
}
}

/// Data for a running thread.
#[derive(Debug)]
struct ThreadData<'scope, T> {
result: Mutex<Option<std::thread::Result<T>>>,
notification: rt::Notify,
_marker: PhantomData<Option<&'scope ScopeData>>,
}

impl<'scope, T> ThreadData<'scope, T> {
fn new() -> Self {
Self {
result: Mutex::new(None),
notification: rt::Notify::new(true, false),
_marker: PhantomData,
}
}
}

impl<'scope, T> JoinHandleInner<'scope, T> {
fn join(self) -> std::thread::Result<T> {
self.data.notification.wait(location!());
self.data.result.lock().unwrap().take().unwrap()
}

fn thread(&self) -> &Thread {
&self.thread
}
}

#[derive(Default, Debug)]
struct ScopeThreads {
count: usize,
notify_on_finished: Vec<rt::Notify>,
}

#[derive(Debug)]
struct ScopeData {
running_threads: Mutex<ScopeThreads>,
main_thread: Thread,
}

impl ScopeData {
fn add_running_thread(&self) -> rt::Notify {
let mut running = self.running_threads.lock().unwrap();
running.count += 1;
let notify = rt::Notify::new(true, false);
running.notify_on_finished.push(notify);
notify
}

fn remove_running_thread(&self) {
let mut running = self.running_threads.lock().unwrap();
running.count -= 1;
if running.count == 0 {
self.main_thread.unpark()
}
}
}
Loading