Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor accumulators #575

Merged
merged 7 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 31 additions & 87 deletions src/accumulator.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
//! Basic test of accumulator functionality.

use std::{
any::Any,
fmt::{self, Debug},
marker::PhantomData,
};

use accumulated::Accumulated;
use accumulated::AnyAccumulated;
use accumulated_map::AccumulatedMap;

use crate::{
cycle::CycleRecoveryStrategy,
hash::FxDashMap,
ingredient::{fmt_index, Ingredient, Jar},
key::DependencyIndex,
plumbing::JarAux,
zalsa::IngredientIndex,
zalsa_local::{QueryOrigin, ZalsaLocal},
Database, DatabaseKeyIndex, Event, EventKind, Id, Revision,
zalsa_local::QueryOrigin,
Database, DatabaseKeyIndex, Id, Revision,
};

pub trait Accumulator: Clone + Debug + Send + Sync + 'static + Sized {
mod accumulated;
pub(crate) mod accumulated_map;

/// Trait implemented on the struct that user annotated with `#[salsa::accumulator]`.
/// The `Self` type is therefore the types to be accumulated.
pub trait Accumulator: Clone + Debug + Send + Sync + Any + Sized {
const DEBUG_NAME: &'static str;

/// Accumulate an instance of this in the database for later retrieval.
Expand Down Expand Up @@ -49,12 +57,7 @@ impl<A: Accumulator> Jar for JarImpl<A> {

pub struct IngredientImpl<A: Accumulator> {
index: IngredientIndex,
map: FxDashMap<DatabaseKeyIndex, AccumulatedValues<A>>,
}

struct AccumulatedValues<A> {
produced_at: Revision,
values: Vec<A>,
phantom: PhantomData<Accumulated<A>>,
}

impl<A: Accumulator> IngredientImpl<A> {
Expand All @@ -72,67 +75,20 @@ impl<A: Accumulator> IngredientImpl<A> {

pub fn new(index: IngredientIndex) -> Self {
Self {
map: FxDashMap::default(),
index,
phantom: PhantomData,
}
}

fn dependency_index(&self) -> DependencyIndex {
DependencyIndex {
ingredient_index: self.index,
key_index: None,
}
}

pub fn push(&self, db: &dyn crate::Database, value: A) {
let state = db.zalsa_local();
let current_revision = db.zalsa().current_revision();
let (active_query, _) = match state.active_query() {
Some(pair) => pair,
None => {
panic!("cannot accumulate values outside of an active query")
}
};

let mut accumulated_values = self.map.entry(active_query).or_insert(AccumulatedValues {
values: vec![],
produced_at: current_revision,
});

// When we call `push' in a query, we will add the accumulator to the output of the query.
// If we find here that this accumulator is not the output of the query,
// we can say that the accumulated values we stored for this query is out of date.
if !state.is_output_of_active_query(self.dependency_index()) {
accumulated_values.values.truncate(0);
accumulated_values.produced_at = current_revision;
pub fn push(&self, db: &dyn Database, value: A) {
let zalsa_local = db.zalsa_local();
if let Err(()) = zalsa_local.accumulate(self.index, value) {
panic!("cannot accumulate values outside of an active tracked function");
}

state.add_output(self.dependency_index());
accumulated_values.values.push(value);
}

pub(crate) fn produced_by(
&self,
current_revision: Revision,
local_state: &ZalsaLocal,
query: DatabaseKeyIndex,
output: &mut Vec<A>,
) {
if let Some(v) = self.map.get(&query) {
// FIXME: We don't currently have a good way to identify the value that was read.
// You can't report is as a tracked read of `query`, because the return value of query is not being read here --
// instead it is the set of values accumuated by `query`.
local_state.report_untracked_read(current_revision);

let AccumulatedValues {
values,
produced_at,
} = v.value();

if *produced_at == current_revision {
output.extend(values.iter().cloned());
}
}
pub fn index(&self) -> IngredientIndex {
self.index
}
}

Expand Down Expand Up @@ -160,34 +116,18 @@ impl<A: Accumulator> Ingredient for IngredientImpl<A> {

fn mark_validated_output(
&self,
db: &dyn Database,
executor: DatabaseKeyIndex,
output_key: Option<crate::Id>,
_db: &dyn Database,
_executor: DatabaseKeyIndex,
_output_key: Option<crate::Id>,
) {
assert!(output_key.is_none());
let current_revision = db.zalsa().current_revision();
if let Some(mut v) = self.map.get_mut(&executor) {
// The value is still valid in the new revision.
v.produced_at = current_revision;
}
}

fn remove_stale_output(
&self,
db: &dyn Database,
executor: DatabaseKeyIndex,
stale_output_key: Option<crate::Id>,
_db: &dyn Database,
_executor: DatabaseKeyIndex,
_stale_output_key: Option<crate::Id>,
) {
assert!(stale_output_key.is_none());
if self.map.remove(&executor).is_some() {
db.salsa_event(&|| Event {
thread_id: std::thread::current().id(),
kind: EventKind::DidDiscardAccumulated {
executor_key: executor,
accumulator: self.dependency_index(),
},
})
}
}

fn requires_reset_for_new_revision(&self) -> bool {
Expand All @@ -205,6 +145,10 @@ impl<A: Accumulator> Ingredient for IngredientImpl<A> {
fn debug_name(&self) -> &'static str {
A::DEBUG_NAME
}

fn accumulated(&self, _db: &dyn Database, _key_index: Id) -> Option<&AccumulatedMap> {
None
}
}

impl<A> std::fmt::Debug for IngredientImpl<A>
Expand Down
60 changes: 60 additions & 0 deletions src/accumulator/accumulated.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use std::any::Any;
use std::fmt::Debug;

use super::Accumulator;

#[derive(Clone, Debug)]
pub(crate) struct Accumulated<A: Accumulator> {
values: Vec<A>,
}

pub(crate) trait AnyAccumulated: Any + Debug + Send + Sync {
fn as_dyn_any(&self) -> &dyn Any;
fn as_dyn_any_mut(&mut self) -> &mut dyn Any;
fn cloned(&self) -> Box<dyn AnyAccumulated>;
}

impl<A: Accumulator> Accumulated<A> {
pub fn push(&mut self, value: A) {
self.values.push(value);
}

pub fn extend_with_accumulated(&self, values: &mut Vec<A>) {
values.extend_from_slice(&self.values);
}
}

impl<A: Accumulator> Default for Accumulated<A> {
fn default() -> Self {
Self {
values: Default::default(),
}
}
}

impl<A> AnyAccumulated for Accumulated<A>
where
A: Accumulator,
{
fn as_dyn_any(&self) -> &dyn Any {
self
}

fn as_dyn_any_mut(&mut self) -> &mut dyn Any {
self
}

fn cloned(&self) -> Box<dyn AnyAccumulated> {
let this: Self = self.clone();
Box::new(this)
}
}

impl dyn AnyAccumulated {
pub fn accumulate<A: Accumulator>(&mut self, value: A) {
self.as_dyn_any_mut()
.downcast_mut::<Accumulated<A>>()
.unwrap()
.push(value);
}
}
46 changes: 46 additions & 0 deletions src/accumulator/accumulated_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use rustc_hash::FxHashMap;

use crate::IngredientIndex;

use super::{accumulated::Accumulated, Accumulator, AnyAccumulated};

#[derive(Default, Debug)]
pub struct AccumulatedMap {
map: FxHashMap<IngredientIndex, Box<dyn AnyAccumulated>>,
}

impl AccumulatedMap {
pub fn accumulate<A: Accumulator>(&mut self, index: IngredientIndex, value: A) {
self.map
.entry(index)
.or_insert_with(|| <Box<Accumulated<A>>>::default())
.accumulate(value);
}

pub fn extend_with_accumulated<A: Accumulator>(
&self,
index: IngredientIndex,
output: &mut Vec<A>,
) {
let Some(a) = self.map.get(&index) else {
return;
};

a.as_dyn_any()
.downcast_ref::<Accumulated<A>>()
.unwrap()
.extend_with_accumulated(output);
}
}

impl Clone for AccumulatedMap {
fn clone(&self) -> Self {
Self {
map: self
.map
.iter()
.map(|(&key, value)| (key, value.cloned()))
.collect(),
}
}
}
11 changes: 9 additions & 2 deletions src/active_query.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use rustc_hash::FxHashMap;

use crate::{
accumulator::accumulated_map::AccumulatedMap,
durability::Durability,
hash::{FxIndexMap, FxIndexSet},
hash::FxIndexSet,
key::{DatabaseKeyIndex, DependencyIndex},
tracked_struct::{Disambiguator, KeyStruct},
zalsa_local::EMPTY_DEPENDENCIES,
Expand Down Expand Up @@ -44,11 +45,15 @@ pub(crate) struct ActiveQuery {
/// This table starts empty as the query begins and is gradually populated.
/// Note that if a query executes in 2 different revisions but creates the same
/// set of tracked structs, they will get the same disambiguator values.
disambiguator_map: FxIndexMap<u64, Disambiguator>,
disambiguator_map: FxHashMap<u64, Disambiguator>,

/// Map from tracked struct keys (which include the hash + disambiguator) to their
/// final id.
pub(crate) tracked_struct_ids: FxHashMap<KeyStruct, Id>,

/// Stores the values accumulated to the given ingredient.
/// The type of accumulated value is erased but known to the ingredient.
pub(crate) accumulated: AccumulatedMap,
}

impl ActiveQuery {
Expand All @@ -62,6 +67,7 @@ impl ActiveQuery {
cycle: None,
disambiguator_map: Default::default(),
tracked_struct_ids: Default::default(),
accumulated: Default::default(),
}
}

Expand Down Expand Up @@ -118,6 +124,7 @@ impl ActiveQuery {
origin,
durability: self.durability,
tracked_struct_ids: self.tracked_struct_ids,
accumulated: self.accumulated,
}
}

Expand Down
25 changes: 17 additions & 8 deletions src/function.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{any::Any, fmt, sync::Arc};

use crate::{
accumulator::accumulated_map::AccumulatedMap,
cycle::CycleRecoveryStrategy,
ingredient::fmt_index,
key::DatabaseKeyIndex,
Expand Down Expand Up @@ -152,22 +153,21 @@ where
/// when this function is called and (b) ensuring that any entries
/// removed from the memo-map are added to `deleted_entries`, which is
/// only cleared with `&mut self`.
unsafe fn extend_memo_lifetime<'this, 'memo>(
unsafe fn extend_memo_lifetime<'this>(
&'this self,
memo: &'memo memo::Memo<C::Output<'this>>,
) -> Option<&'this C::Output<'this>> {
let memo_value: Option<&'memo C::Output<'this>> = memo.value.as_ref();
std::mem::transmute(memo_value)
memo: &memo::Memo<C::Output<'this>>,
) -> &'this memo::Memo<C::Output<'this>> {
std::mem::transmute(memo)
}

fn insert_memo<'db>(
&'db self,
zalsa: &'db Zalsa,
id: Id,
memo: memo::Memo<C::Output<'db>>,
) -> Option<&C::Output<'db>> {
) -> &'db memo::Memo<C::Output<'db>> {
let memo = Arc::new(memo);
let value = unsafe {
let db_memo = unsafe {
// Unsafety conditions: memo must be in the map (it's not yet, but it will be by the time this
// value is returned) and anything removed from map is added to deleted entries (ensured elsewhere).
self.extend_memo_lifetime(&memo)
Expand All @@ -177,7 +177,7 @@ where
// in the deleted entries. This will get cleared when a new revision starts.
self.deleted_entries.push(old_value);
}
value
db_memo
}
}

Expand Down Expand Up @@ -244,6 +244,15 @@ where
fn debug_name(&self) -> &'static str {
C::DEBUG_NAME
}

fn accumulated<'db>(
&'db self,
db: &'db dyn Database,
key_index: Id,
) -> Option<&'db AccumulatedMap> {
let db = db.as_view::<C::DbView>();
self.accumulated_map(db, key_index)
}
}

impl<C> std::fmt::Debug for IngredientImpl<C>
Expand Down
Loading
Loading