Skip to content

Commit

Permalink
Address review
Browse files Browse the repository at this point in the history
This changes the way that Tids are accessed on 32-bit platforms from a seqlock to a simple
tls-address check (followed by a ThreadId comparison). Additionally, `try_current_id` is
changed to be infallible and is renamed to `current_id`.
  • Loading branch information
Sp00ph committed Jul 18, 2024
1 parent 26e800d commit 2296ffb
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 71 deletions.
102 changes: 44 additions & 58 deletions library/std/src/sync/reentrant_lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::fmt;
use crate::ops::Deref;
use crate::panic::{RefUnwindSafe, UnwindSafe};
use crate::sys::sync as sys;
use crate::thread::ThreadId;
use crate::thread::{current_id, ThreadId};

/// A re-entrant mutual exclusion lock
///
Expand Down Expand Up @@ -108,66 +108,66 @@ cfg_if!(
self.0.store(value, Relaxed);
}
}
} else if #[cfg(target_has_atomic = "32")] {
use crate::sync::atomic::{AtomicU32, Ordering::{Acquire, Relaxed, Release}};
} else {
/// Returns the address of a TLS variable. This is guaranteed to
/// be unique across all currently alive threads.
fn tls_addr() -> usize {
thread_local! { static X: u8 = const { 0u8 } };

X.with(|p| <*const u8>::addr(p))
}

use crate::sync::atomic::{
AtomicUsize,
Ordering,
};

struct Tid {
seq: AtomicU32,
low: AtomicU32,
high: AtomicU32,
// When a thread calls `set()`, this value gets updated to
// the address of a thread local on that thread. This is
// used as a first check in `contains()`; if the `tls_addr`
// doesn't match the TLS address of the current thread, then
// the ThreadId also can't match. Only if the TLS addresses do
// match do we read out the actual TID.
// Note also that we can use relaxed atomic operations here, because
// we only ever read from the tid if `tls_addr` matches the current
// TLS address. In that case, either the the tid has been set by
// the current thread, or by a thread that has terminated before
// the current thread was created. In either case, no further
// synchronization is needed (as per <https://github.com/rust-lang/miri/issues/3450>)
tls_addr: AtomicUsize,
tid: UnsafeCell<u64>,
}

unsafe impl Send for Tid {}
unsafe impl Sync for Tid {}

impl Tid {
const fn new() -> Self {
Self {
seq: AtomicU32::new(0),
low: AtomicU32::new(0),
high: AtomicU32::new(0),
}
Self { tls_addr: AtomicUsize::new(0), tid: UnsafeCell::new(0) }
}

#[inline]
// NOTE: This assumes that `owner` is the ID of the current
// thread, and may spuriously return `false` if that's not the case.
fn contains(&self, owner: ThreadId) -> bool {
// Synchronizes with the release-increment in `set()` to ensure
// we only read the data after it's been fully written.
let mut seq = self.seq.load(Acquire);
loop {
if seq % 2 == 0 {
let low = self.low.load(Relaxed);
let high = self.high.load(Relaxed);
// The acquire-increment in `set()` synchronizes with this release
// store to ensure that `get()` doesn't see data from a subsequent
// `set()` call.
match self.seq.compare_exchange_weak(seq, seq, Release, Acquire) {
Ok(_) => {
let tid = u64::from(low) | (u64::from(high) << 32);
return owner.as_u64().get() == tid;
},
Err(new) => seq = new,
}
} else {
// Another thread is currently writing to the seqlock. That thread
// must also be holding the mutex, so we can't currently be the lock owner.
return false;
}
}
// SAFETY: See the comments in the struct definition.
self.tls_addr.load(Ordering::Relaxed) == tls_addr()
&& unsafe { *self.tid.get() } == owner.as_u64().get()
}

#[inline]
// This may only be called from one thread at a time, otherwise
// concurrent `get()` calls may return teared data.
// This may only be called by one thread at a time.
fn set(&self, tid: Option<ThreadId>) {
// It's important that we set `self.tls_addr` to 0 if the
// tid is cleared. Otherwise, there might be race conditions between
// `set()` and `get()`.
let tls_addr = if tid.is_some() { tls_addr() } else { 0 };
let value = tid.map_or(0, |tid| tid.as_u64().get());
self.seq.fetch_add(1, Acquire);
self.low.store(value as u32, Relaxed);
self.high.store((value >> 32) as u32, Relaxed);
self.seq.fetch_add(1, Release);
self.tls_addr.store(tls_addr, Ordering::Relaxed);
unsafe { *self.tid.get() = value };
}
}
} else {
compile_error!("`ReentrantLock` requires at least 32 bit atomics!");
}
);

