Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loom: add abstraction for RwLock to remove poisoning aspect #6807

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion tokio/src/loom/mocked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub(crate) use loom::*;

pub(crate) mod sync {

pub(crate) use loom::sync::MutexGuard;
pub(crate) use loom::sync::{MutexGuard, RwLockReadGuard, RwLockWriteGuard};

#[derive(Debug)]
pub(crate) struct Mutex<T>(loom::sync::Mutex<T>);
Expand Down Expand Up @@ -30,6 +30,38 @@ pub(crate) mod sync {
self.0.get_mut().unwrap()
}
}

#[derive(Debug)]
pub(crate) struct RwLock<T>(loom::sync::RwLock<T>);

#[allow(dead_code)]
impl<T> RwLock<T> {
#[inline]
pub(crate) fn new(t: T) -> Self {
Self(loom::sync::RwLock::new(t))
}

#[inline]
pub(crate) fn read(&self) -> RwLockReadGuard<'_, T> {
self.0.read().unwrap()
}

#[inline]
pub(crate) fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
self.0.try_read().ok()
}

#[inline]
pub(crate) fn write(&self) -> RwLockWriteGuard<'_, T> {
self.0.write().unwrap()
}

#[inline]
pub(crate) fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
self.0.try_write().ok()
}
}

pub(crate) use loom::sync::*;

pub(crate) mod atomic {
Expand Down
6 changes: 5 additions & 1 deletion tokio/src/loom/std/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod barrier;
mod mutex;
#[cfg(all(feature = "parking_lot", not(miri)))]
mod parking_lot;
mod rwlock;
mod unsafe_cell;

pub(crate) mod cell {
Expand Down Expand Up @@ -64,11 +65,14 @@ pub(crate) mod sync {

#[cfg(not(all(feature = "parking_lot", not(miri))))]
#[allow(unused_imports)]
pub(crate) use std::sync::{Condvar, MutexGuard, RwLock, RwLockReadGuard, WaitTimeoutResult};
pub(crate) use std::sync::{Condvar, MutexGuard, RwLockReadGuard, WaitTimeoutResult};

#[cfg(not(all(feature = "parking_lot", not(miri))))]
pub(crate) use crate::loom::std::mutex::Mutex;

#[cfg(not(all(feature = "parking_lot", not(miri))))]
pub(crate) use crate::loom::std::rwlock::RwLock;

pub(crate) mod atomic {
pub(crate) use crate::loom::std::atomic_u16::AtomicU16;
pub(crate) use crate::loom::std::atomic_u32::AtomicU32;
Expand Down
16 changes: 12 additions & 4 deletions tokio/src/loom/std/parking_lot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,20 @@ impl<T> RwLock<T> {
RwLock(PhantomData, parking_lot::RwLock::new(t))
}

pub(crate) fn read(&self) -> LockResult<RwLockReadGuard<'_, T>> {
Ok(RwLockReadGuard(PhantomData, self.1.read()))
pub(crate) fn read(&self) -> RwLockReadGuard<'_, T> {
RwLockReadGuard(PhantomData, self.1.read())
}

pub(crate) fn write(&self) -> LockResult<RwLockWriteGuard<'_, T>> {
Ok(RwLockWriteGuard(PhantomData, self.1.write()))
pub(crate) fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
Some(RwLockReadGuard(PhantomData, self.1.read()))
}

pub(crate) fn write(&self) -> RwLockWriteGuard<'_, T> {
RwLockWriteGuard(PhantomData, self.1.write())
}

pub(crate) fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
Some(RwLockWriteGuard(PhantomData, self.1.write()))
}
}

Expand Down
48 changes: 48 additions & 0 deletions tokio/src/loom/std/rwlock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use std::sync::{self, RwLockReadGuard, RwLockWriteGuard, TryLockError};

/// Adapter for `std::sync::RwLock` that removes the poisoning aspects
/// from its api.
#[derive(Debug)]
pub(crate) struct RwLock<T: ?Sized>(sync::RwLock<T>);

