Skip to content

Commit

Permalink
feat(trie): reimplement in-memory trie cursors (#9305)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkrasiuk authored Jul 12, 2024
1 parent a617bd0 commit da0efbe
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 24 deletions.
6 changes: 3 additions & 3 deletions crates/trie/trie/src/forward_cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ impl<'a, K, V> ForwardInMemoryCursor<'a, K, V> {

impl<'a, K, V> ForwardInMemoryCursor<'a, K, V>
where
K: PartialOrd + Copy,
V: Copy,
K: PartialOrd + Clone,
V: Clone,
{
/// Advances the cursor forward while `comparator` returns `true` or until the collection is
/// exhausted. Returns the first entry for which `comparator` returns `false` or `None`.
Expand All @@ -34,7 +34,7 @@ where
self.index += 1;
entry = self.entries.get(self.index);
}
entry.copied()
entry.cloned()
}

/// Returns the first entry from the current cursor position that's greater or equal to the
Expand Down
14 changes: 12 additions & 2 deletions crates/trie/trie/src/trie_cursor/database_cursors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl<'a, TX: DbTx> TrieCursorFactory for &'a TX {

/// A cursor over the account trie.
#[derive(Debug)]
pub struct DatabaseAccountTrieCursor<C>(C);
pub struct DatabaseAccountTrieCursor<C>(pub(crate) C);

impl<C> DatabaseAccountTrieCursor<C> {
/// Create a new account trie cursor.
Expand Down Expand Up @@ -59,6 +59,11 @@ where
Ok(self.0.seek(StoredNibbles(key))?.map(|value| (value.0 .0, value.1 .0)))
}

/// Move the cursor to the next entry and return it.
fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
Ok(self.0.next()?.map(|value| (value.0 .0, value.1 .0)))
}

/// Retrieves the current key in the cursor.
fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
Ok(self.0.current()?.map(|(k, _)| k.0))
Expand All @@ -83,7 +88,7 @@ impl<C> DatabaseStorageTrieCursor<C> {

impl<C> TrieCursor for DatabaseStorageTrieCursor<C>
where
C: DbDupCursorRO<tables::StoragesTrie> + DbCursorRO<tables::StoragesTrie> + Send + Sync,
C: DbCursorRO<tables::StoragesTrie> + DbDupCursorRO<tables::StoragesTrie> + Send + Sync,
{
/// Seeks an exact match for the given key in the storage trie.
fn seek_exact(
Expand All @@ -108,6 +113,11 @@ where
.map(|value| (value.nibbles.0, value.node)))
}

/// Move the cursor to the next entry and return it.
fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
Ok(self.cursor.next_dup()?.map(|(_, v)| (v.nibbles.0, v.node)))
}

/// Retrieves the current value in the storage trie cursor.
fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
Ok(self.cursor.current()?.map(|(_, v)| v.nibbles.0))
Expand Down
233 changes: 222 additions & 11 deletions crates/trie/trie/src/trie_cursor/in_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub struct InMemoryAccountTrieCursor<'a, C> {
last_key: Option<Nibbles>,
}

