Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ThreadId instead of TLS-address in ReentrantLock #124881

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 115 additions & 23 deletions library/std/src/sync/reentrant_lock.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#[cfg(all(test, not(target_os = "emscripten")))]
mod tests;

use cfg_if::cfg_if;

use crate::cell::UnsafeCell;
use crate::fmt;
use crate::ops::Deref;
use crate::panic::{RefUnwindSafe, UnwindSafe};
use crate::sync::atomic::{AtomicUsize, Ordering::Relaxed};
use crate::sys::sync as sys;
use crate::thread::{current_id, ThreadId};

/// A re-entrant mutual exclusion lock
///
Expand Down Expand Up @@ -53,8 +55,8 @@ use crate::sys::sync as sys;
//
// The 'owner' field tracks which thread has locked the mutex.
//
// We use current_thread_unique_ptr() as the thread identifier,
// which is just the address of a thread local variable.
// We use thread::current_id() as the thread identifier, which is just the
// current thread's ThreadId, so it's unique across the process lifetime.
//
// If `owner` is set to the identifier of the current thread,
// we assume the mutex is already locked and instead of locking it again,
Expand All @@ -72,14 +74,109 @@ use crate::sys::sync as sys;
// since we're not dealing with multiple threads. If it's not equal,
// synchronization is left to the mutex, making relaxed memory ordering for
// the `owner` field fine in all cases.
//
// On systems without 64 bit atomics we also store the address of a TLS variable
// along the 64-bit TID. We then first check that address against the address
// of that variable on the current thread, and only if they compare equal do we
// compare the actual TIDs. Because we only ever read the TID on the same thread
// that it was written on (or a thread sharing the TLS block with that writer thread),
// we don't need to further synchronize the TID accesses, so they can be regular 64-bit
// non-atomic accesses.
#[unstable(feature = "reentrant_lock", issue = "121440")]
pub struct ReentrantLock<T: ?Sized> {
mutex: sys::Mutex,
owner: AtomicUsize,
owner: Tid,
lock_count: UnsafeCell<u32>,
data: T,
}

cfg_if!(
if #[cfg(target_has_atomic = "64")] {
use crate::sync::atomic::{AtomicU64, Ordering::Relaxed};

struct Tid(AtomicU64);

impl Tid {
const fn new() -> Self {
Self(AtomicU64::new(0))
}

#[inline]
fn contains(&self, owner: ThreadId) -> bool {
owner.as_u64().get() == self.0.load(Relaxed)
}

#[inline]
// This is just unsafe to match the API of the Tid type below.
unsafe fn set(&self, tid: Option<ThreadId>) {
let value = tid.map_or(0, |tid| tid.as_u64().get());
self.0.store(value, Relaxed);
}
}
} else {
/// Returns the address of a TLS variable. This is guaranteed to
/// be unique across all currently alive threads.
fn tls_addr() -> usize {
thread_local! { static X: u8 = const { 0u8 } };

X.with(|p| <*const u8>::addr(p))
}

use crate::sync::atomic::{
AtomicUsize,
Ordering,
};

struct Tid {
// When a thread calls `set()`, this value gets updated to
// the address of a thread local on that thread. This is
// used as a first check in `contains()`; if the `tls_addr`
// doesn't match the TLS address of the current thread, then
// the ThreadId also can't match. Only if the TLS addresses do
// match do we read out the actual TID.
// Note also that we can use relaxed atomic operations here, because
// we only ever read from the tid if `tls_addr` matches the current
// TLS address. In that case, either the the tid has been set by
// the current thread, or by a thread that has terminated before
// the current thread was created. In either case, no further
// synchronization is needed (as per <https://github.com/rust-lang/miri/issues/3450>)
tls_addr: AtomicUsize,
tid: UnsafeCell<u64>,
}

unsafe impl Send for Tid {}
unsafe impl Sync for Tid {}

impl Tid {
const fn new() -> Self {
Self { tls_addr: AtomicUsize::new(0), tid: UnsafeCell::new(0) }
}

#[inline]
// NOTE: This assumes that `owner` is the ID of the current
// thread, and may spuriously return `false` if that's not the case.
fn contains(&self, owner: ThreadId) -> bool {
// SAFETY: See the comments in the struct definition.
self.tls_addr.load(Ordering::Relaxed) == tls_addr()
&& unsafe { *self.tid.get() } == owner.as_u64().get()
}

#[inline]
// This may only be called by one thread at a time, and can lead to
// race conditions otherwise.
unsafe fn set(&self, tid: Option<ThreadId>) {
// It's important that we set `self.tls_addr` to 0 if the tid is
// cleared. Otherwise, there might be race conditions between
// `set()` and `get()`.
let tls_addr = if tid.is_some() { tls_addr() } else { 0 };
let value = tid.map_or(0, |tid| tid.as_u64().get());
self.tls_addr.store(tls_addr, Ordering::Relaxed);
unsafe { *self.tid.get() = value };
}
}
}
);

