Skip to content

Commit

Permalink
rework accumulator to store values in memos
Browse files Browse the repository at this point in the history
We used to store values in a central map,
but now each memo has an `AccumulatorMap`
that maps accumulated values (if any).

The primary goals of this change are

* forward compatible with speculative execution
  because it puts more data into tables;
* step towards a refactoring to stop tracking
  outputs in the same array as inputs and thus
  to simplify how we do versioning. We will no
  longer need to walk the function's outputs
  and refresh their versions and so forth because
  they are stored in the function memo and so
  they get refreshed automatically when the memo
  is refreshed.
  • Loading branch information
nikomatsakis committed Sep 21, 2024
1 parent 2caa5cc commit 3687e48
Show file tree
Hide file tree
Showing 17 changed files with 319 additions and 109 deletions.
118 changes: 31 additions & 87 deletions src/accumulator.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
//! Basic test of accumulator functionality.

use std::{
any::Any,
fmt::{self, Debug},
marker::PhantomData,
};

use accumulated::Accumulated;
use accumulated::AnyAccumulated;
use accumulated_map::AccumulatedMap;

use crate::{
cycle::CycleRecoveryStrategy,
hash::FxDashMap,
ingredient::{fmt_index, Ingredient, Jar},
key::DependencyIndex,
plumbing::JarAux,
zalsa::IngredientIndex,
zalsa_local::{QueryOrigin, ZalsaLocal},
Database, DatabaseKeyIndex, Event, EventKind, Id, Revision,
zalsa_local::QueryOrigin,
Database, DatabaseKeyIndex, Id, Revision,
};

pub trait Accumulator: Clone + Debug + Send + Sync + 'static + Sized {
mod accumulated;
pub(crate) mod accumulated_map;

