From fd17302bccf5ead94e54f7f7b9c74d5ccca7b200 Mon Sep 17 00:00:00 2001 From: threadexception Date: Tue, 18 Jul 2023 17:37:36 +0200 Subject: [PATCH] Reduce size of Thread Make insertion fully cold --- Cargo.toml | 5 +-- src/lib.rs | 30 +++++++++------- src/thread_id.rs | 90 ++++++++++++++++++++++++++++++++---------------- 3 files changed, 81 insertions(+), 44 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9d36f1a..616bdcc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ rust-version = "1.59" [features] # this feature provides performance improvements using nightly features -nightly = [] +nightly = ["memoffset"] [badges] travis-ci = { repository = "Amanieu/thread_local-rs" } @@ -23,9 +23,10 @@ once_cell = "1.5.2" # this is required to gate `nightly` related code paths cfg-if = "1.0.0" crossbeam-utils = "0.8.15" +memoffset = { version = "0.9.0", optional = true } [dev-dependencies] -criterion = "0.4.0" +criterion = "0.4" [[bench]] name = "thread_local" diff --git a/src/lib.rs b/src/lib.rs index d67f67a..c7d1dc4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -188,10 +188,11 @@ impl ThreadLocal { where F: FnOnce() -> T, { - unsafe { - self.get_or_try(|| Ok::(create())) - .unchecked_unwrap_ok() + if let Some(val) = self.get() { + return val; } + + self.insert(create) } /// Returns the element for the current thread, or creates it if it doesn't @@ -201,12 +202,11 @@ impl ThreadLocal { where F: FnOnce() -> Result, { - let thread = thread_id::get(); - if let Some(val) = self.get_inner(thread) { + if let Some(val) = self.get() { return Ok(val); } - Ok(self.insert(create()?)) + self.insert_maybe(create) } fn get_inner(&self, thread: Thread) -> Option<&T> { @@ -227,14 +227,22 @@ impl ThreadLocal { } #[cold] - fn insert(&self, data: T) -> &T { + fn insert_maybe Result, E>(&self, gen: F) -> Result<&T, E> { + let data = gen()?; + Ok(self.insert(|| data)) + } + + #[cold] + fn insert T>(&self, gen: F) -> &T { + // call the generator here, so it is #[cold] as well. + let data = gen(); let thread = thread_id::get(); let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) }; let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire); // If the bucket doesn't already exist, we need to allocate it let bucket_ptr = if bucket_ptr.is_null() { - let new_bucket = allocate_bucket(thread.bucket_size); + let new_bucket = allocate_bucket(thread.bucket_size()); match bucket_atomic_ptr.compare_exchange( ptr::null_mut(), @@ -247,7 +255,7 @@ impl ThreadLocal { // another thread stored a new bucket before we could, // and we can free our bucket and use that one instead Err(bucket_ptr) => { - unsafe { deallocate_bucket(new_bucket, thread.bucket_size) } + unsafe { deallocate_bucket(new_bucket, thread.bucket_size()) } bucket_ptr } } @@ -496,9 +504,7 @@ impl Iterator for IntoIter { fn next(&mut self) -> Option { self.raw.next_mut(&mut self.thread_local).map(|entry| { *entry.present.get_mut() = false; - unsafe { - std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init() - } + unsafe { mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init() } }) } fn size_hint(&self) -> (usize, Option) { diff --git a/src/thread_id.rs b/src/thread_id.rs index 8ee9cbf..9bbc361 100644 --- a/src/thread_id.rs +++ b/src/thread_id.rs @@ -49,38 +49,47 @@ static THREAD_ID_MANAGER: Lazy> = /// A thread ID may be reused after a thread exits. #[derive(Clone, Copy)] pub(crate) struct Thread { - /// The thread ID obtained from the thread ID manager. - pub(crate) id: usize, /// The bucket this thread's local storage will be in. pub(crate) bucket: usize, - /// The size of the bucket this thread's local storage will be in. - pub(crate) bucket_size: usize, /// The index into the bucket this thread's local storage is in. pub(crate) index: usize, } + impl Thread { + /// id: The thread ID obtained from the thread ID manager. + #[inline] fn new(id: usize) -> Self { let bucket = usize::from(POINTER_WIDTH) - ((id + 1).leading_zeros() as usize) - 1; let bucket_size = 1 << bucket; let index = id - (bucket_size - 1); + Self { bucket, index } + } - Self { - id, - bucket, - bucket_size, - index, - } + /// The size of the bucket this thread's local storage will be in. + #[inline] + pub fn bucket_size(&self) -> usize { + 1 << self.bucket } } cfg_if::cfg_if! { if #[cfg(feature = "nightly")] { + use memoffset::offset_of; + use std::ptr::null; + use std::cell::UnsafeCell; + // This is split into 2 thread-local variables so that we can check whether the // thread is initialized without having to register a thread-local destructor. // // This makes the fast path smaller. #[thread_local] - static mut THREAD: Option = None; + static THREAD: UnsafeCell = UnsafeCell::new(ThreadWrapper { + self_ptr: null(), + thread: Thread { + index: 0, + bucket: 0, + }, + }); thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard { id: Cell::new(0) } }; } // Guard to ensure the thread ID is released on thread exit. @@ -97,17 +106,41 @@ cfg_if::cfg_if! { // will go through get_slow which will either panic or // initialize a new ThreadGuard. unsafe { - THREAD = None; + (&mut *THREAD.get()).self_ptr = null(); } THREAD_ID_MANAGER.lock().free(self.id.get()); } } + /// Data which is unique to the current thread while it is running. + /// A thread ID may be reused after a thread exits. + /// + /// This wrapper exists to hide multiple accesses to the TLS data + /// from the backend as this can lead to inefficient codegen + /// (to be precise it can lead to multiple TLS address lookups) + #[derive(Clone, Copy)] + struct ThreadWrapper { + self_ptr: *const Thread, + thread: Thread, + } + + impl ThreadWrapper { + /// The thread ID obtained from the thread ID manager. + #[inline] + fn new(id: usize) -> Self { + Self { + self_ptr: ((THREAD.get().cast_const() as usize) + offset_of!(ThreadWrapper, thread)) as *const Thread, + thread: Thread::new(id), + } + } + } + /// Returns a thread ID for the current thread, allocating one if needed. #[inline] pub(crate) fn get() -> Thread { - if let Some(thread) = unsafe { THREAD } { - thread + let thread = unsafe { *THREAD.get() }; + if !thread.self_ptr.is_null() { + unsafe { thread.self_ptr.read() } } else { get_slow() } @@ -116,12 +149,13 @@ cfg_if::cfg_if! { /// Out-of-line slow path for allocating a thread ID. #[cold] fn get_slow() -> Thread { - let new = Thread::new(THREAD_ID_MANAGER.lock().alloc()); + let id = THREAD_ID_MANAGER.lock().alloc(); + let new = ThreadWrapper::new(id); unsafe { - THREAD = Some(new); + *THREAD.get() = new; } - THREAD_GUARD.with(|guard| guard.id.set(new.id)); - new + THREAD_GUARD.with(|guard| guard.id.set(id)); + new.thread } } else { // This is split into 2 thread-local variables so that we can check whether the @@ -164,9 +198,10 @@ cfg_if::cfg_if! { /// Out-of-line slow path for allocating a thread ID. #[cold] fn get_slow(thread: &Cell>) -> Thread { - let new = Thread::new(THREAD_ID_MANAGER.lock().alloc()); + let id = THREAD_ID_MANAGER.lock().alloc(); + let new = Thread::new(id); thread.set(Some(new)); - THREAD_GUARD.with(|guard| guard.id.set(new.id)); + THREAD_GUARD.with(|guard| guard.id.set(id)); new } } @@ -175,32 +210,27 @@ cfg_if::cfg_if! { #[test] fn test_thread() { let thread = Thread::new(0); - assert_eq!(thread.id, 0); assert_eq!(thread.bucket, 0); - assert_eq!(thread.bucket_size, 1); + assert_eq!(thread.bucket_size(), 1); assert_eq!(thread.index, 0); let thread = Thread::new(1); - assert_eq!(thread.id, 1); assert_eq!(thread.bucket, 1); - assert_eq!(thread.bucket_size, 2); + assert_eq!(thread.bucket_size(), 2); assert_eq!(thread.index, 0); let thread = Thread::new(2); - assert_eq!(thread.id, 2); assert_eq!(thread.bucket, 1); - assert_eq!(thread.bucket_size, 2); + assert_eq!(thread.bucket_size(), 2); assert_eq!(thread.index, 1); let thread = Thread::new(3); - assert_eq!(thread.id, 3); assert_eq!(thread.bucket, 2); - assert_eq!(thread.bucket_size, 4); + assert_eq!(thread.bucket_size(), 4); assert_eq!(thread.index, 0); let thread = Thread::new(19); - assert_eq!(thread.id, 19); assert_eq!(thread.bucket, 4); - assert_eq!(thread.bucket_size, 16); + assert_eq!(thread.bucket_size(), 16); assert_eq!(thread.index, 4); }