diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index a5bd1bb8835..5fa460bbc43 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -2,7 +2,7 @@ use crate::loom::sync::{Arc, Mutex}; use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task}; use crate::sync::AtomicWaker; -use crate::util::VecDequeCell; +use crate::util::{RcCell, VecDequeCell}; use std::cell::Cell; use std::collections::VecDeque; @@ -261,7 +261,11 @@ pin_project! { } } -thread_local!(static CURRENT: Cell>> = Cell::new(None)); +#[cfg(any(loom, tokio_no_const_thread_local))] +thread_local!(static CURRENT: RcCell = RcCell::new()); + +#[cfg(not(any(loom, tokio_no_const_thread_local)))] +thread_local!(static CURRENT: RcCell = const { RcCell::new() }); cfg_rt! { /// Spawns a `!Send` future on the local task set. @@ -311,12 +315,10 @@ cfg_rt! { F::Output: 'static { CURRENT.with(|maybe_cx| { - let ctx = clone_rc(maybe_cx); - match ctx { + match maybe_cx.get() { None => panic!("`spawn_local` called from outside of a `task::LocalSet`"), Some(cx) => cx.spawn(future, name) } - }) } } @@ -336,7 +338,7 @@ pub struct LocalEnterGuard(Option>); impl Drop for LocalEnterGuard { fn drop(&mut self) { CURRENT.with(|ctx| { - ctx.replace(self.0.take()); + ctx.set(self.0.take()); }) } } @@ -615,12 +617,12 @@ impl LocalSet { fn with(&self, f: impl FnOnce() -> T) -> T { CURRENT.with(|ctx| { struct Reset<'a> { - ctx_ref: &'a Cell>>, + ctx_ref: &'a RcCell, val: Option>, } impl<'a> Drop for Reset<'a> { fn drop(&mut self) { - self.ctx_ref.replace(self.val.take()); + self.ctx_ref.set(self.val.take()); } } let old = ctx.replace(Some(self.context.clone())); @@ -822,19 +824,11 @@ impl Future for RunUntil<'_, T> { } } -fn clone_rc(rc: &Cell>>) -> Option> { - let value = rc.take(); - let cloned = value.clone(); - rc.set(value); - cloned -} - impl Shared { /// Schedule the provided task on the scheduler. fn schedule(&self, task: task::Notified>) { CURRENT.with(|maybe_cx| { - let ctx = clone_rc(maybe_cx); - match ctx { + match maybe_cx.get() { Some(cx) if cx.shared.ptr_eq(self) => { cx.queue.push_back(task); } @@ -861,14 +855,11 @@ impl Shared { impl task::Schedule for Arc { fn release(&self, task: &Task) -> Option> { - CURRENT.with(|maybe_cx| { - let ctx = clone_rc(maybe_cx); - match ctx { - None => panic!("scheduler context missing"), - Some(cx) => { - assert!(cx.shared.ptr_eq(self)); - cx.owned.remove(task) - } + CURRENT.with(|maybe_cx| match maybe_cx.get() { + None => panic!("scheduler context missing"), + Some(cx) => { + assert!(cx.shared.ptr_eq(self)); + cx.owned.remove(task) } }) } @@ -889,15 +880,13 @@ impl task::Schedule for Arc { // This hook is only called from within the runtime, so // `CURRENT` should match with `&self`, i.e. there is no // opportunity for a nested scheduler to be called. - CURRENT.with(|maybe_cx| { - let ctx = clone_rc(maybe_cx); - match ctx { + CURRENT.with(|maybe_cx| match maybe_cx.get() { Some(cx) if Arc::ptr_eq(self, &cx.shared) => { cx.unhandled_panic.set(true); cx.owned.close_and_shutdown_all(); } _ => unreachable!("runtime core not set in CURRENT thread-local"), - }}) + }) } } } diff --git a/tokio/src/util/mod.rs b/tokio/src/util/mod.rs index 41a3bce051f..7ea4840454d 100644 --- a/tokio/src/util/mod.rs +++ b/tokio/src/util/mod.rs @@ -56,6 +56,9 @@ cfg_rt! { mod vec_deque_cell; pub(crate) use vec_deque_cell::VecDequeCell; + + mod rc_cell; + pub(crate) use rc_cell::RcCell; } cfg_rt_multi_thread! { diff --git a/tokio/src/util/rc_cell.rs b/tokio/src/util/rc_cell.rs new file mode 100644 index 00000000000..97c02053c59 --- /dev/null +++ b/tokio/src/util/rc_cell.rs @@ -0,0 +1,57 @@ +use crate::loom::cell::UnsafeCell; + +use std::rc::Rc; + +/// This is exactly like `Cell>>`, except that it provides a `get` +/// method even though `Rc` is not `Copy`. +pub(crate) struct RcCell { + inner: UnsafeCell>>, +} + +impl RcCell { + #[cfg(not(loom))] + pub(crate) const fn new() -> Self { + Self { + inner: UnsafeCell::new(None), + } + } + + // The UnsafeCell in loom does not have a const `new` fn. + #[cfg(loom)] + pub(crate) fn new() -> Self { + Self { + inner: UnsafeCell::new(None), + } + } + + /// Safety: This method may not be called recursively. + #[inline] + unsafe fn with_inner(&self, f: F) -> R + where + F: FnOnce(&mut Option>) -> R, + { + // safety: This type is not Sync, so concurrent calls of this method + // cannot happen. Furthermore, the caller guarantees that the method is + // not called recursively. Finally, this is the only place that can + // create mutable references to the inner Rc. This ensures that any + // mutable references created here are exclusive. + self.inner.with_mut(|ptr| f(&mut *ptr)) + } + + pub(crate) fn get(&self) -> Option> { + // safety: The `Rc::clone` method will not call any unknown user-code, + // so it will not result in a recursive call to `with_inner`. + unsafe { self.with_inner(|rc| rc.clone()) } + } + + pub(crate) fn replace(&self, val: Option>) -> Option> { + // safety: No destructors or other unknown user-code will run inside the + // `with_inner` call, so no recursive call to `with_inner` can happen. + unsafe { self.with_inner(|rc| std::mem::replace(rc, val)) } + } + + pub(crate) fn set(&self, val: Option>) { + let old = self.replace(val); + drop(old); + } +}