#[allow(dead_code)]
impl<T> RwLock<T> {
#[inline]
pub(crate) fn new(t: T) -> Self {
Self(sync::RwLock::new(t))
}

#[inline]
pub(crate) fn read(&self) -> RwLockReadGuard<'_, T> {
match self.0.read() {
Ok(guard) => guard,
Err(p_err) => p_err.into_inner(),
}
}

#[inline]
pub(crate) fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
match self.0.try_read() {
Ok(guard) => Some(guard),
Err(TryLockError::Poisoned(p_err)) => Some(p_err.into_inner()),
Err(TryLockError::WouldBlock) => None,
}
}

#[inline]
pub(crate) fn write(&self) -> RwLockWriteGuard<'_, T> {
match self.0.write() {
Ok(guard) => guard,
Err(p_err) => p_err.into_inner(),
}
}

#[inline]
pub(crate) fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
match self.0.try_write() {
Ok(guard) => Some(guard),
Err(TryLockError::Poisoned(p_err)) => Some(p_err.into_inner()),
Err(TryLockError::WouldBlock) => None,
}
}
}
34 changes: 6 additions & 28 deletions tokio/src/runtime/time/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ pub(crate) use source::TimeSource;
mod wheel;

use crate::loom::sync::atomic::{AtomicBool, Ordering};
use crate::loom::sync::Mutex;
use crate::loom::sync::{Mutex, RwLock};
use crate::runtime::driver::{self, IoHandle, IoStack};
use crate::time::error::Error;
use crate::time::{Clock, Duration};
use crate::util::WakeList;

use crate::loom::sync::atomic::AtomicU64;
use std::fmt;
use std::sync::RwLock;
use std::{num::NonZeroU64, ptr::NonNull};

