Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync: make notify_waiters calls atomic #5458

Merged
merged 10 commits into from
Feb 19, 2023
125 changes: 84 additions & 41 deletions tokio/src/sync/notify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,16 @@ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>;
/// [`Semaphore`]: crate::sync::Semaphore
#[derive(Debug)]
pub struct Notify {
// This uses 2 bits to store one of `EMPTY`,
// `state` uses 2 bits to store one of `EMPTY`,
// `WAITING` or `NOTIFIED`. The rest of the bits
// are used to store the number of times `notify_waiters`
// was called.
//
// Throughout the code there are two assumptions:
// - state can be transitioned *from* `WAITING` only if
// `waiters` lock is held
// - number of times `notify_waiters` was called can
// be modified only if `waiters` lock is held
state: AtomicUsize,
waiters: Mutex<WaitList>,
}
Expand Down Expand Up @@ -229,6 +235,17 @@ struct Waiter {
_p: PhantomPinned,
}

impl Waiter {
fn new() -> Waiter {
Waiter {
pointers: linked_list::Pointers::new(),
waker: None,
notified: None,
_p: PhantomPinned,
}
}
}

generate_addr_of_methods! {
impl<> Waiter {
unsafe fn addr_of_pointers(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Waiter>> {
Expand All @@ -249,6 +266,9 @@ pub struct Notified<'a> {
/// The current state of the receiving process.
state: State,

/// Number of calls to `notify_waiters` at the time of creation.
notify_waiters_calls: usize,

/// Entry in the waiter `LinkedList`.
waiter: UnsafeCell<Waiter>,
}
Expand All @@ -258,7 +278,7 @@ unsafe impl<'a> Sync for Notified<'a> {}

#[derive(Debug)]
enum State {
Init(usize),
Init,
Waiting,
Done,
}
Expand Down Expand Up @@ -383,17 +403,13 @@ impl Notify {
/// ```
pub fn notified(&self) -> Notified<'_> {
// we load the number of times notify_waiters
// was called and store that in our initial state
// was called and store that in the future.
let state = self.state.load(SeqCst);
Notified {
notify: self,
state: State::Init(state >> NOTIFY_WAITERS_SHIFT),
waiter: UnsafeCell::new(Waiter {
pointers: linked_list::Pointers::new(),
waker: None,
notified: None,
_p: PhantomPinned,
}),
state: State::Init,
notify_waiters_calls: get_num_notify_waiters_calls(state),
waiter: UnsafeCell::new(Waiter::new()),
}
}

Expand Down Expand Up @@ -500,12 +516,9 @@ impl Notify {
/// }
/// ```
pub fn notify_waiters(&self) {
let mut wakers = WakeList::new();

// There are waiters, the lock must be acquired to notify.
let mut waiters = self.waiters.lock();

// The state must be reloaded while the lock is held. The state may only
// The state must be loaded while the lock is held. The state may only
// transition out of WAITING while the lock is held.
let curr = self.state.load(SeqCst);

Expand All @@ -516,12 +529,23 @@ impl Notify {
return;
}

// At this point, it is guaranteed that the state will not
// concurrently change, as holding the lock is required to
// transition **out** of `WAITING`.
// Increment the number of times this method was called
// and transition to empty.
let new_state = set_state(inc_num_notify_waiters_calls(curr), EMPTY);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this moved up? I'm not sure it matters, but it is not apparent it does not matter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is correct as long as the store happens while the mutex is held. Moving the store up increases the odds of a poll_notified succeeding early in some rare concurrent conditions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was moved up for correctness. If we allowed to poll a pending future between chunks and observe the old counter value, then it would be possible to observe the inconsistency from the description (number 2.). This is because such future would return Pending, even though other waiters from the decoupled list could be already notified. notify_waiters_poll_consistency_many checks such scenarios.

self.state.store(new_state, SeqCst);

let decoupled_list = std::mem::take(&mut *waiters);

let guard = UnsafeCell::new(Waiter::new());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to pin guard to ensure it doesn't accidentally move (the pin! macro might work here). Also, could you add a big comment saying it is critical for safety that guard does not move and is not dropped until the guarded list is dropped?

// Safety: the pointer is not null. Additionally, we have made sure
// that `guard` will not be moved until the guarded list is dropped.
let mut guarded_list =
unsafe { decoupled_list.into_guarded(NonNull::new_unchecked(guard.get())) };

let mut wakers = WakeList::new();
'outer: loop {
while wakers.can_push() {
match waiters.pop_back() {
match guarded_list.pop_back() {
Some(mut waiter) => {
// Safety: `waiters` lock is still held.
let waiter = unsafe { waiter.as_mut() };
Expand All @@ -540,6 +564,8 @@ impl Notify {
}
}

// Release the lock before notifying.
// `guarded_list` is no longer used.
drop(waiters);

wakers.wake_all();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We must clean up the linked list even if this call panics.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good catch. I moved the list to a new struct, which makes sure the list is cleaned up on drop.

Expand All @@ -548,12 +574,6 @@ impl Notify {
waiters = self.waiters.lock();
}

// All waiters will be notified, the state must be transitioned to
// `EMPTY`. As transitioning **from** `WAITING` requires the lock to be
// held, a `store` is sufficient.
let new = set_state(inc_num_notify_waiters_calls(curr), EMPTY);
self.state.store(new, SeqCst);

// Release the lock before notifying
drop(waiters);

Expand Down Expand Up @@ -730,26 +750,32 @@ impl Notified<'_> {

/// A custom `project` implementation is used in place of `pin-project-lite`
/// as a custom drop implementation is needed.
fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &UnsafeCell<Waiter>) {
fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &usize, &UnsafeCell<Waiter>) {
unsafe {
// Safety: both `notify` and `state` are `Unpin`.
// Safety: `notify`, `state` and `notify_waiters_calls` are `Unpin`.

is_unpin::<&Notify>();
is_unpin::<AtomicUsize>();
is_unpin::<usize>();

let me = self.get_unchecked_mut();
(me.notify, &mut me.state, &me.waiter)
(
me.notify,
&mut me.state,
&me.notify_waiters_calls,
&me.waiter,
)
}
}

fn poll_notified(self: Pin<&mut Self>, waker: Option<&Waker>) -> Poll<()> {
use State::*;

let (notify, state, waiter) = self.project();
let (notify, state, notify_waiters_calls, waiter) = self.project();

loop {
match *state {
Init(initial_notify_waiters_calls) => {
Init => {
let curr = notify.state.load(SeqCst);

// Optimistically try acquiring a pending notification
Expand Down Expand Up @@ -779,7 +805,7 @@ impl Notified<'_> {

// if notify_waiters has been called after the future
// was created, then we are done
if get_num_notify_waiters_calls(curr) != initial_notify_waiters_calls {
if get_num_notify_waiters_calls(curr) != *notify_waiters_calls {
*state = Done;
return Poll::Ready(());
}
Expand Down Expand Up @@ -846,21 +872,36 @@ impl Notified<'_> {
return Poll::Pending;
}
Waiting => {
// Currently in the "Waiting" state, implying the caller has
// a waiter stored in the waiter list (guarded by
// `notify.waiters`). In order to access the waker fields,
// we must hold the lock.
// Currently in the "Waiting" state, implying the caller has a waiter stored in
// a waiter list (guarded by `notify.waiters`). In order to access the waker
// fields, we must acquire the lock.

let mut waiters = notify.waiters.lock();

let waiters = notify.waiters.lock();
// Load the state with the lock held.
let curr = notify.state.load(SeqCst);

// Safety: called while locked
let w = unsafe { &mut *waiter.get() };

if w.notified.is_some() {
// Our waker has been notified. Reset the fields and
// remove it from the list.
w.waker = None;
// Our waker has been notified and our waiter is already removed from
// the list. Reset the notification and convert to `Done`.
w.notified = None;
*state = Done;
} else if get_num_notify_waiters_calls(curr) != *notify_waiters_calls {
// Before we add a waiter to the list we check if these numbers are
// different while holding the lock. If these numbers are different now,
// it means that there is a call to `notify_waiters` in progress and this
// waiter must be contained by a guarded list used in `notify_waiters`.
// We can treat the waiter as notified and remove it from the list, as
// it would have been notified be the `notify_waiters` call anyways.

w.waker.take();

// Safety: we hold the lock, so we have an exclusive access to the list.
// The list is used in `notify_waiters`, so it must be guarded.
unsafe { waiters.remove(NonNull::new_unchecked(w)) };

*state = Done;
} else {
Expand Down Expand Up @@ -906,7 +947,7 @@ impl Drop for Notified<'_> {
use State::*;

// Safety: The type only transitions to a "Waiting" state when pinned.
let (notify, state, waiter) = unsafe { Pin::new_unchecked(self).project() };
let (notify, state, _, waiter) = unsafe { Pin::new_unchecked(self).project() };

// This is where we ensure safety. The `Notified` value is being
// dropped, which means we must ensure that the waiter entry is no
Expand All @@ -917,8 +958,10 @@ impl Drop for Notified<'_> {

// remove the entry from the list (if not already removed)
//
// safety: the waiter is only added to `waiters` by virtue of it
// being the only `LinkedList` available to the type.
// Safety: we hold the lock, so we have an exclusive access to every list the
// waiter may be contained in. If the node is not contained in the `waiters`
// list, then it is contained by a guarded list used by `notify_waiters` and
// in such case it must be a middle node.
unsafe { waiters.remove(NonNull::new_unchecked(waiter.get())) };

if waiters.is_empty() && get_state(notify_state) == WAITING {
Expand Down
Loading