Skip to content

Commit

Permalink
sync: make notify_waiters calls atomic (#5458)
Browse files Browse the repository at this point in the history
  • Loading branch information
satakuma authored Feb 19, 2023
1 parent 0f17d69 commit 795754a
Show file tree
Hide file tree
Showing 4 changed files with 474 additions and 44 deletions.
190 changes: 148 additions & 42 deletions tokio/src/sync/notify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::Mutex;
use crate::util::linked_list::{self, LinkedList};
use crate::util::linked_list::{self, GuardedLinkedList, LinkedList};
use crate::util::WakeList;

use std::cell::UnsafeCell;
Expand All @@ -20,6 +20,7 @@ use std::sync::atomic::Ordering::SeqCst;
use std::task::{Context, Poll, Waker};

type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
type GuardedWaitList = GuardedLinkedList<Waiter, <Waiter as linked_list::Link>::Target>;

/// Notifies a single task to wake up.
///
Expand Down Expand Up @@ -198,10 +199,16 @@ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
/// [`Semaphore`]: crate::sync::Semaphore
#[derive(Debug)]
pub struct Notify {
// This uses 2 bits to store one of `EMPTY`,
// `state` uses 2 bits to store one of `EMPTY`,
// `WAITING` or `NOTIFIED`. The rest of the bits
// are used to store the number of times `notify_waiters`
// was called.
//
// Throughout the code there are two assumptions:
// - state can be transitioned *from* `WAITING` only if
// `waiters` lock is held
// - number of times `notify_waiters` was called can
// be modified only if `waiters` lock is held
state: AtomicUsize,
waiters: Mutex<WaitList>,
}
Expand Down Expand Up @@ -229,6 +236,17 @@ struct Waiter {
_p: PhantomPinned,
}

impl Waiter {
fn new() -> Waiter {
Waiter {
pointers: linked_list::Pointers::new(),
waker: None,
notified: None,
_p: PhantomPinned,
}
}
}

generate_addr_of_methods! {
impl<> Waiter {
unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> {
Expand All @@ -237,6 +255,59 @@ generate_addr_of_methods! {
}
}

/// List used in `Notify::notify_waiters`. It wraps a guarded linked list
/// and gates the access to it on `notify.waiters` mutex. It also empties
/// the list on drop.
struct NotifyWaitersList<'a> {
list: GuardedWaitList,
is_empty: bool,
notify: &'a Notify,
}

impl<'a> NotifyWaitersList<'a> {
fn new(
unguarded_list: WaitList,
guard: Pin<&'a mut UnsafeCell<Waiter>>,
notify: &'a Notify,
) -> NotifyWaitersList<'a> {
// Safety: pointer to the guarding waiter is not null.
let guard_ptr = unsafe { NonNull::new_unchecked(guard.get()) };
let list = unguarded_list.into_guarded(guard_ptr);
NotifyWaitersList {
list,
is_empty: false,
notify,
}
}

/// Removes the last element from the guarded list. Modifying this list
/// requires an exclusive access to the main list in `Notify`.
fn pop_back_locked(&mut self, _waiters: &mut WaitList) -> Option<NonNull<Waiter>> {
let result = self.list.pop_back();
if result.is_none() {
// Save information about emptiness to avoid waiting for lock
// in the destructor.
self.is_empty = true;
}
result
}
}

impl Drop for NotifyWaitersList<'_> {
fn drop(&mut self) {
// If the list is not empty, we unlink all waiters from it.
// We do not wake the waiters to avoid double panics.
if !self.is_empty {
let _lock_guard = self.notify.waiters.lock();
while let Some(mut waiter) = self.list.pop_back() {
// Safety: we hold the lock.
let waiter = unsafe { waiter.as_mut() };
waiter.notified = Some(NotificationType::AllWaiters);
}
}
}
}

/// Future returned from [`Notify::notified()`].
///
/// This future is fused, so once it has completed, any future calls to poll
Expand All @@ -249,6 +320,9 @@ pub struct Notified<'a> {
/// The current state of the receiving process.
state: State,