/// Trait implemented on the struct that user annotated with `#[salsa::accumulator]`.
/// The `Self` type is therefore the types to be accumulated.
pub trait Accumulator: Clone + Debug + Send + Sync + Any + Sized {
const DEBUG_NAME: &'static str;

/// Accumulate an instance of this in the database for later retrieval.
Expand Down Expand Up @@ -49,12 +57,7 @@ impl<A: Accumulator> Jar for JarImpl<A> {

pub struct IngredientImpl<A: Accumulator> {
index: IngredientIndex,
map: FxDashMap<DatabaseKeyIndex, AccumulatedValues<A>>,
}

struct AccumulatedValues<A> {
produced_at: Revision,
values: Vec<A>,
phantom: PhantomData<Accumulated<A>>,
}

impl<A: Accumulator> IngredientImpl<A> {
Expand All @@ -72,67 +75,20 @@ impl<A: Accumulator> IngredientImpl<A> {

pub fn new(index: IngredientIndex) -> Self {
Self {
map: FxDashMap::default(),
index,
phantom: PhantomData,
}
}

fn dependency_index(&self) -> DependencyIndex {
DependencyIndex {
ingredient_index: self.index,
key_index: None,
}
}

pub fn push(&self, db: &dyn crate::Database, value: A) {
let state = db.zalsa_local();
let current_revision = db.zalsa().current_revision();
let (active_query, _) = match state.active_query() {
Some(pair) => pair,
None => {
panic!("cannot accumulate values outside of an active query")
}
};

let mut accumulated_values = self.map.entry(active_query).or_insert(AccumulatedValues {
values: vec![],
produced_at: current_revision,
});

// When we call `push' in a query, we will add the accumulator to the output of the query.
// If we find here that this accumulator is not the output of the query,
// we can say that the accumulated values we stored for this query is out of date.
if !state.is_output_of_active_query(self.dependency_index()) {
accumulated_values.values.truncate(0);
accumulated_values.produced_at = current_revision;
pub fn push(&self, db: &dyn Database, value: A) {
let zalsa_local = db.zalsa_local();
if let Err(()) = zalsa_local.accumulate(self.index, value) {
panic!("cannot accumulate values outside of an active tracked function");
}

state.add_output(self.dependency_index());
accumulated_values.values.push(value);
}

pub(crate) fn produced_by(
&self,
current_revision: Revision,
local_state: &ZalsaLocal,
query: DatabaseKeyIndex,
output: &mut Vec<A>,
) {
if let Some(v) = self.map.get(&query) {
// FIXME: We don't currently have a good way to identify the value that was read.
// You can't report is as a tracked read of `query`, because the return value of query is not being read here --
// instead it is the set of values accumuated by `query`.
local_state.report_untracked_read(current_revision);

let AccumulatedValues {
values,
produced_at,
} = v.value();

if *produced_at == current_revision {
output.extend(values.iter().cloned());
}
}
pub fn index(&self) -> IngredientIndex {
self.index
}
}

Expand Down Expand Up @@ -160,34 +116,18 @@ impl<A: Accumulator> Ingredient for IngredientImpl<A> {

fn mark_validated_output(
&self,
db: &dyn Database,
executor: DatabaseKeyIndex,
output_key: Option<crate::Id>,
_db: &dyn Database,
_executor: DatabaseKeyIndex,
_output_key: Option<crate::Id>,
) {
assert!(output_key.is_none());
let current_revision = db.zalsa().current_revision();
if let Some(mut v) = self.map.get_mut(&executor) {
// The value is still valid in the new revision.
v.produced_at = current_revision;
}
}

fn remove_stale_output(
&self,
db: &dyn Database,
executor: DatabaseKeyIndex,
stale_output_key: Option<crate::Id>,
_db: &dyn Database,
_executor: DatabaseKeyIndex,
_stale_output_key: Option<crate::Id>,
) {
assert!(stale_output_key.is_none());
if self.map.remove(&executor).is_some() {
db.salsa_event(&|| Event {
thread_id: std::thread::current().id(),
kind: EventKind::DidDiscardAccumulated {
executor_key: executor,
accumulator: self.dependency_index(),
},
})
}
}

fn requires_reset_for_new_revision(&self) -> bool {
Expand All @@ -205,6 +145,10 @@ impl<A: Accumulator> Ingredient for IngredientImpl<A> {
fn debug_name(&self) -> &'static str {
A::DEBUG_NAME
}

fn accumulated(&self, _db: &dyn Database, _key_index: Id) -> Option<&AccumulatedMap> {
None
}
}

impl<A> std::fmt::Debug for IngredientImpl<A>
Expand Down
60 changes: 60 additions & 0 deletions src/accumulator/accumulated.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use std::any::Any;
use std::fmt::Debug;

use super::Accumulator;

#[derive(Clone, Debug)]
pub(crate) struct Accumulated<A: Accumulator> {
values: Vec<A>,
}

pub(crate) trait AnyAccumulated: Any + Debug + Send + Sync {
fn as_dyn_any(&self) -> &dyn Any;
fn as_dyn_any_mut(&mut self) -> &mut dyn Any;
fn cloned(&self) -> Box<dyn AnyAccumulated>;
}

impl<A: Accumulator> Accumulated<A> {
pub fn push(&mut self, value: A) {
self.values.push(value);
}

pub fn extend_with_accumulated(&self, values: &mut Vec<A>) {
values.extend_from_slice(&self.values);
}
}

impl<A: Accumulator> Default for Accumulated<A> {
fn default() -> Self {
Self {
values: Default::default(),
}
}
}

impl<A> AnyAccumulated for Accumulated<A>
where
A: Accumulator,
{
fn as_dyn_any(&self) -> &dyn Any {
self
}

fn as_dyn_any_mut(&mut self) -> &mut dyn Any {
self
}

fn cloned(&self) -> Box<dyn AnyAccumulated> {
let this: Self = self.clone();
Box::new(this)
}
}

impl dyn AnyAccumulated {
pub fn accumulate<A: Accumulator>(&mut self, value: A) {
self.as_dyn_any_mut()
.downcast_mut::<Accumulated<A>>()
.unwrap()
.push(value);
}
}
46 changes: 46 additions & 0 deletions src/accumulator/accumulated_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use rustc_hash::FxHashMap;

use crate::IngredientIndex;

use super::{accumulated::Accumulated, Accumulator, AnyAccumulated};

#[derive(Default, Debug)]
pub struct AccumulatedMap {
map: FxHashMap<IngredientIndex, Box<dyn AnyAccumulated>>,
}

impl AccumulatedMap {
pub fn accumulate<A: Accumulator>(&mut self, index: IngredientIndex, value: A) {
self.map
.entry(index)
.or_insert_with(|| <Box<Accumulated<A>>>::default())
.accumulate(value);
}

pub fn extend_with_accumulated<A: Accumulator>(
&self,
index: IngredientIndex,
output: &mut Vec<A>,
) {
let Some(a) = self.map.get(&index) else {
return;
};

a.as_dyn_any()
.downcast_ref::<Accumulated<A>>()
.unwrap()
.extend_with_accumulated(output);
}
}

impl Clone for AccumulatedMap {
fn clone(&self) -> Self {
Self {
map: self
.map
.iter()
.map(|(&key, value)| (key, value.cloned()))
.collect(),
}
}
}
7 changes: 7 additions & 0 deletions src/active_query.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use rustc_hash::FxHashMap;

use crate::{
accumulator::accumulated_map::AccumulatedMap,
durability::Durability,
hash::FxIndexSet,
key::{DatabaseKeyIndex, DependencyIndex},
Expand Down Expand Up @@ -49,6 +50,10 @@ pub(crate) struct ActiveQuery {
/// Map from tracked struct keys (which include the hash + disambiguator) to their
/// final id.
pub(crate) tracked_struct_ids: FxHashMap<KeyStruct, Id>,

/// Stores the values accumulated to the given ingredient.
/// The type of accumulated value is erased but known to the ingredient.
pub(crate) accumulated: AccumulatedMap,
}

impl ActiveQuery {
Expand All @@ -62,6 +67,7 @@ impl ActiveQuery {
cycle: None,
disambiguator_map: Default::default(),
tracked_struct_ids: Default::default(),
accumulated: Default::default(),
}
}

Expand Down Expand Up @@ -118,6 +124,7 @@ impl ActiveQuery {
origin,
durability: self.durability,
tracked_struct_ids: self.tracked_struct_ids,
accumulated: self.accumulated,
}
}

Expand Down
10 changes: 10 additions & 0 deletions src/function.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{any::Any, fmt, sync::Arc};

use crate::{
accumulator::accumulated_map::AccumulatedMap,
cycle::CycleRecoveryStrategy,
ingredient::fmt_index,
key::DatabaseKeyIndex,
Expand Down Expand Up @@ -243,6 +244,15 @@ where
fn debug_name(&self) -> &'static str {
C::DEBUG_NAME
}

fn accumulated<'db>(
&'db self,
db: &'db dyn Database,
key_index: Id,
) -> Option<&'db AccumulatedMap> {
let db = db.as_view::<C::DbView>();
self.accumulated_map(db, key_index)
}
}

impl<C> std::fmt::Debug for IngredientImpl<C>
Expand Down
Loading

0 comments on commit 3687e48

Please sign in to comment.