Expand Down Expand Up @@ -272,7 +272,7 @@ impl<T: ?Sized> ReentrantLock<T> {
/// assert_eq!(lock.lock().get(), 10);
/// ```
pub fn lock(&self) -> ReentrantLockGuard<'_, T> {
let this_thread = current_thread_id();
let this_thread = current_id();
// Safety: We only touch lock_count when we own the lock.
unsafe {
if self.owner.contains(this_thread) {
Expand Down Expand Up @@ -314,7 +314,7 @@ impl<T: ?Sized> ReentrantLock<T> {
///
/// This function does not block.
pub(crate) fn try_lock(&self) -> Option<ReentrantLockGuard<'_, T>> {
let this_thread = current_thread_id();
let this_thread = current_id();
// Safety: We only touch lock_count when we own the lock.
unsafe {
if self.owner.contains(this_thread) {
Expand Down Expand Up @@ -400,17 +400,3 @@ impl<T: ?Sized> Drop for ReentrantLockGuard<'_, T> {
}
}
}

/// Returns the current thread's ThreadId value, which is guaranteed
/// to be unique across the lifetime of the process.
///
/// Panics if called during a TLS destructor on a thread that hasn't
/// been assigned an ID.
pub(crate) fn current_thread_id() -> ThreadId {
#[cold]
fn no_tid() -> ! {
rtabort!("Thread hasn't been assigned an ID!")
}

crate::thread::try_current_id().unwrap_or_else(|| no_tid())
}
20 changes: 7 additions & 13 deletions library/std/src/thread/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,7 @@ where
}

thread_local! {
// Invariant: `CURRENT` and `CURRENT_ID` will always be initialized
// together. However, while `CURRENT_ID` will be available during
// TLS constructors, `CURRENT` will not.
// Invariant: `CURRENT` and `CURRENT_ID` will always be initialized together.
static CURRENT: OnceCell<Thread> = const { OnceCell::new() };
static CURRENT_ID: Cell<Option<ThreadId>> = const { Cell::new(None) };
}
Expand Down Expand Up @@ -737,18 +735,14 @@ pub(crate) fn try_current() -> Option<Thread> {
}

/// Gets the id of the thread that invokes it.
///
/// If called from inside a TLS destructor and the thread was never
/// assigned an id, returns `None`.
#[inline]
pub(crate) fn try_current_id() -> Option<ThreadId> {
if CURRENT_ID.get().is_none() {
pub(crate) fn current_id() -> ThreadId {
CURRENT_ID.get().unwrap_or_else(|| {
// If `CURRENT_ID` isn't initialized yet, then `CURRENT` must also not be initialized.
// `try_current()` will try to initialize both `CURRENT` and `CURRENT_ID`.
// Subsequent calls to `try_current_id` will then no longer enter this if-branch.
let _ = try_current();
}
CURRENT_ID.get()
// `current()` will initialize both `CURRENT` and `CURRENT_ID` so subsequent calls to
// `current_id()` will succeed immediately.
current().id()
})
}

/// Gets a handle to the thread that invokes it.
Expand Down

0 comments on commit 2296ffb

Please sign in to comment.