/// Number of calls to `notify_waiters` at the time of creation.
notify_waiters_calls: usize,

/// Entry in the waiter `LinkedList`.
waiter: UnsafeCell<Waiter>,
}
Expand All @@ -258,7 +332,7 @@ unsafe impl<'a> Sync for Notified<'a> {}

#[derive(Debug)]
enum State {
Init(usize),
Init,
Waiting,
Done,
}
Expand Down Expand Up @@ -383,17 +457,13 @@ impl Notify {
/// ```
pub fn notified(&self) -> Notified<'_> {
// we load the number of times notify_waiters
// was called and store that in our initial state
// was called and store that in the future.
let state = self.state.load(SeqCst);
Notified {
notify: self,
state: State::Init(state >> NOTIFY_WAITERS_SHIFT),
waiter: UnsafeCell::new(Waiter {
pointers: linked_list::Pointers::new(),
waker: None,
notified: None,
_p: PhantomPinned,
}),
state: State::Init,
notify_waiters_calls: get_num_notify_waiters_calls(state),
waiter: UnsafeCell::new(Waiter::new()),
}
}

Expand Down Expand Up @@ -500,12 +570,9 @@ impl Notify {
/// }
/// ```
pub fn notify_waiters(&self) {
let mut wakers = WakeList::new();

// There are waiters, the lock must be acquired to notify.
let mut waiters = self.waiters.lock();

// The state must be reloaded while the lock is held. The state may only
// The state must be loaded while the lock is held. The state may only
// transition out of WAITING while the lock is held.
let curr = self.state.load(SeqCst);

Expand All @@ -516,12 +583,30 @@ impl Notify {
return;
}

// At this point, it is guaranteed that the state will not
// concurrently change, as holding the lock is required to
// transition **out** of `WAITING`.
// Increment the number of times this method was called
// and transition to empty.
let new_state = set_state(inc_num_notify_waiters_calls(curr), EMPTY);
self.state.store(new_state, SeqCst);

// It is critical for `GuardedLinkedList` safety that the guard node is
// pinned in memory and is not dropped until the guarded list is dropped.
let guard = UnsafeCell::new(Waiter::new());
pin!(guard);

// We move all waiters to a secondary list. It uses a `GuardedLinkedList`
// underneath to allow every waiter to safely remove itself from it.
//
// * This list will be still guarded by the `waiters` lock.
// `NotifyWaitersList` wrapper makes sure we hold the lock to modify it.
// * This wrapper will empty the list on drop. It is critical for safety
// that we will not leave any list entry with a pointer to the local
// guard node after this function returns / panics.
let mut list = NotifyWaitersList::new(std::mem::take(&mut *waiters), guard, self);

let mut wakers = WakeList::new();
'outer: loop {
while wakers.can_push() {
match waiters.pop_back() {
match list.pop_back_locked(&mut waiters) {
Some(mut waiter) => {
// Safety: `waiters` lock is still held.
let waiter = unsafe { waiter.as_mut() };
Expand All @@ -540,20 +625,17 @@ impl Notify {
}
}

// Release the lock before notifying.
drop(waiters);

// One of the wakers may panic, but the remaining waiters will still
// be unlinked from the list in `NotifyWaitersList` destructor.
wakers.wake_all();

// Acquire the lock again.
waiters = self.waiters.lock();
}

// All waiters will be notified, the state must be transitioned to
// `EMPTY`. As transitioning **from** `WAITING` requires the lock to be
// held, a `store` is sufficient.
let new = set_state(inc_num_notify_waiters_calls(curr), EMPTY);
self.state.store(new, SeqCst);

// Release the lock before notifying
drop(waiters);

Expand Down Expand Up @@ -730,26 +812,32 @@ impl Notified<'_> {

/// A custom `project` implementation is used in place of `pin-project-lite`
/// as a custom drop implementation is needed.
fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &UnsafeCell<Waiter>) {
fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &usize, &UnsafeCell<Waiter>) {
unsafe {
// Safety: both `notify` and `state` are `Unpin`.
// Safety: `notify`, `state` and `notify_waiters_calls` are `Unpin`.

is_unpin::<&Notify>();
is_unpin::<AtomicUsize>();
is_unpin::<usize>();

let me = self.get_unchecked_mut();
(me.notify, &mut me.state, &me.waiter)
(
me.notify,
&mut me.state,
&me.notify_waiters_calls,
&me.waiter,
)
}
}

fn poll_notified(self: Pin<&mut Self>, waker: Option<&Waker>) -> Poll<()> {
use State::*;

let (notify, state, waiter) = self.project();
let (notify, state, notify_waiters_calls, waiter) = self.project();

loop {
match *state {
Init(initial_notify_waiters_calls) => {
Init => {
let curr = notify.state.load(SeqCst);

// Optimistically try acquiring a pending notification
Expand Down Expand Up @@ -779,7 +867,7 @@ impl Notified<'_> {

// if notify_waiters has been called after the future
// was created, then we are done
if get_num_notify_waiters_calls(curr) != initial_notify_waiters_calls {
if get_num_notify_waiters_calls(curr) != *notify_waiters_calls {
*state = Done;
return Poll::Ready(());
}
Expand Down Expand Up @@ -846,21 +934,37 @@ impl Notified<'_> {
return Poll::Pending;
}
Waiting => {
// Currently in the "Waiting" state, implying the caller has
// a waiter stored in the waiter list (guarded by
// `notify.waiters`). In order to access the waker fields,
// we must hold the lock.
// Currently in the "Waiting" state, implying the caller has a waiter stored in
// a waiter list (guarded by `notify.waiters`). In order to access the waker
// fields, we must acquire the lock.

let waiters = notify.waiters.lock();
let mut waiters = notify.waiters.lock();

// Load the state with the lock held.
let curr = notify.state.load(SeqCst);

// Safety: called while locked
let w = unsafe { &mut *waiter.get() };

if w.notified.is_some() {
// Our waker has been notified. Reset the fields and
// remove it from the list.
w.waker = None;
// Our waker has been notified and our waiter is already removed from
// the list. Reset the notification and convert to `Done`.
w.notified = None;
w.waker = None;
*state = Done;
} else if get_num_notify_waiters_calls(curr) != *notify_waiters_calls {
// Before we add a waiter to the list we check if these numbers are
// different while holding the lock. If these numbers are different now,
// it means that there is a call to `notify_waiters` in progress and this
// waiter must be contained by a guarded list used in `notify_waiters`.
// We can treat the waiter as notified and remove it from the list, as
// it would have been notified in the `notify_waiters` call anyways.

w.waker = None;

// Safety: we hold the lock, so we have an exclusive access to the list.
// The list is used in `notify_waiters`, so it must be guarded.
unsafe { waiters.remove(NonNull::new_unchecked(w)) };

*state = Done;
} else {
Expand Down Expand Up @@ -906,7 +1010,7 @@ impl Drop for Notified<'_> {
use State::*;

// Safety: The type only transitions to a "Waiting" state when pinned.
let (notify, state, waiter) = unsafe { Pin::new_unchecked(self).project() };
let (notify, state, _, waiter) = unsafe { Pin::new_unchecked(self).project() };

// This is where we ensure safety. The `Notified` value is being
// dropped, which means we must ensure that the waiter entry is no
Expand All @@ -917,8 +1021,10 @@ impl Drop for Notified<'_> {

// remove the entry from the list (if not already removed)
//
// safety: the waiter is only added to `waiters` by virtue of it
// being the only `LinkedList` available to the type.
// Safety: we hold the lock, so we have an exclusive access to every list the
// waiter may be contained in. If the node is not contained in the `waiters`
// list, then it is contained by a guarded list used by `notify_waiters` and
// in such case it must be a middle node.
unsafe { waiters.remove(NonNull::new_unchecked(waiter.get())) };

if waiters.is_empty() && get_state(notify_state) == WAITING {
Expand Down
Loading

0 comments on commit 795754a

Please sign in to comment.