Skip to content

Commit

Permalink
feat: Thread-local queue push take 3
Browse files Browse the repository at this point in the history
This commit attempts to re-introduce the thread-local optimization. It
stores the local queues in a multiplex hash map keyed by the thread ID
that it started in. It also sets it up so the thread can be woken up by
a unique runner ID.

cc #64

Signed-off-by: John Nunley <[email protected]>
  • Loading branch information
notgull committed May 14, 2024
1 parent 444d0c1 commit de1112b
Showing 1 changed file with 136 additions and 29 deletions.
165 changes: 136 additions & 29 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@
)]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]

use std::collections::HashMap;
use std::fmt;
use std::marker::PhantomData;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock, TryLockError};
use std::task::{Poll, Waker};
use std::thread::{self, ThreadId};

use async_task::{Builder, Runnable};
use concurrent_queue::ConcurrentQueue;
Expand Down Expand Up @@ -347,8 +349,32 @@ impl<'a> Executor<'a> {
fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static {
let state = self.state_as_arc();

// TODO: If possible, push into the current local queue and notify the ticker.
move |runnable| {
move |mut runnable| {
// If possible, push into the current local queue and notify the ticker.
if let Some(local_queue) = state
.local_queues
.read()
.unwrap()
.get(&thread::current().id())
.and_then(|list| list.first())
{
match local_queue.queue.push(runnable) {
Ok(()) => {
if let Some(waker) = state
.sleepers
.lock()
.unwrap()
.notify_runner(local_queue.runner_id)
{
waker.wake();
}
return;
}

Err(r) => runnable = r.into_inner(),
}
}

state.queue.push(runnable).unwrap();
state.notify();
}
Expand Down Expand Up @@ -665,7 +691,9 @@ struct State {
queue: ConcurrentQueue<Runnable>,

/// Local queues created by runners.
local_queues: RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>,
///
/// These are keyed by the thread that the runner originated in.
local_queues: RwLock<HashMap<ThreadId, Vec<Arc<LocalQueue>>>>,

/// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
notified: AtomicBool,
Expand All @@ -682,7 +710,7 @@ impl State {
const fn new() -> State {
State {
queue: ConcurrentQueue::unbounded(),
local_queues: RwLock::new(Vec::new()),
local_queues: RwLock::new(HashMap::new()),
notified: AtomicBool::new(true),
sleepers: Mutex::new(Sleepers {
count: 0,
Expand Down Expand Up @@ -756,36 +784,57 @@ struct Sleepers {
/// IDs and wakers of sleeping unnotified tickers.
///
/// A sleeping ticker is notified when its waker is missing from this list.
wakers: Vec<(usize, Waker)>,
wakers: Vec<Sleeper>,

/// Reclaimed IDs.
free_ids: Vec<usize>,
}

/// A single sleeping ticker.
struct Sleeper {
/// ID of the sleeping ticker.
id: usize,

/// Waker associated with this ticker.
waker: Waker,

/// Specific runner ID for targeted wakeups.
runner: Option<usize>,
}

impl Sleepers {
/// Inserts a new sleeping ticker.
fn insert(&mut self, waker: &Waker) -> usize {
fn insert(&mut self, waker: &Waker, runner: Option<usize>) -> usize {
let id = match self.free_ids.pop() {
Some(id) => id,
None => self.count + 1,
};
self.count += 1;
self.wakers.push((id, waker.clone()));
self.wakers.push(Sleeper {
id,
waker: waker.clone(),
runner,
});
id
}

/// Re-inserts a sleeping ticker's waker if it was notified.
///
/// Returns `true` if the ticker was notified.
fn update(&mut self, id: usize, waker: &Waker) -> bool {
fn update(&mut self, id: usize, waker: &Waker, runner: Option<usize>) -> bool {
for item in &mut self.wakers {
if item.0 == id {
item.1.clone_from(waker);
if item.id == id {
debug_assert_eq!(item.runner, runner);
item.waker.clone_from(waker);
return false;
}
}

self.wakers.push((id, waker.clone()));
self.wakers.push(Sleeper {
id,
waker: waker.clone(),
runner,
});
true
}

Expand All @@ -797,7 +846,7 @@ impl Sleepers {
self.free_ids.push(id);

for i in (0..self.wakers.len()).rev() {
if self.wakers[i].0 == id {
if self.wakers[i].id == id {
self.wakers.remove(i);
return false;
}
Expand All @@ -815,7 +864,20 @@ impl Sleepers {
/// If a ticker was notified already or there are no tickers, `None` will be returned.
fn notify(&mut self) -> Option<Waker> {
if self.wakers.len() == self.count {
self.wakers.pop().map(|item| item.1)
self.wakers.pop().map(|item| item.waker)
} else {
None
}
}

/// Notify a specific waker that was previously sleeping.
fn notify_runner(&mut self, runner: usize) -> Option<Waker> {
if let Some(posn) = self
.wakers
.iter()
.position(|sleeper| sleeper.runner == Some(runner))
{
Some(self.wakers.swap_remove(posn).waker)
} else {
None
}
Expand All @@ -834,12 +896,28 @@ struct Ticker<'a> {
/// 2a) Sleeping and unnotified.
/// 2b) Sleeping and notified.
sleeping: usize,

/// Unique runner ID, if this is a runner.
runner: Option<usize>,
}

impl Ticker<'_> {
/// Creates a ticker.
fn new(state: &State) -> Ticker<'_> {
Ticker { state, sleeping: 0 }
Ticker {
state,
sleeping: 0,
runner: None,
}
}

/// Creates a ticker for a runner.
fn for_runner(state: &State, runner: usize) -> Ticker<'_> {
Ticker {
state,
sleeping: 0,
runner: Some(runner),
}
}

/// Moves the ticker into sleeping and unnotified state.
Expand All @@ -851,12 +929,12 @@ impl Ticker<'_> {
match self.sleeping {
// Move to sleeping state.
0 => {
self.sleeping = sleepers.insert(waker);
self.sleeping = sleepers.insert(waker, self.runner);
}

// Already sleeping, check if notified.
id => {
if !sleepers.update(id, waker) {
if !sleepers.update(id, waker, self.runner) {
return false;
}
}
Expand Down Expand Up @@ -946,8 +1024,11 @@ struct Runner<'a> {
/// Inner ticker.
ticker: Ticker<'a>,

/// The ID of the thread we originated from.
origin_id: ThreadId,

/// The local queue.
local: Arc<ConcurrentQueue<Runnable>>,
local: Arc<LocalQueue>,

/// Bumped every time a runnable task is found.
ticks: usize,
Expand All @@ -956,16 +1037,26 @@ struct Runner<'a> {
impl Runner<'_> {
/// Creates a runner and registers it in the executor state.
fn new(state: &State) -> Runner<'_> {
static ID_GENERATOR: AtomicUsize = AtomicUsize::new(0);
let runner_id = ID_GENERATOR.fetch_add(1, Ordering::SeqCst);

let origin_id = thread::current().id();
let runner = Runner {
state,
ticker: Ticker::new(state),
local: Arc::new(ConcurrentQueue::bounded(512)),
ticker: Ticker::for_runner(state, runner_id),
local: Arc::new(LocalQueue {
queue: ConcurrentQueue::bounded(512),
runner_id,
}),
ticks: 0,
origin_id,
};
state
.local_queues
.write()
.unwrap()
.entry(origin_id)
.or_default()
.push(runner.local.clone());
runner
}
Expand All @@ -976,13 +1067,13 @@ impl Runner<'_> {
.ticker
.runnable_with(|| {
// Try the local queue.
if let Ok(r) = self.local.pop() {
if let Ok(r) = self.local.queue.pop() {
return Some(r);
}

// Try stealing from the global queue.
if let Ok(r) = self.state.queue.pop() {
steal(&self.state.queue, &self.local);
steal(&self.state.queue, &self.local.queue);
return Some(r);
}

Expand All @@ -994,7 +1085,8 @@ impl Runner<'_> {
let start = rng.usize(..n);
let iter = local_queues
.iter()
.chain(local_queues.iter())
.flat_map(|(_, list)| list)
.chain(local_queues.iter().flat_map(|(_, list)| list))
.skip(start)
.take(n);

Expand All @@ -1003,8 +1095,8 @@ impl Runner<'_> {

// Try stealing from each local queue in the list.
for local in iter {
steal(local, &self.local);
if let Ok(r) = self.local.pop() {
steal(&local.queue, &self.local.queue);
if let Ok(r) = self.local.queue.pop() {
return Some(r);
}
}
Expand All @@ -1018,7 +1110,7 @@ impl Runner<'_> {

if self.ticks % 64 == 0 {
// Steal tasks from the global queue to ensure fair task scheduling.
steal(&self.state.queue, &self.local);
steal(&self.state.queue, &self.local.queue);
}

runnable
Expand All @@ -1032,15 +1124,26 @@ impl Drop for Runner<'_> {
.local_queues
.write()
.unwrap()
.get_mut(&self.origin_id)
.unwrap()
.retain(|local| !Arc::ptr_eq(local, &self.local));

// Re-schedule remaining tasks in the local queue.
while let Ok(r) = self.local.pop() {
while let Ok(r) = self.local.queue.pop() {
r.schedule();
}
}
}

/// Data associated with a local queue.
struct LocalQueue {
/// Concurrent queue of active tasks.
queue: ConcurrentQueue<Runnable>,

/// Unique ID associated with this runner.
runner_id: usize,
}

/// Steals some items from one queue into another.
fn steal<T>(src: &ConcurrentQueue<T>, dest: &ConcurrentQueue<T>) {
// Half of `src`'s length rounded up.
Expand Down Expand Up @@ -1104,14 +1207,18 @@ fn debug_state(state: &State, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Re
}

/// Debug wrapper for the local runners.
struct LocalRunners<'a>(&'a RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>);
struct LocalRunners<'a>(&'a RwLock<HashMap<ThreadId, Vec<Arc<LocalQueue>>>>);

impl fmt::Debug for LocalRunners<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0.try_read() {
Ok(lock) => f
.debug_list()
.entries(lock.iter().map(|queue| queue.len()))
.entries(
lock.iter()
.flat_map(|(_, list)| list)
.map(|queue| queue.queue.len()),
)
.finish(),
Err(TryLockError::WouldBlock) => f.write_str("<locked>"),
Err(TryLockError::Poisoned(_)) => f.write_str("<poisoned>"),
Expand Down

0 comments on commit de1112b

Please sign in to comment.