struct AtomicOptionNonZeroU64(AtomicU64);
Expand Down Expand Up @@ -199,12 +198,7 @@ impl Driver {

// Finds out the min expiration time to park.
let expiration_time = {
let mut wheels_lock = rt_handle
.time()
.inner
.wheels
.write()
.expect("Timer wheel shards poisoned");
let mut wheels_lock = rt_handle.time().inner.wheels.write();
let expiration_time = wheels_lock
.0
.iter_mut()
Expand Down Expand Up @@ -324,11 +318,7 @@ impl Handle {
// Returns the next wakeup time of this shard.
pub(self) fn process_at_sharded_time(&self, id: u32, mut now: u64) -> Option<u64> {
let mut waker_list = WakeList::new();
let mut wheels_lock = self
.inner
.wheels
.read()
.expect("Timer wheel shards poisoned");
let mut wheels_lock = self.inner.wheels.read();
let mut lock = wheels_lock.lock_sharded_wheel(id);

if now < lock.elapsed() {
Expand All @@ -355,11 +345,7 @@ impl Handle {

waker_list.wake_all();

wheels_lock = self
.inner
.wheels
.read()
.expect("Timer wheel shards poisoned");
wheels_lock = self.inner.wheels.read();
lock = wheels_lock.lock_sharded_wheel(id);
}
}
Expand All @@ -384,11 +370,7 @@ impl Handle {
/// `add_entry` must not be called concurrently.
pub(self) unsafe fn clear_entry(&self, entry: NonNull<TimerShared>) {
unsafe {
let wheels_lock = self
.inner
.wheels
.read()
.expect("Timer wheel shards poisoned");
let wheels_lock = self.inner.wheels.read();
let mut lock = wheels_lock.lock_sharded_wheel(entry.as_ref().shard_id());

if entry.as_ref().might_be_registered() {
Expand All @@ -412,11 +394,7 @@ impl Handle {
entry: NonNull<TimerShared>,
) {
let waker = unsafe {
let wheels_lock = self
.inner
.wheels
.read()
.expect("Timer wheel shards poisoned");
let wheels_lock = self.inner.wheels.read();

let mut lock = wheels_lock.lock_sharded_wheel(entry.as_ref().shard_id());

Expand Down
10 changes: 5 additions & 5 deletions tokio/src/sync/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ impl<T> Sender<T> {
tail.pos = tail.pos.wrapping_add(1);

// Get the slot
let mut slot = self.shared.buffer[idx].write().unwrap();
let mut slot = self.shared.buffer[idx].write();

// Track the position
slot.pos = pos;
Expand Down Expand Up @@ -695,7 +695,7 @@ impl<T> Sender<T> {
while low < high {
let mid = low + (high - low) / 2;
let idx = base_idx.wrapping_add(mid) & self.shared.mask;
if self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0 {
if self.shared.buffer[idx].read().rem.load(SeqCst) == 0 {
low = mid + 1;
} else {
high = mid;
Expand Down Expand Up @@ -737,7 +737,7 @@ impl<T> Sender<T> {
let tail = self.shared.tail.lock();

let idx = (tail.pos.wrapping_sub(1) & self.shared.mask as u64) as usize;
self.shared.buffer[idx].read().unwrap().rem.load(SeqCst) == 0
self.shared.buffer[idx].read().rem.load(SeqCst) == 0
}

/// Returns the number of active receivers.
Expand Down Expand Up @@ -1057,7 +1057,7 @@ impl<T> Receiver<T> {
let idx = (self.next & self.shared.mask as u64) as usize;

// The slot holding the next value to read
let mut slot = self.shared.buffer[idx].read().unwrap();
let mut slot = self.shared.buffer[idx].read();

if slot.pos != self.next {
// Release the `slot` lock before attempting to acquire the `tail`
Expand All @@ -1074,7 +1074,7 @@ impl<T> Receiver<T> {
let mut tail = self.shared.tail.lock();

// Acquire slot lock again
slot = self.shared.buffer[idx].read().unwrap();
slot = self.shared.buffer[idx].read();

// Make sure the position did not change. This could happen in the
// unlikely event that the buffer is wrapped between dropping the
Expand Down
10 changes: 5 additions & 5 deletions tokio/src/sync/watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ impl<T> Receiver<T> {
/// assert_eq!(*rx.borrow(), "hello");
/// ```
pub fn borrow(&self) -> Ref<'_, T> {
let inner = self.shared.value.read().unwrap();
let inner = self.shared.value.read();

// After obtaining a read-lock no concurrent writes could occur
// and the loaded version matches that of the borrowed reference.
Expand Down Expand Up @@ -622,7 +622,7 @@ impl<T> Receiver<T> {
/// [`changed`]: Receiver::changed
/// [`borrow`]: Receiver::borrow
pub fn borrow_and_update(&mut self) -> Ref<'_, T> {
let inner = self.shared.value.read().unwrap();
let inner = self.shared.value.read();

// After obtaining a read-lock no concurrent writes could occur
// and the loaded version matches that of the borrowed reference.
Expand Down Expand Up @@ -813,7 +813,7 @@ impl<T> Receiver<T> {
let mut closed = false;
loop {
{
let inner = self.shared.value.read().unwrap();
let inner = self.shared.value.read();

let new_version = self.shared.state.load().version();
let has_changed = self.version != new_version;
Expand Down Expand Up @@ -1087,7 +1087,7 @@ impl<T> Sender<T> {
{
{
// Acquire the write lock and update the value.
let mut lock = self.shared.value.write().unwrap();
let mut lock = self.shared.value.write();

// Update the value and catch possible panic inside func.
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| modify(&mut lock)));
Expand Down Expand Up @@ -1164,7 +1164,7 @@ impl<T> Sender<T> {
/// assert_eq!(*tx.borrow(), "hello");
/// ```
pub fn borrow(&self) -> Ref<'_, T> {
let inner = self.shared.value.read().unwrap();
let inner = self.shared.value.read();

// The sender/producer always sees the current version
let has_changed = false;
Expand Down
Loading