Skip to content

Commit

Permalink
enable partial recovery across threads
Browse files Browse the repository at this point in the history
Including the corner case where the active thread does not have
recovery.
  • Loading branch information
nikomatsakis committed Nov 11, 2021
1 parent 93ee78e commit 382147f
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 11 deletions.
50 changes: 41 additions & 9 deletions src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::revision::{AtomicRevision, Revision};
use crate::{Cancelled, Cycle, Database, DatabaseKeyIndex, Event, EventKind};
use log::debug;
use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive};
use parking_lot::{Mutex, MutexGuard, RwLock};
use parking_lot::{Mutex, RwLock};
use rustc_hash::FxHasher;
use std::hash::{BuildHasherDefault, Hash};
use std::sync::atomic::{AtomicUsize, Ordering};
Expand Down Expand Up @@ -44,10 +44,11 @@ pub struct Runtime {
shared_state: Arc<SharedState>,
}

#[derive(Copy, Clone, Debug)]
#[derive(Clone, Debug)]
pub(crate) enum WaitResult {
Completed,
Panicked,
Cycle(Cycle),
}

impl Default for Runtime {
Expand Down Expand Up @@ -266,14 +267,25 @@ impl Runtime {
.report_synthetic_read(durability, changed_at);
}

fn throw_cycle_error(
/// Handles a cycle in the dependency graph that was detected when the
/// current thread tried to block on `database_key_index` which is being
/// executed by `to_id`. If this function returns, then `to_id` no longer
/// depends on the current thread, and so we should continue executing
/// as normal. Otherwise, the function will throw a `Cycle` which is expected
/// to be caught by some frame on our stack. This occurs either if there is
/// a frame on our stack with cycle recovery (possibly the top one!) or if there
/// is no cycle recovery at all.
fn unblock_cycle_and_maybe_throw(
&self,
db: &dyn Database,
mut dg: MutexGuard<'_, DependencyGraph>,
dg: &mut DependencyGraph,
database_key_index: DatabaseKeyIndex,
to_id: RuntimeId,
) -> ! {
debug!("create_cycle_error(database_key={:?})", database_key_index);
) {
debug!(
"unblock_cycle_and_maybe_throw(database_key={:?})",
database_key_index
);

let mut from_stack = self.local_state.take_query_stack();
let from_id = self.id();
Expand Down Expand Up @@ -330,6 +342,7 @@ impl Runtime {
CycleRecoveryStrategy::Fallback => {
debug!("marking {:?} for fallback", aq.database_key_index.debug(db));
aq.take_inputs_from(&cycle_query);
assert!(aq.cycle.is_none());
aq.cycle = Some(cycle.clone());
}

Expand All @@ -339,9 +352,22 @@ impl Runtime {
}
});

// Unblock every thread that has cycle recovery with a `WaitResult::Cycle`.
// They will throw the cycle, which will be caught by the frame that has
// cycle recovery so that it can execute that recovery.
let (me_recovered, others_recovered) =
dg.maybe_unblock_runtimes_in_cycle(from_id, &from_stack, database_key_index, to_id);

self.local_state.restore_query_stack(from_stack);

cycle.throw()
if me_recovered || !others_recovered {
// if me_recovered: If the current thread has recovery, we want to throw
// so that it can begin.
//
// otherwise, if !others_recorded: then no threads have recovery, so we want
// to throw the cycle so that salsa can abort.
cycle.throw();
}
}

/// Block until `other_id` completes executing `database_key`;
Expand Down Expand Up @@ -375,8 +401,12 @@ impl Runtime {
) {
let mut dg = self.shared_state.dependency_graph.lock();

if self.id() == other_id || dg.depends_on(other_id, self.id()) {
self.throw_cycle_error(db, dg, database_key, other_id)
if dg.depends_on(other_id, self.id()) {
self.unblock_cycle_and_maybe_throw(db, &mut dg, database_key, other_id);

// If the above fn returns, then (via cycle recovery) it has unblocked the
// cycle, so we can continue.
assert!(!dg.depends_on(other_id, self.id()));
}

db.salsa_event(Event {
Expand Down Expand Up @@ -407,6 +437,8 @@ impl Runtime {
// cancelled. The assumption is that the panic will be detected
// by the other thread and responded to appropriately.
WaitResult::Panicked => Cancelled::PropagatedPanic.throw(),

WaitResult::Cycle(c) => c.throw(),
}
}

Expand Down
63 changes: 62 additions & 1 deletion src/runtime/dependency_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,67 @@ impl DependencyGraph {
.for_each(&mut closure);
}

/// For each runtime that is blocked as part of this cycle excluding the current one,
/// execute `should_unblock` with its portion of the stack. If that function returns true,
/// then unblocks the given edge. The function is also invoked on the current runtime,
/// but in that case the return value (true or false) is simply returned directly as part
/// of the return tuple, since it is up to the caller to "unblock" or "block".
///
/// Returns a boolean (Current, Others) where:
/// * Current is true if the current runtime should be unblocked and
/// * Others is true if other runtimes were unblocked.
pub(super) fn maybe_unblock_runtimes_in_cycle(
&mut self,
from_id: RuntimeId,
from_stack: &QueryStack,
database_key: DatabaseKeyIndex,
to_id: RuntimeId,
) -> (bool, bool) {
// See diagram in `for_each_cycle_participant`.
let mut id = to_id;
let mut key = database_key;
let mut others_unblocked = false;
while id != from_id {
let edge = self.edges.get(&id).unwrap();
let prefix = edge
.stack
.iter()
.take_while(|p| p.database_key_index != key)
.count();
let next_id = edge.blocked_on_id;
let next_key = edge.blocked_on_key;

if let Some(cycle) = edge.stack[prefix..]
.iter()
.rev()
.filter_map(|aq| aq.cycle.clone())
.next()
{
// Remove `id` from the list of runtimes blocked on `next_key`:
self.query_dependents
.get_mut(&next_key)
.unwrap()
.retain(|r| *r != id);

// Unblock runtime so that it can resume execution once lock is released:
self.unblock_runtime(id, WaitResult::Cycle(cycle));

others_unblocked = true;
}

id = next_id;
key = next_key;
}

let prefix = from_stack
.iter()
.take_while(|p| p.database_key_index != key)
.count();
let this_unblocked = from_stack[prefix..].iter().any(|aq| aq.cycle.is_some());

(this_unblocked, others_unblocked)
}

/// Modifies the graph so that `from_id` is blocked
/// on `database_key`, which is being computed by
/// `to_id`.
Expand Down Expand Up @@ -198,7 +259,7 @@ impl DependencyGraph {
.unwrap_or_default();

for from_id in dependents {
self.unblock_runtime(from_id, wait_result);
self.unblock_runtime(from_id, wait_result.clone());
}
}

Expand Down
1 change: 1 addition & 0 deletions tests/parallel/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod frozen;
mod independent;
mod parallel_cycle_all_recover;
mod parallel_cycle_none_recover;
mod parallel_cycle_one_recovers;
mod race;
mod signal;
mod stress;
Expand Down
95 changes: 95 additions & 0 deletions tests/parallel/parallel_cycle_one_recovers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//! Test for cycle recover spread across two threads.
//! See `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.

use crate::setup::{Knobs, ParDatabaseImpl};
use salsa::ParallelDatabase;
use test_env_log::test;

// Recover cycle test:
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1 b1
// | wait for stage 1 (blocks)
// signal stage 1 |
// wait for stage 2 (blocks) (unblocked)
// | signal stage 2
// (unblocked) wait for stage 3 (blocks)
// a2 |
// b1 (blocks -> stage 3) |
// | (unblocked)
// | b2
// | a1 (cycle detected)
// a2 recovery fn executes |
// a1 completes normally |
// b2 completes, recovers
// b1 completes, recovers

#[test]
fn parallel_cycle_one_recovers() {
let db = ParDatabaseImpl::default();
db.knobs().signal_on_will_block.set(3);

let thread_a = std::thread::spawn({
let db = db.snapshot();
move || db.a1(1)
});

let thread_b = std::thread::spawn({
let db = db.snapshot();
move || db.b1(1)
});

// We expect that the recovery function yields
// `1 * 20 + 2`, which is returned (and forwarded)
// to b1, and from there to a2 and a1.
assert_eq!(thread_a.join().unwrap(), 22);
assert_eq!(thread_b.join().unwrap(), 22);
}

#[salsa::query_group(ParallelCycleOneRecovers)]
pub(crate) trait TestDatabase: Knobs {
fn a1(&self, key: i32) -> i32;

#[salsa::cycle(recover)]
fn a2(&self, key: i32) -> i32;

fn b1(&self, key: i32) -> i32;

fn b2(&self, key: i32) -> i32;
}

fn recover(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
log::debug!("recover");
key * 20 + 2
}

fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.signal(1);
db.wait_for(2);

db.a2(key)
}

fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
db.b1(key)
}

fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.wait_for(1);
db.signal(2);

// Wait for thread A to block on this thread
db.wait_for(3);

db.b2(key)
}

fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
db.a1(key)
}
3 changes: 2 additions & 1 deletion tests/parallel/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ fn sum3_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize {
#[salsa::database(
Par,
crate::parallel_cycle_all_recover::ParallelCycleAllRecover,
crate::parallel_cycle_none_recover::ParallelCycleNoneRecover
crate::parallel_cycle_none_recover::ParallelCycleNoneRecover,
crate::parallel_cycle_one_recovers::ParallelCycleOneRecovers
)]
#[derive(Default)]
pub(crate) struct ParDatabaseImpl {
Expand Down

0 comments on commit 382147f

Please sign in to comment.