impl<'a, C> InMemoryAccountTrieCursor<'a, C> {
impl<'a, C: TrieCursor> InMemoryAccountTrieCursor<'a, C> {
const fn new(cursor: C, trie_updates: &'a TrieUpdatesSorted) -> Self {
let in_memory_cursor = ForwardInMemoryCursor::new(&trie_updates.account_nodes);
Self {
Expand All @@ -71,25 +71,86 @@ impl<'a, C> InMemoryAccountTrieCursor<'a, C> {
last_key: None,
}
}

fn seek_inner(
&mut self,
key: Nibbles,
exact: bool,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let in_memory = self.in_memory_cursor.seek(&key);
if exact && in_memory.as_ref().map_or(false, |entry| entry.0 == key) {
return Ok(in_memory)
}

// Reposition the cursor to the first greater or equal node that wasn't removed.
let mut db_entry = self.cursor.seek(key.clone())?;
while db_entry.as_ref().map_or(false, |entry| self.removed_nodes.contains(&entry.0)) {
db_entry = self.cursor.next()?;
}

// Compare two entries and return the lowest.
// If seek is exact, filter the entry for exact key match.
Ok(compare_trie_node_entries(in_memory, db_entry)
.filter(|(nibbles, _)| !exact || nibbles == &key))
}

fn next_inner(
&mut self,
last: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let in_memory = self.in_memory_cursor.first_after(&last);

// Reposition the cursor to the first greater or equal node that wasn't removed.
let mut db_entry = self.cursor.seek(last.clone())?;
while db_entry
.as_ref()
.map_or(false, |entry| entry.0 < last || self.removed_nodes.contains(&entry.0))
{
db_entry = self.cursor.next()?;
}

// Compare two entries and return the lowest.
Ok(compare_trie_node_entries(in_memory, db_entry))
}
}

impl<'a, C: TrieCursor> TrieCursor for InMemoryAccountTrieCursor<'a, C> {
fn seek_exact(
&mut self,
_key: Nibbles,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
unimplemented!()
let entry = self.seek_inner(key, true)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| nibbles.clone());
Ok(entry)
}

fn seek(
&mut self,
_key: Nibbles,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
unimplemented!()
let entry = self.seek_inner(key, false)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| nibbles.clone());
Ok(entry)
}

fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let next = match &self.last_key {
Some(last) => {
let entry = self.next_inner(last.clone())?;
self.last_key = entry.as_ref().map(|entry| entry.0.clone());
entry
}
// no previous entry was found
None => None,
};
Ok(next)
}

fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
unimplemented!()
match &self.last_key {
Some(key) => Ok(Some(key.clone())),
None => self.cursor.current(),
}
}
}

Expand Down Expand Up @@ -128,22 +189,172 @@ impl<'a, C> InMemoryStorageTrieCursor<'a, C> {
}
}

impl<'a, C: TrieCursor> InMemoryStorageTrieCursor<'a, C> {
fn seek_inner(
&mut self,
key: Nibbles,
exact: bool,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let in_memory = self.in_memory_cursor.as_mut().and_then(|c| c.seek(&key));
if self.storage_trie_cleared ||
(exact && in_memory.as_ref().map_or(false, |entry| entry.0 == key))
{
return Ok(in_memory)
}

// Reposition the cursor to the first greater or equal node that wasn't removed.
let mut db_entry = self.cursor.seek(key.clone())?;
while db_entry.as_ref().map_or(false, |entry| {
self.removed_nodes.as_ref().map_or(false, |r| r.contains(&entry.0))
}) {
db_entry = self.cursor.next()?;
}

// Compare two entries and return the lowest.
// If seek is exact, filter the entry for exact key match.
Ok(compare_trie_node_entries(in_memory, db_entry)
.filter(|(nibbles, _)| !exact || nibbles == &key))
}

fn next_inner(
&mut self,
last: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let in_memory = self.in_memory_cursor.as_mut().and_then(|c| c.first_after(&last));

// Reposition the cursor to the first greater or equal node that wasn't removed.
let mut db_entry = self.cursor.seek(last.clone())?;
while db_entry.as_ref().map_or(false, |entry| {
entry.0 < last || self.removed_nodes.as_ref().map_or(false, |r| r.contains(&entry.0))
}) {
db_entry = self.cursor.next()?;
}

// Compare two entries and return the lowest.
Ok(compare_trie_node_entries(in_memory, db_entry))
}
}

impl<'a, C: TrieCursor> TrieCursor for InMemoryStorageTrieCursor<'a, C> {
fn seek_exact(
&mut self,
_key: Nibbles,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
unimplemented!()
let entry = self.seek_inner(key, true)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| nibbles.clone());
Ok(entry)
}

fn seek(
&mut self,
_key: Nibbles,
key: Nibbles,
) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
unimplemented!()
let entry = self.seek_inner(key, false)?;
self.last_key = entry.as_ref().map(|(nibbles, _)| nibbles.clone());
Ok(entry)
}

fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
let next = match &self.last_key {
Some(last) => {
let entry = self.next_inner(last.clone())?;
self.last_key = entry.as_ref().map(|entry| entry.0.clone());
entry
}
// no previous entry was found
None => None,
};
Ok(next)
}

fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError> {
unimplemented!()
match &self.last_key {
Some(key) => Ok(Some(key.clone())),
None => self.cursor.current(),
}
}
}

/// Return the node with the lowest nibbles.
///
/// Given the next in-memory and database entries, return the smallest of the two.
/// If the node keys are the same, the in-memory entry is given precedence.
fn compare_trie_node_entries(
mut in_memory_item: Option<(Nibbles, BranchNodeCompact)>,
mut db_item: Option<(Nibbles, BranchNodeCompact)>,
) -> Option<(Nibbles, BranchNodeCompact)> {
if let Some((in_memory_entry, db_entry)) = in_memory_item.as_ref().zip(db_item.as_ref()) {
// If both are not empty, return the smallest of the two
// In-memory is given precedence if keys are equal
if in_memory_entry.0 <= db_entry.0 {
in_memory_item.take()
} else {
db_item.take()
}
} else {
// Return either non-empty entry
db_item.or(in_memory_item)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
prefix_set::{PrefixSetMut, TriePrefixSets},
test_utils::state_root_prehashed,
StateRoot,
};
use proptest::prelude::*;
use reth_db::{cursor::DbCursorRW, tables, transaction::DbTxMut};
use reth_primitives::{Account, U256};
use reth_provider::test_utils::create_test_provider_factory;
use std::collections::BTreeMap;

proptest! {
#![proptest_config(ProptestConfig {
cases: 128, ..ProptestConfig::default()
})]

#[test]
fn fuzz_in_memory_nodes(mut init_state: BTreeMap<B256, U256>, mut updated_state: BTreeMap<B256, U256>) {
let factory = create_test_provider_factory();
let provider = factory.provider_rw().unwrap();
let mut hashed_account_cursor = provider.tx_ref().cursor_write::<tables::HashedAccounts>().unwrap();

// Insert init state into database
for (hashed_address, balance) in init_state.clone() {
hashed_account_cursor.upsert(hashed_address, Account { balance, ..Default::default() }).unwrap();
}

// Compute initial root and updates
let (_, trie_updates) = StateRoot::from_tx(provider.tx_ref())
.root_with_updates()
.unwrap();

// Insert state updates into database
let mut changes = PrefixSetMut::default();
for (hashed_address, balance) in updated_state.clone() {
hashed_account_cursor.upsert(hashed_address, Account { balance, ..Default::default() }).unwrap();
changes.insert(Nibbles::unpack(hashed_address));
}

// Compute root with in-memory trie nodes overlay
let (state_root, _) = StateRoot::from_tx(provider.tx_ref())
.with_prefix_sets(TriePrefixSets { account_prefix_set: changes.freeze(), ..Default::default() })
.with_trie_cursor_factory(InMemoryTrieCursorFactory::new(provider.tx_ref(), &trie_updates.into_sorted()))
.root_with_updates()
.unwrap();

// Verify the result
let mut state = BTreeMap::default();
state.append(&mut init_state);
state.append(&mut updated_state);
let expected_root = state_root_prehashed(
state.iter().map(|(&key, &balance)| (key, (Account { balance, ..Default::default() }, std::iter::empty())))
);
assert_eq!(expected_root, state_root);

}
}
}
3 changes: 3 additions & 0 deletions crates/trie/trie/src/trie_cursor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ pub trait TrieCursor: Send + Sync {
fn seek(&mut self, key: Nibbles)
-> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError>;

/// Move the cursor to the next key.
fn next(&mut self) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError>;

/// Get the current entry.
fn current(&mut self) -> Result<Option<Nibbles>, DatabaseError>;
}
Loading

0 comments on commit da0efbe

Please sign in to comment.