diff --git a/src/accumulator.rs b/src/accumulator.rs index 01df30fd..b9355419 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -1,22 +1,30 @@ //! Basic test of accumulator functionality. use std::{ + any::Any, fmt::{self, Debug}, marker::PhantomData, }; +use accumulated::Accumulated; +use accumulated::AnyAccumulated; +use accumulated_map::AccumulatedMap; + use crate::{ cycle::CycleRecoveryStrategy, - hash::FxDashMap, ingredient::{fmt_index, Ingredient, Jar}, - key::DependencyIndex, plumbing::JarAux, zalsa::IngredientIndex, - zalsa_local::{QueryOrigin, ZalsaLocal}, - Database, DatabaseKeyIndex, Event, EventKind, Id, Revision, + zalsa_local::QueryOrigin, + Database, DatabaseKeyIndex, Id, Revision, }; -pub trait Accumulator: Clone + Debug + Send + Sync + 'static + Sized { +mod accumulated; +pub(crate) mod accumulated_map; + +/// Trait implemented on the struct that user annotated with `#[salsa::accumulator]`. +/// The `Self` type is therefore the types to be accumulated. +pub trait Accumulator: Clone + Debug + Send + Sync + Any + Sized { const DEBUG_NAME: &'static str; /// Accumulate an instance of this in the database for later retrieval. @@ -49,12 +57,7 @@ impl Jar for JarImpl { pub struct IngredientImpl { index: IngredientIndex, - map: FxDashMap>, -} - -struct AccumulatedValues { - produced_at: Revision, - values: Vec, + phantom: PhantomData>, } impl IngredientImpl { @@ -72,67 +75,20 @@ impl IngredientImpl { pub fn new(index: IngredientIndex) -> Self { Self { - map: FxDashMap::default(), index, + phantom: PhantomData, } } - fn dependency_index(&self) -> DependencyIndex { - DependencyIndex { - ingredient_index: self.index, - key_index: None, - } - } - - pub fn push(&self, db: &dyn crate::Database, value: A) { - let state = db.zalsa_local(); - let current_revision = db.zalsa().current_revision(); - let (active_query, _) = match state.active_query() { - Some(pair) => pair, - None => { - panic!("cannot accumulate values outside of an active query") - } - }; - - let mut accumulated_values = self.map.entry(active_query).or_insert(AccumulatedValues { - values: vec![], - produced_at: current_revision, - }); - - // When we call `push' in a query, we will add the accumulator to the output of the query. - // If we find here that this accumulator is not the output of the query, - // we can say that the accumulated values we stored for this query is out of date. - if !state.is_output_of_active_query(self.dependency_index()) { - accumulated_values.values.truncate(0); - accumulated_values.produced_at = current_revision; + pub fn push(&self, db: &dyn Database, value: A) { + let zalsa_local = db.zalsa_local(); + if let Err(()) = zalsa_local.accumulate(self.index, value) { + panic!("cannot accumulate values outside of an active tracked function"); } - - state.add_output(self.dependency_index()); - accumulated_values.values.push(value); } - pub(crate) fn produced_by( - &self, - current_revision: Revision, - local_state: &ZalsaLocal, - query: DatabaseKeyIndex, - output: &mut Vec, - ) { - if let Some(v) = self.map.get(&query) { - // FIXME: We don't currently have a good way to identify the value that was read. - // You can't report is as a tracked read of `query`, because the return value of query is not being read here -- - // instead it is the set of values accumuated by `query`. - local_state.report_untracked_read(current_revision); - - let AccumulatedValues { - values, - produced_at, - } = v.value(); - - if *produced_at == current_revision { - output.extend(values.iter().cloned()); - } - } + pub fn index(&self) -> IngredientIndex { + self.index } } @@ -160,34 +116,18 @@ impl Ingredient for IngredientImpl { fn mark_validated_output( &self, - db: &dyn Database, - executor: DatabaseKeyIndex, - output_key: Option, + _db: &dyn Database, + _executor: DatabaseKeyIndex, + _output_key: Option, ) { - assert!(output_key.is_none()); - let current_revision = db.zalsa().current_revision(); - if let Some(mut v) = self.map.get_mut(&executor) { - // The value is still valid in the new revision. - v.produced_at = current_revision; - } } fn remove_stale_output( &self, - db: &dyn Database, - executor: DatabaseKeyIndex, - stale_output_key: Option, + _db: &dyn Database, + _executor: DatabaseKeyIndex, + _stale_output_key: Option, ) { - assert!(stale_output_key.is_none()); - if self.map.remove(&executor).is_some() { - db.salsa_event(&|| Event { - thread_id: std::thread::current().id(), - kind: EventKind::DidDiscardAccumulated { - executor_key: executor, - accumulator: self.dependency_index(), - }, - }) - } } fn requires_reset_for_new_revision(&self) -> bool { @@ -205,6 +145,10 @@ impl Ingredient for IngredientImpl { fn debug_name(&self) -> &'static str { A::DEBUG_NAME } + + fn accumulated(&self, _db: &dyn Database, _key_index: Id) -> Option<&AccumulatedMap> { + None + } } impl std::fmt::Debug for IngredientImpl diff --git a/src/accumulator/accumulated.rs b/src/accumulator/accumulated.rs new file mode 100644 index 00000000..365b2f46 --- /dev/null +++ b/src/accumulator/accumulated.rs @@ -0,0 +1,60 @@ +use std::any::Any; +use std::fmt::Debug; + +use super::Accumulator; + +#[derive(Clone, Debug)] +pub(crate) struct Accumulated { + values: Vec, +} + +pub(crate) trait AnyAccumulated: Any + Debug + Send + Sync { + fn as_dyn_any(&self) -> &dyn Any; + fn as_dyn_any_mut(&mut self) -> &mut dyn Any; + fn cloned(&self) -> Box; +} + +impl Accumulated { + pub fn push(&mut self, value: A) { + self.values.push(value); + } + + pub fn extend_with_accumulated(&self, values: &mut Vec) { + values.extend_from_slice(&self.values); + } +} + +impl Default for Accumulated { + fn default() -> Self { + Self { + values: Default::default(), + } + } +} + +impl AnyAccumulated for Accumulated +where + A: Accumulator, +{ + fn as_dyn_any(&self) -> &dyn Any { + self + } + + fn as_dyn_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn cloned(&self) -> Box { + let this: Self = self.clone(); + Box::new(this) + } +} + +impl dyn AnyAccumulated { + pub fn accumulate(&mut self, value: A) { + self.as_dyn_any_mut() + .downcast_mut::>() + .unwrap() + .push(value); + } +} diff --git a/src/accumulator/accumulated_map.rs b/src/accumulator/accumulated_map.rs new file mode 100644 index 00000000..bda989a7 --- /dev/null +++ b/src/accumulator/accumulated_map.rs @@ -0,0 +1,46 @@ +use rustc_hash::FxHashMap; + +use crate::IngredientIndex; + +use super::{accumulated::Accumulated, Accumulator, AnyAccumulated}; + +#[derive(Default, Debug)] +pub struct AccumulatedMap { + map: FxHashMap>, +} + +impl AccumulatedMap { + pub fn accumulate(&mut self, index: IngredientIndex, value: A) { + self.map + .entry(index) + .or_insert_with(|| >>::default()) + .accumulate(value); + } + + pub fn extend_with_accumulated( + &self, + index: IngredientIndex, + output: &mut Vec, + ) { + let Some(a) = self.map.get(&index) else { + return; + }; + + a.as_dyn_any() + .downcast_ref::>() + .unwrap() + .extend_with_accumulated(output); + } +} + +impl Clone for AccumulatedMap { + fn clone(&self) -> Self { + Self { + map: self + .map + .iter() + .map(|(&key, value)| (key, value.cloned())) + .collect(), + } + } +} diff --git a/src/active_query.rs b/src/active_query.rs index 87b834ba..45508ac9 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -1,8 +1,9 @@ use rustc_hash::FxHashMap; use crate::{ + accumulator::accumulated_map::AccumulatedMap, durability::Durability, - hash::{FxIndexMap, FxIndexSet}, + hash::FxIndexSet, key::{DatabaseKeyIndex, DependencyIndex}, tracked_struct::{Disambiguator, KeyStruct}, zalsa_local::EMPTY_DEPENDENCIES, @@ -44,11 +45,15 @@ pub(crate) struct ActiveQuery { /// This table starts empty as the query begins and is gradually populated. /// Note that if a query executes in 2 different revisions but creates the same /// set of tracked structs, they will get the same disambiguator values. - disambiguator_map: FxIndexMap, + disambiguator_map: FxHashMap, /// Map from tracked struct keys (which include the hash + disambiguator) to their /// final id. pub(crate) tracked_struct_ids: FxHashMap, + + /// Stores the values accumulated to the given ingredient. + /// The type of accumulated value is erased but known to the ingredient. + pub(crate) accumulated: AccumulatedMap, } impl ActiveQuery { @@ -62,6 +67,7 @@ impl ActiveQuery { cycle: None, disambiguator_map: Default::default(), tracked_struct_ids: Default::default(), + accumulated: Default::default(), } } @@ -118,6 +124,7 @@ impl ActiveQuery { origin, durability: self.durability, tracked_struct_ids: self.tracked_struct_ids, + accumulated: self.accumulated, } } diff --git a/src/function.rs b/src/function.rs index 24cbb130..07f13d49 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,6 +1,7 @@ use std::{any::Any, fmt, sync::Arc}; use crate::{ + accumulator::accumulated_map::AccumulatedMap, cycle::CycleRecoveryStrategy, ingredient::fmt_index, key::DatabaseKeyIndex, @@ -152,12 +153,11 @@ where /// when this function is called and (b) ensuring that any entries /// removed from the memo-map are added to `deleted_entries`, which is /// only cleared with `&mut self`. - unsafe fn extend_memo_lifetime<'this, 'memo>( + unsafe fn extend_memo_lifetime<'this>( &'this self, - memo: &'memo memo::Memo>, - ) -> Option<&'this C::Output<'this>> { - let memo_value: Option<&'memo C::Output<'this>> = memo.value.as_ref(); - std::mem::transmute(memo_value) + memo: &memo::Memo>, + ) -> &'this memo::Memo> { + std::mem::transmute(memo) } fn insert_memo<'db>( @@ -165,9 +165,9 @@ where zalsa: &'db Zalsa, id: Id, memo: memo::Memo>, - ) -> Option<&C::Output<'db>> { + ) -> &'db memo::Memo> { let memo = Arc::new(memo); - let value = unsafe { + let db_memo = unsafe { // Unsafety conditions: memo must be in the map (it's not yet, but it will be by the time this // value is returned) and anything removed from map is added to deleted entries (ensured elsewhere). self.extend_memo_lifetime(&memo) @@ -177,7 +177,7 @@ where // in the deleted entries. This will get cleared when a new revision starts. self.deleted_entries.push(old_value); } - value + db_memo } } @@ -244,6 +244,15 @@ where fn debug_name(&self) -> &'static str { C::DEBUG_NAME } + + fn accumulated<'db>( + &'db self, + db: &'db dyn Database, + key_index: Id, + ) -> Option<&'db AccumulatedMap> { + let db = db.as_view::(); + self.accumulated_map(db, key_index) + } } impl std::fmt::Debug for IngredientImpl diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index 5f9ccc2a..21e017a5 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -1,5 +1,8 @@ use crate::{ - accumulator, hash::FxHashSet, zalsa::ZalsaDatabase, AsDynDatabase, DatabaseKeyIndex, Id, + accumulator::{self, accumulated_map::AccumulatedMap}, + hash::FxHashSet, + zalsa::ZalsaDatabase, + AsDynDatabase, DatabaseKeyIndex, Id, }; use super::{Configuration, IngredientImpl}; @@ -14,9 +17,20 @@ where where A: accumulator::Accumulator, { - let zalsa = db.zalsa(); - let zalsa_local = db.zalsa_local(); - let current_revision = zalsa.current_revision(); + let (zalsa, zalsa_local) = db.zalsas(); + + // NOTE: We don't have a precise way to track accumulated values at present, + // so we report any read of them as an untracked read. + // + // Like tracked struct fields, accumulated values are essentially a "side channel output" + // from a tracked function, hence we can't report this as a read of the tracked function(s) + // whose accumulated values we are probing, since the accumulated values may have changed + // even when the the main return value of the function has not changed. + // + // Unlike tracked struct fields, we don't have a distinct id or ingredient to represent + // "the values of type A accumulated by tracked function X". Typically accumulated values + // are read from outside of salsa anyway so this is not a big deal. + zalsa_local.report_untracked_read(zalsa.current_revision()); let Some(accumulator) = >::from_db(db) else { return vec![]; @@ -31,25 +45,46 @@ where let mut visited: FxHashSet = FxHashSet::default(); let mut stack: Vec = vec![db_key]; + // Do a depth-first earch across the dependencies of `key`, reading the values accumulated by + // each dependency. while let Some(k) = stack.pop() { - if visited.insert(k) { - accumulator.produced_by(current_revision, zalsa_local, k, &mut output); - - let origin = zalsa - .lookup_ingredient(k.ingredient_index) - .origin(db, k.key_index); - let inputs = origin.iter().flat_map(|origin| origin.inputs()); - // Careful: we want to push in execution order, so reverse order to - // ensure the first child that was executed will be the first child popped - // from the stack. - stack.extend( - inputs - .flat_map(|input| TryInto::::try_into(input).into_iter()) - .rev(), - ); + // Already visited `k`? + if !visited.insert(k) { + continue; + } + + // Extend `output` with any values accumulated by `k`. + if let Some(accumulated_map) = k.accumulated(db) { + accumulated_map.extend_with_accumulated(accumulator.index(), &mut output); } + + // Find the inputs of `k` and push them onto the stack. + // + // Careful: to ensure the user gets a consistent ordering in their + // output vector, we want to push in execution order, so reverse order to + // ensure the first child that was executed will be the first child popped + // from the stack. + let origin = zalsa + .lookup_ingredient(k.ingredient_index) + .origin(db, k.key_index); + let inputs = origin.iter().flat_map(|origin| origin.inputs()); + stack.extend( + inputs + .flat_map(|input| TryInto::::try_into(input).into_iter()) + .rev(), + ); } output } + + pub(super) fn accumulated_map<'db>( + &'db self, + db: &'db C::DbView, + key: Id, + ) -> Option<&'db AccumulatedMap> { + // NEXT STEP: stash and refactor `fetch` to return an `&Memo` so we can make this work + let memo = self.refresh_memo(db, key); + Some(&memo.revisions.accumulated) + } } diff --git a/src/function/execute.rs b/src/function/execute.rs index 0f1876b6..8d8beca0 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,8 +1,7 @@ use std::sync::Arc; use crate::{ - runtime::StampedValue, zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle, Database, - Event, EventKind, + zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle, Database, Event, EventKind, }; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -25,7 +24,7 @@ where db: &'db C::DbView, active_query: ActiveQueryGuard<'_>, opt_old_memo: Option>>>, - ) -> StampedValue<&C::Output<'db>> { + ) -> &'db Memo> { let zalsa = db.zalsa(); let revision_now = zalsa.current_revision(); let database_key_index = active_query.database_key_index; @@ -86,11 +85,6 @@ where tracing::debug!("{database_key_index:?}: read_upgrade: result.revisions = {revisions:#?}"); - let stamp_template = revisions.stamp_template(); - let value = self - .insert_memo(zalsa, id, Memo::new(Some(value), revision_now, revisions)) - .unwrap(); - - stamp_template.stamp(value) + self.insert_memo(zalsa, id, Memo::new(Some(value), revision_now, revisions)) } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index ff77b95c..e16ac3f0 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,6 +1,6 @@ use crate::{runtime::StampedValue, zalsa::ZalsaDatabase, AsDynDatabase as _, Id}; -use super::{Configuration, IngredientImpl}; +use super::{memo::Memo, Configuration, IngredientImpl}; impl IngredientImpl where @@ -10,11 +10,12 @@ where let (zalsa, zalsa_local) = db.zalsas(); zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database()); + let memo = self.refresh_memo(db, id); let StampedValue { value, durability, changed_at, - } = self.compute_value(db, id); + } = memo.revisions.stamped_value(memo.value.as_ref().unwrap()); if let Some(evicted) = self.lru.record_use(id) { self.evict_value_from_memo_for(zalsa, evicted); @@ -26,45 +27,36 @@ where } #[inline] - fn compute_value<'db>( + pub(super) fn refresh_memo<'db>( &'db self, db: &'db C::DbView, id: Id, - ) -> StampedValue<&'db C::Output<'db>> { + ) -> &'db Memo> { loop { - if let Some(value) = self.fetch_hot(db, id).or_else(|| self.fetch_cold(db, id)) { - return value; + if let Some(memo) = self.fetch_hot(db, id).or_else(|| self.fetch_cold(db, id)) { + return memo; } } } #[inline] - fn fetch_hot<'db>( - &'db self, - db: &'db C::DbView, - id: Id, - ) -> Option>> { + fn fetch_hot<'db>(&'db self, db: &'db C::DbView, id: Id) -> Option<&'db Memo>> { let zalsa = db.zalsa(); let memo_guard = self.get_memo_from_table_for(zalsa, id); if let Some(memo) = &memo_guard { if memo.value.is_some() && self.shallow_verify_memo(db, zalsa, self.database_key_index(id), memo) { - let value = unsafe { - // Unsafety invariant: memo is present in memo_map - self.extend_memo_lifetime(memo).unwrap() - }; - return Some(memo.revisions.stamped_value(value)); + // Unsafety invariant: memo is present in memo_map + unsafe { + return Some(self.extend_memo_lifetime(memo)); + } } } None } - fn fetch_cold<'db>( - &'db self, - db: &'db C::DbView, - id: Id, - ) -> Option>> { + fn fetch_cold<'db>(&'db self, db: &'db C::DbView, id: Id) -> Option<&'db Memo>> { let (zalsa, zalsa_local) = db.zalsas(); let database_key_index = self.database_key_index(id); @@ -84,11 +76,10 @@ where let opt_old_memo = self.get_memo_from_table_for(zalsa, id); if let Some(old_memo) = &opt_old_memo { if old_memo.value.is_some() && self.deep_verify_memo(db, old_memo, &active_query) { - let value = unsafe { - // Unsafety invariant: memo is present in memo_map. - self.extend_memo_lifetime(old_memo).unwrap() - }; - return Some(old_memo.revisions.stamped_value(value)); + // Unsafety invariant: memo is present in memo_map. + unsafe { + return Some(self.extend_memo_lifetime(old_memo)); + } } } diff --git a/src/function/input_outputs.rs b/src/function/input_outputs.rs new file mode 100644 index 00000000..dab66fdf --- /dev/null +++ b/src/function/input_outputs.rs @@ -0,0 +1,21 @@ +use crate::{ + accumulator::accumulated_map::AccumulatedMap, zalsa::Zalsa, zalsa_local::QueryOrigin, Id, +}; + +use super::{Configuration, IngredientImpl}; + +impl IngredientImpl +where + C: Configuration, +{ + pub(super) fn origin(&self, zalsa: &Zalsa, key: Id) -> Option { + self.get_memo_from_table_for(zalsa, key) + .map(|m| m.revisions.origin.clone()) + } + + pub(super) fn accumulated(&self, zalsa: &Zalsa, key: Id) -> Option<&AccumulatedMap> { + // NEXT STEP: stash and refactor `fetch` to return an `&Memo` so we can make this work + self.get_memo_from_table_for(zalsa, key) + .map(|m| &m.revisions.accumulated) + } +} diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 741a34dd..b1d671a3 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -1,6 +1,5 @@ use crate::{ key::DatabaseKeyIndex, - runtime::StampedValue, zalsa::{Zalsa, ZalsaDatabase}, zalsa_local::{ActiveQueryGuard, EdgeKind, QueryOrigin}, AsDynDatabase as _, Id, Revision, @@ -83,7 +82,8 @@ where // backdated. In that case, although we will have computed a new memo, // the value has not logically changed. if old_memo.value.is_some() { - let StampedValue { changed_at, .. } = self.execute(db, active_query, Some(old_memo)); + let memo = self.execute(db, active_query, Some(old_memo)); + let changed_at = memo.revisions.changed_at; return Some(changed_at > revision); } diff --git a/src/function/specify.rs b/src/function/specify.rs index 14b6c3db..73f3d8c3 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -69,6 +69,7 @@ where durability: current_deps.durability, origin: QueryOrigin::Assigned(active_query_key), tracked_struct_ids: Default::default(), + accumulated: Default::default(), }; if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key) { diff --git a/src/hash.rs b/src/hash.rs index 61055b75..d1bb0cf4 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -2,7 +2,6 @@ use std::hash::{BuildHasher, Hash}; pub(crate) type FxHasher = std::hash::BuildHasherDefault; pub(crate) type FxIndexSet = indexmap::IndexSet; -pub(crate) type FxIndexMap = indexmap::IndexMap; pub(crate) type FxDashMap = dashmap::DashMap; pub(crate) type FxLinkedHashSet = hashlink::LinkedHashSet; pub(crate) type FxHashSet = std::collections::HashSet; diff --git a/src/ingredient.rs b/src/ingredient.rs index 2d606715..383fdc6b 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -4,6 +4,7 @@ use std::{ }; use crate::{ + accumulator::accumulated_map::AccumulatedMap, cycle::CycleRecoveryStrategy, zalsa::{IngredientIndex, MemoIngredientIndex}, zalsa_local::QueryOrigin, @@ -42,6 +43,16 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// What were the inputs (if any) that were used to create the value at `key_index`. fn origin(&self, db: &dyn Database, key_index: Id) -> Option; + /// What values were accumulated during the creation of the value at `key_index` + /// (if any). + /// + /// In practice, returns `Some` only for tracked function ingredients. + fn accumulated<'db>( + &'db self, + db: &'db dyn Database, + key_index: Id, + ) -> Option<&'db AccumulatedMap>; + /// Invoked when the value `output_key` should be marked as valid in the current revision. /// This occurs because the value for `executor`, which generated it, was marked as valid /// in the current revision. diff --git a/src/input.rs b/src/input.rs index ce35e1b9..fdad27ac 100644 --- a/src/input.rs +++ b/src/input.rs @@ -108,16 +108,12 @@ impl IngredientImpl { None }; - let id = zalsa_local.allocate( - zalsa.table(), - self.ingredient_index, - Value:: { - fields, - stamps, - memos: Default::default(), - syncs: Default::default(), - }, - ); + let id = zalsa_local.allocate(zalsa.table(), self.ingredient_index, || Value:: { + fields, + stamps, + memos: Default::default(), + syncs: Default::default(), + }); if C::IS_SINGLETON { self.singleton_index.store(Some(id)); @@ -269,6 +265,14 @@ impl Ingredient for IngredientImpl { fn debug_name(&self) -> &'static str { C::DEBUG_NAME } + + fn accumulated<'db>( + &'db self, + _db: &'db dyn Database, + _key_index: Id, + ) -> Option<&'db crate::accumulator::accumulated_map::AccumulatedMap> { + None + } } impl std::fmt::Debug for IngredientImpl { diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 505aaf4f..fd308225 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -96,6 +96,14 @@ where fn debug_name(&self) -> &'static str { C::FIELD_DEBUG_NAMES[self.field_index] } + + fn accumulated<'db>( + &'db self, + _db: &'db dyn Database, + _key_index: Id, + ) -> Option<&'db crate::accumulator::accumulated_map::AccumulatedMap> { + None + } } impl std::fmt::Debug for FieldIngredientImpl diff --git a/src/interned.rs b/src/interned.rs index 9ff667a3..2003520c 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -156,15 +156,11 @@ where dashmap::mapref::entry::Entry::Vacant(entry) => { let zalsa = db.zalsa(); let table = zalsa.table(); - let next_id = zalsa_local.allocate( - table, - self.ingredient_index, - Value:: { - data: internal_data, - memos: Default::default(), - syncs: Default::default(), - }, - ); + let next_id = zalsa_local.allocate(table, self.ingredient_index, || Value:: { + data: internal_data, + memos: Default::default(), + syncs: Default::default(), + }); entry.insert(next_id); C::struct_from_id(next_id) } @@ -259,6 +255,14 @@ where fn debug_name(&self) -> &'static str { C::DEBUG_NAME } + + fn accumulated<'db>( + &'db self, + _db: &'db dyn Database, + _key_index: Id, + ) -> Option<&'db crate::accumulator::accumulated_map::AccumulatedMap> { + None + } } impl std::fmt::Debug for IngredientImpl diff --git a/src/key.rs b/src/key.rs index df49d047..92e63541 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,4 +1,7 @@ -use crate::{cycle::CycleRecoveryStrategy, zalsa::IngredientIndex, Database, Id}; +use crate::{ + accumulator::accumulated_map::AccumulatedMap, cycle::CycleRecoveryStrategy, + zalsa::IngredientIndex, Database, Id, +}; /// An integer that uniquely identifies a particular query instance within the /// database. Used to track dependencies between queries. Fully ordered and @@ -93,9 +96,15 @@ impl DatabaseKeyIndex { self.key_index } - pub(crate) fn cycle_recovery_strategy(&self, db: &dyn Database) -> CycleRecoveryStrategy { + pub(crate) fn cycle_recovery_strategy(self, db: &dyn Database) -> CycleRecoveryStrategy { self.ingredient_index.cycle_recovery_strategy(db) } + + pub(crate) fn accumulated(self, db: &dyn Database) -> Option<&AccumulatedMap> { + db.zalsa() + .lookup_ingredient(self.ingredient_index) + .accumulated(db, self.key_index) + } } impl std::fmt::Debug for DatabaseKeyIndex { diff --git a/src/table.rs b/src/table.rs index 199d2c2a..56a37696 100644 --- a/src/table.rs +++ b/src/table.rs @@ -214,7 +214,10 @@ impl Page { self.data[slot.0].get() } - pub(crate) fn allocate(&self, page: PageIndex, value: T) -> Result { + pub(crate) fn allocate(&self, page: PageIndex, value: V) -> Result + where + V: FnOnce() -> T, + { let guard = self.allocation_lock.lock(); let index = self.allocated.load(); if index == PAGE_LEN { @@ -223,7 +226,7 @@ impl Page { // Initialize entry `index` let data = &self.data[index]; - unsafe { std::ptr::write(data.get(), value) }; + unsafe { std::ptr::write(data.get(), value()) }; // Update the length (this must be done after initialization!) self.allocated.store(index + 1); diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 6b83d0ae..3532a611 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -257,8 +257,7 @@ where let data_hash = crate::hash::hash(&C::id_fields(&fields)); - let (current_deps, disambiguator) = - zalsa_local.disambiguate(self.ingredient_index, Revision::start(), data_hash); + let (current_deps, disambiguator) = zalsa_local.disambiguate(data_hash); let key_struct = KeyStruct { disambiguator, @@ -316,7 +315,7 @@ where id } else { - zalsa_local.allocate::>(zalsa.table(), self.ingredient_index, value()) + zalsa_local.allocate::>(zalsa.table(), self.ingredient_index, value) } } @@ -606,6 +605,14 @@ where } fn reset_for_new_revision(&mut self) {} + + fn accumulated<'db>( + &'db self, + _db: &'db dyn Database, + _key_index: Id, + ) -> Option<&'db crate::accumulator::accumulated_map::AccumulatedMap> { + None + } } impl std::fmt::Debug for IngredientImpl diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index d5c214ad..ff190939 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -112,6 +112,14 @@ where fn debug_name(&self) -> &'static str { C::FIELD_DEBUG_NAMES[self.field_index] } + + fn accumulated<'db>( + &'db self, + _db: &'db dyn Database, + _key_index: Id, + ) -> Option<&'db crate::accumulator::accumulated_map::AccumulatedMap> { + None + } } impl std::fmt::Debug for FieldIngredientImpl diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index e70cd38f..fb4e4048 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -1,6 +1,7 @@ use rustc_hash::FxHashMap; use tracing::debug; +use crate::accumulator::accumulated_map::AccumulatedMap; use crate::active_query::ActiveQuery; use crate::durability::Durability; use crate::key::DatabaseKeyIndex; @@ -12,6 +13,7 @@ use crate::table::Table; use crate::tracked_struct::Disambiguator; use crate::tracked_struct::KeyStruct; use crate::zalsa::IngredientIndex; +use crate::Accumulator; use crate::Cancelled; use crate::Cycle; use crate::Database; @@ -58,7 +60,7 @@ impl ZalsaLocal { &self, table: &Table, ingredient: IngredientIndex, - mut value: T, + mut value: impl FnOnce() -> T, ) -> Id { // Find the most recent page, pushing a page if needed let mut page = *self @@ -125,6 +127,24 @@ impl ZalsaLocal { }) } + /// Add an output to the current query's list of dependencies + /// + /// Returns `Err` if not in a query. + pub(crate) fn accumulate( + &self, + index: IngredientIndex, + value: A, + ) -> Result<(), ()> { + self.with_query_stack(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.accumulated.accumulate(index, value); + Ok(()) + } else { + Err(()) + } + }) + } + /// Add an output to the current query's list of dependencies pub(crate) fn add_output(&self, entity: DependencyIndex) { self.with_query_stack(|stack| { @@ -242,23 +262,12 @@ impl ZalsaLocal { /// * the current dependencies (durability, changed_at) of current query /// * the disambiguator index #[track_caller] - pub(crate) fn disambiguate( - &self, - entity_index: IngredientIndex, - reset_at: Revision, - data_hash: u64, - ) -> (StampedValue<()>, Disambiguator) { + pub(crate) fn disambiguate(&self, data_hash: u64) -> (StampedValue<()>, Disambiguator) { assert!( self.query_in_progress(), "cannot create a tracked struct disambiguator outside of a tracked function" ); - self.report_tracked_read( - DependencyIndex::for_table(entity_index), - Durability::MAX, - reset_at, - ); - self.with_query_stack(|stack| { let top_query = stack.last_mut().unwrap(); let disambiguator = top_query.disambiguate(data_hash); @@ -352,6 +361,8 @@ pub(crate) struct QueryRevisions { /// This is used to seed the next round if the query is /// re-executed. pub(super) tracked_struct_ids: FxHashMap, + + pub(super) accumulated: AccumulatedMap, } impl QueryRevisions {