#[unstable(feature = "reentrant_lock", issue = "121440")]
unsafe impl<T: Send + ?Sized> Send for ReentrantLock<T> {}
#[unstable(feature = "reentrant_lock", issue = "121440")]
Expand Down Expand Up @@ -131,7 +228,7 @@ impl<T> ReentrantLock<T> {
pub const fn new(t: T) -> ReentrantLock<T> {
ReentrantLock {
mutex: sys::Mutex::new(),
owner: AtomicUsize::new(0),
owner: Tid::new(),
lock_count: UnsafeCell::new(0),
data: t,
}
Expand Down Expand Up @@ -181,14 +278,16 @@ impl<T: ?Sized> ReentrantLock<T> {
/// assert_eq!(lock.lock().get(), 10);
/// ```
pub fn lock(&self) -> ReentrantLockGuard<'_, T> {
let this_thread = current_thread_unique_ptr();
// Safety: We only touch lock_count when we own the lock.
let this_thread = current_id();
// Safety: We only touch lock_count when we own the inner mutex.
// Additionally, we only call `self.owner.set()` while holding
// the inner mutex, so no two threads can call it concurrently.
unsafe {
if self.owner.load(Relaxed) == this_thread {
if self.owner.contains(this_thread) {
self.increment_lock_count().expect("lock count overflow in reentrant mutex");
} else {
self.mutex.lock();
self.owner.store(this_thread, Relaxed);
self.owner.set(Some(this_thread));
debug_assert_eq!(*self.lock_count.get(), 0);
*self.lock_count.get() = 1;
}
Expand Down Expand Up @@ -223,14 +322,16 @@ impl<T: ?Sized> ReentrantLock<T> {
///
/// This function does not block.
pub(crate) fn try_lock(&self) -> Option<ReentrantLockGuard<'_, T>> {
let this_thread = current_thread_unique_ptr();
// Safety: We only touch lock_count when we own the lock.
let this_thread = current_id();
// Safety: We only touch lock_count when we own the inner mutex.
// Additionally, we only call `self.owner.set()` while holding
// the inner mutex, so no two threads can call it concurrently.
unsafe {
if self.owner.load(Relaxed) == this_thread {
if self.owner.contains(this_thread) {
self.increment_lock_count()?;
Some(ReentrantLockGuard { lock: self })
} else if self.mutex.try_lock() {
self.owner.store(this_thread, Relaxed);
self.owner.set(Some(this_thread));
debug_assert_eq!(*self.lock_count.get(), 0);
*self.lock_count.get() = 1;
Some(ReentrantLockGuard { lock: self })
Expand Down Expand Up @@ -303,18 +404,9 @@ impl<T: ?Sized> Drop for ReentrantLockGuard<'_, T> {
unsafe {
*self.lock.lock_count.get() -= 1;
if *self.lock.lock_count.get() == 0 {
self.lock.owner.store(0, Relaxed);
self.lock.owner.set(None);
self.lock.mutex.unlock();
}
}
}
}

/// Get an address that is unique per running thread.
///
/// This can be used as a non-null usize-sized ID.
pub(crate) fn current_thread_unique_ptr() -> usize {
// Use a non-drop type to make sure it's still available during thread destruction.
thread_local! { static X: u8 = const { 0 } }
X.with(|x| <*const _>::addr(x))
}
32 changes: 29 additions & 3 deletions library/std/src/thread/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
mod tests;

use crate::any::Any;
use crate::cell::{OnceCell, UnsafeCell};
use crate::cell::{Cell, OnceCell, UnsafeCell};
use crate::env;
use crate::ffi::{CStr, CString};
use crate::fmt;
Expand Down Expand Up @@ -698,17 +698,22 @@ where
}

thread_local! {
// Invariant: `CURRENT` and `CURRENT_ID` will always be initialized together.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a note about CURRENT_ID always holding the same ID as the thread in CURRENT, if it is initialized?

// If `CURRENT` is initialized, then `CURRENT_ID` will hold the same value
// as `CURRENT.id()`.
static CURRENT: OnceCell<Thread> = const { OnceCell::new() };
static CURRENT_ID: Cell<Option<ThreadId>> = const { Cell::new(None) };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So now we redundantly store the thread ID twice, once in the Thread and once here?

At least we should explicitly document this invariant. Ideally we can avoid the redundancy...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the redundancy can be avoided, seeing as we want to access the thread local without accessing the Thread and vice versa. And since the Thread instance can outlive the thread itself, we also can't store a reference to the thread local in the Thread instance. Maybe we could store something like a UnsendUnsync<&'static Thread> in a thread local pointing to the current Thread instance? That way we could at least avoid the Arc clone + drop that occurs with thread::current().id()...

}

/// Sets the thread handle for the current thread.
///
/// Aborts if the handle has been set already to reduce code size.
pub(crate) fn set_current(thread: Thread) {
let tid = thread.id();
// Using `unwrap` here can add ~3kB to the binary size. We have complete
// control over where this is called, so just abort if there is a bug.
CURRENT.with(|current| match current.set(thread) {
Ok(()) => {}
Ok(()) => CURRENT_ID.set(Some(tid)),
Err(_) => rtabort!("thread::set_current should only be called once per thread"),
});
}
Expand All @@ -718,7 +723,28 @@ pub(crate) fn set_current(thread: Thread) {
/// In contrast to the public `current` function, this will not panic if called
/// from inside a TLS destructor.
pub(crate) fn try_current() -> Option<Thread> {
CURRENT.try_with(|current| current.get_or_init(|| Thread::new_unnamed()).clone()).ok()
CURRENT
.try_with(|current| {
current
.get_or_init(|| {
let thread = Thread::new_unnamed();
CURRENT_ID.set(Some(thread.id()));
thread
})
.clone()
})
.ok()
}

/// Gets the id of the thread that invokes it.
#[inline]
pub(crate) fn current_id() -> ThreadId {
CURRENT_ID.get().unwrap_or_else(|| {
// If `CURRENT_ID` isn't initialized yet, then `CURRENT` must also not be initialized.
// `current()` will initialize both `CURRENT` and `CURRENT_ID` so subsequent calls to
// `current_id()` will succeed immediately.
current().id()
})
}

/// Gets a handle to the thread that invokes it.
Expand Down
Loading