Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce parallel salsa #568

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" }
salsa-macros = { path = "components/salsa-macros" }
smallvec = "1"
lazy_static = "1"
rayon = "1.10.0"

[dev-dependencies]
annotate-snippets = "0.11.4"
Expand Down
2 changes: 1 addition & 1 deletion examples/calc/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::{Arc, Mutex};

// ANCHOR: db_struct
#[salsa::db]
#[derive(Default)]
#[derive(Default, Clone)]
pub struct CalcDatabaseImpl {
storage: salsa::Storage<Self>,

Expand Down
15 changes: 11 additions & 4 deletions examples/lazy-input/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#![allow(unreachable_patterns)]
// FIXME(rust-lang/rust#129031): regression in nightly
use std::{path::PathBuf, sync::Mutex, time::Duration};
use std::{
path::PathBuf,
sync::{Arc, Mutex},
time::Duration,
};

use crossbeam::channel::{unbounded, Sender};
use dashmap::{mapref::entry::Entry, DashMap};
Expand Down Expand Up @@ -77,11 +81,12 @@ trait Db: salsa::Database {
}

#[salsa::db]
#[derive(Clone)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I kind of preferred the old salsa's ParallelDatabase trait rather than relying on Clone. It makes the contract more explicit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's certainly option, but I'm not sure ParallelDatabase works well with when an end-user of Salsa only has &dyn HirDatabase and it should be safe and reasonable to use Rayon with queries that use &dyn HirDatabase. Like I said in my other comment, I'm going off of this document, but I'd be happy to change this design! I mostly care about making the example work with any database.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I was debating about in those variations-- there might be a hybrid here. That said, I think that ALL databases will effectively be parallel databases under this plan (that is ~required to allow queries to use parallelism internally, since they are written against a dyn Database).

I guess we could still make a ParallelDatabase trait and have them work against a dyn ParallelDatabase, but I'm not convinced there's a use case for "non-parallel databases".

Copy link
Contributor Author

@davidbarsky davidbarsky Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least in rust-analyzer's case, I can't really think of a situation where it wouldn't make sense to have every Database be a ParallelDatabase. I don't think we'd want to make everything parallel, but having the option to do so without extensive refactors is really valuable, in my view.

struct LazyInputDatabase {
storage: Storage<Self>,
logs: Mutex<Vec<String>>,
logs: Arc<Mutex<Vec<String>>>,
files: DashMap<PathBuf, File>,
file_watcher: Mutex<Debouncer<RecommendedWatcher>>,
file_watcher: Arc<Mutex<Debouncer<RecommendedWatcher>>>,
}

impl LazyInputDatabase {
Expand All @@ -90,7 +95,9 @@ impl LazyInputDatabase {
storage: Default::default(),
logs: Default::default(),
files: DashMap::new(),
file_watcher: Mutex::new(new_debouncer(Duration::from_secs(1), tx).unwrap()),
file_watcher: Arc::new(Mutex::new(
new_debouncer(Duration::from_secs(1), tx).unwrap(),
)),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl dyn Database {
///
/// # Panics
///
/// If the view has not been added to the database (see [`DatabaseView`][])
/// If the view has not been added to the database (see [`crate::views::Views`]).
#[track_caller]
pub fn as_view<DbView: ?Sized + Database>(&self) -> &DbView {
self.zalsa().views().try_view_as(self).unwrap()
Expand Down
2 changes: 1 addition & 1 deletion src/database_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{self as salsa, Database, Event, Storage};
#[salsa::db]
/// Default database implementation that you can use if you don't
/// require any custom user data.
#[derive(Default)]
#[derive(Default, Clone)]
pub struct DatabaseImpl {
storage: Storage<Self>,
}
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod input;
mod interned;
mod key;
mod nonce;
mod par_map;
mod revision;
mod runtime;
mod salsa_struct;
Expand Down Expand Up @@ -45,6 +46,7 @@ pub use self::storage::Storage;
pub use self::update::Update;
pub use self::zalsa::IngredientIndex;
pub use crate::attach::with_attached_database;
pub use par_map::par_map;
pub use salsa_macros::accumulator;
pub use salsa_macros::db;
pub use salsa_macros::input;
Expand Down
54 changes: 54 additions & 0 deletions src/par_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use std::ops::Deref;

use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator};

use crate::Database;

pub fn par_map<Db, D, E, C>(
db: &Db,
inputs: impl IntoParallelIterator<Item = D>,
op: fn(&Db, D) -> E,
) -> C
where
Db: Database + ?Sized,
D: Send,
E: Send + Sync,
C: FromParallelIterator<E>,
{
let parallel_db = ParallelDb::Ref(db.as_dyn_database());

inputs
.into_par_iter()
.map_with(parallel_db, |parallel_db, element| {
let db = parallel_db.as_view::<Db>();
op(db, element)
})
.collect()
}

/// This enum _must not_ be public or used outside of `par_map`.
enum ParallelDb<'db> {
Ref(&'db dyn Database),
Fork(Box<dyn Database + Send>),
}

/// SAFETY: the contents of the database are never accessed on the thread
/// where this wrapper type is created.
unsafe impl Send for ParallelDb<'_> {}

impl Deref for ParallelDb<'_> {
type Target = dyn Database;

fn deref(&self) -> &Self::Target {
match self {
ParallelDb::Ref(db) => *db,
ParallelDb::Fork(db) => db.as_dyn_database(),
}
}
}

impl Clone for ParallelDb<'_> {
fn clone(&self) -> Self {
ParallelDb::Fork(self.fork_db())
}
}
davidbarsky marked this conversation as resolved.
Show resolved Hide resolved
6 changes: 5 additions & 1 deletion src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
///
/// The `storage` and `storage_mut` fields must both return a reference to the same
/// storage field which must be owned by `self`.
pub unsafe trait HasStorage: Database + Sized {
pub unsafe trait HasStorage: Database + Clone + Sized {
fn storage(&self) -> &Storage<Self>;
fn storage_mut(&mut self) -> &mut Storage<Self>;
}
Expand Down Expand Up @@ -108,6 +108,10 @@ unsafe impl<T: HasStorage> ZalsaDatabase for T {
fn zalsa_local(&self) -> &ZalsaLocal {
&self.storage().zalsa_local
}

fn fork_db(&self) -> Box<dyn Database> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we return T instead of a Box<dyn Database>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Micha: I'm going off the designed specified in the section titled "Clone alternative". Happy to revisit this, if needed!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't readily do that because of dyn safety, but there are other ways we can structure it to expose a variation that returns a Self.

Box::new(self.clone())
}
}

impl<Db: Database> RefUnwindSafe for Storage<Db> {}
Expand Down
4 changes: 4 additions & 0 deletions src/zalsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ pub unsafe trait ZalsaDatabase: Any {
/// Access the thread-local state associated with this database
#[doc(hidden)]
fn zalsa_local(&self) -> &ZalsaLocal;

/// Clone the database.
#[doc(hidden)]
fn fork_db(&self) -> Box<dyn Database>;
}

pub fn views<Db: ?Sized + Database>(db: &Db) -> &Views {
Expand Down
14 changes: 8 additions & 6 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@

#![allow(dead_code)]

use std::sync::{Arc, Mutex};

use salsa::{Database, Storage};

/// Logging userdata: provides [`LogDatabase`][] trait.
///
/// If you wish to use it along with other userdata,
/// you can also embed it in another struct and implement [`HasLogger`][] for that struct.
#[derive(Default)]
#[derive(Clone, Default)]
pub struct Logger {
logs: std::sync::Mutex<Vec<String>>,
logs: Arc<Mutex<Vec<String>>>,
}

/// Trait implemented by databases that lets them log events.
Expand Down Expand Up @@ -48,7 +50,7 @@ impl<Db: HasLogger + Database> LogDatabase for Db {}

/// Database that provides logging but does not log salsa event.
#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct LoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand All @@ -67,7 +69,7 @@ impl Database for LoggerDatabase {

/// Database that provides logging and logs salsa events.
#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct EventLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand All @@ -87,7 +89,7 @@ impl HasLogger for EventLoggerDatabase {
}

#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct DiscardLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand All @@ -114,7 +116,7 @@ impl HasLogger for DiscardLoggerDatabase {
}

#[salsa::db]
#[derive(Default)]
#[derive(Clone, Default)]
pub struct ExecuteValidateLoggerDatabase {
storage: Storage<Self>,
logger: Logger,
Expand Down
1 change: 1 addition & 0 deletions tests/parallel/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ mod parallel_cycle_all_recover;
mod parallel_cycle_mid_recover;
mod parallel_cycle_none_recover;
mod parallel_cycle_one_recover;
mod parallel_map;
mod signal;
98 changes: 98 additions & 0 deletions tests/parallel/parallel_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// test for rayon interations.

use salsa::Cancelled;
use salsa::Setter;

use crate::setup::Knobs;
use crate::setup::KnobsDatabase;

#[salsa::input]
struct ParallelInput {
field: Vec<u32>,
}

#[salsa::tracked]
fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> Vec<u32> {
salsa::par_map(db, input.field(db), |_db, field| field + 1)
}

#[test]
fn execute() {
let db = salsa::DatabaseImpl::new();

let counts = (1..=10).collect::<Vec<u32>>();
let input = ParallelInput::new(&db, counts);

tracked_fn(&db, input);
}

#[salsa::tracked]
fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> Vec<u32> {
db.signal(1);
salsa::par_map(db, input.field(db), |db, field| {
db.wait_for(2);
field + 1
})
}

#[salsa::tracked]
fn dummy(_db: &dyn KnobsDatabase, _input: ParallelInput) -> ParallelInput {
panic!("should never get here!")
}

// we expect this to panic, as `salsa::par_map` needs to be called from a query.
#[test]
#[should_panic]
fn direct_calls_panic() {
let db = salsa::DatabaseImpl::new();

let counts = (1..=10).collect::<Vec<u32>>();
let input = ParallelInput::new(&db, counts);
let _: Vec<u32> = salsa::par_map(&db, input.field(&db), |_db, field| field + 1);
}

// Cancellation signalling test
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1
// | wait for stage 1
// signal stage 1 set input, triggers cancellation
// wait for stage 2 (blocks) triggering cancellation sends stage 2
// |
// (unblocked)
// dummy
// panics

#[test]
fn execute_cancellation() {
let mut db = Knobs::default();

let counts = (1..=10).collect::<Vec<u32>>();
let input = ParallelInput::new(&db, counts);

let thread_a = std::thread::spawn({
let db = db.clone();
move || a1(&db, input)
});

let counts = (2..=20).collect::<Vec<u32>>();

db.signal_on_did_cancel.store(2);
input.set_field(&mut db).to(counts);

// Assert thread A *should* was cancelled
let cancelled = thread_a
.join()
.unwrap_err()
.downcast::<Cancelled>()
.unwrap();

// and inspect the output
expect_test::expect![[r#"
PendingWrite
"#]]
.assert_debug_eq(&cancelled);
}
2 changes: 1 addition & 1 deletion tests/tracked_struct_durability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ fn check<'db>(db: &'db dyn Db, file: File) -> Inference<'db> {
#[test]
fn execute() {
#[salsa::db]
#[derive(Default)]
#[derive(Default, Clone)]
struct Database {
storage: salsa::Storage<Self>,
files: Vec<File>,
Expand Down
Loading