diff --git a/crates/trie/parallel/benches/root.rs b/crates/trie/parallel/benches/root.rs index f5dc2566fdc8..744683b9a586 100644 --- a/crates/trie/parallel/benches/root.rs +++ b/crates/trie/parallel/benches/root.rs @@ -77,7 +77,15 @@ pub fn calculate_state_root(c: &mut Criterion) { // async root group.bench_function(BenchmarkId::new("async root", size), |b| { b.to_async(&runtime).iter_with_setup( - || AsyncStateRoot::new(view.clone(), blocking_pool.clone(), updated_state.clone()), + || { + AsyncStateRoot::new( + view.clone(), + blocking_pool.clone(), + Default::default(), + updated_state.clone(), + updated_state.construct_prefix_sets().freeze(), + ) + }, |calculator| calculator.incremental_root(), ); }); diff --git a/crates/trie/parallel/src/async_root.rs b/crates/trie/parallel/src/async_root.rs index 179c7dabadc2..600a3b9a2b96 100644 --- a/crates/trie/parallel/src/async_root.rs +++ b/crates/trie/parallel/src/async_root.rs @@ -12,7 +12,8 @@ use reth_tasks::pool::BlockingTaskPool; use reth_trie::{ hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory}, node_iter::{TrieElement, TrieNodeIter}, - trie_cursor::TrieCursorFactory, + prefix_set::TriePrefixSets, + trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory}, updates::TrieUpdates, walker::TrieWalker, HashBuilder, HashedPostState, Nibbles, StorageRoot, TrieAccount, @@ -41,8 +42,12 @@ pub struct AsyncStateRoot { view: ConsistentDbView, /// Blocking task pool. blocking_pool: BlockingTaskPool, + /// Cached trie nodes. + trie_nodes: TrieUpdates, /// Changed hashed state. hashed_state: HashedPostState, + /// A set of prefix sets that have changed. + prefix_sets: TriePrefixSets, /// Parallel state root metrics. #[cfg(feature = "metrics")] metrics: ParallelStateRootMetrics, @@ -53,12 +58,16 @@ impl AsyncStateRoot { pub fn new( view: ConsistentDbView, blocking_pool: BlockingTaskPool, + trie_nodes: TrieUpdates, hashed_state: HashedPostState, + prefix_sets: TriePrefixSets, ) -> Self { Self { view, blocking_pool, + trie_nodes, hashed_state, + prefix_sets, #[cfg(feature = "metrics")] metrics: ParallelStateRootMetrics::default(), } @@ -86,12 +95,15 @@ where retain_updates: bool, ) -> Result<(B256, TrieUpdates), AsyncStateRootError> { let mut tracker = ParallelTrieTracker::default(); - let prefix_sets = self.hashed_state.construct_prefix_sets().freeze(); + let trie_nodes_sorted = Arc::new(self.trie_nodes.into_sorted()); + let hashed_state_sorted = Arc::new(self.hashed_state.into_sorted()); let storage_root_targets = StorageRootTargets::new( - self.hashed_state.accounts.keys().copied(), - prefix_sets.storage_prefix_sets, + self.prefix_sets + .account_prefix_set + .iter() + .map(|nibbles| B256::from_slice(&nibbles.pack())), + self.prefix_sets.storage_prefix_sets, ); - let hashed_state_sorted = Arc::new(self.hashed_state.into_sorted()); // Pre-calculate storage roots async for accounts which were changed. tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64); @@ -102,14 +114,18 @@ where { let view = self.view.clone(); let hashed_state_sorted = hashed_state_sorted.clone(); + let trie_nodes_sorted = trie_nodes_sorted.clone(); #[cfg(feature = "metrics")] let metrics = self.metrics.storage_trie.clone(); let handle = self.blocking_pool.spawn_fifo(move || -> Result<_, AsyncStateRootError> { - let provider = view.provider_ro()?; - let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider.tx_ref()); + let provider_ro = view.provider_ro()?; + let trie_cursor_factory = InMemoryTrieCursorFactory::new( + DatabaseTrieCursorFactory::new(provider_ro.tx_ref()), + &trie_nodes_sorted, + ); let hashed_state = HashedPostStateCursorFactory::new( - DatabaseHashedCursorFactory::new(provider.tx_ref()), + DatabaseHashedCursorFactory::new(provider_ro.tx_ref()), &hashed_state_sorted, ); Ok(StorageRoot::new_hashed( @@ -129,16 +145,18 @@ where let mut trie_updates = TrieUpdates::default(); let provider_ro = self.view.provider_ro()?; - let tx = provider_ro.tx_ref(); - let trie_cursor_factory = DatabaseTrieCursorFactory::new(tx); + let trie_cursor_factory = InMemoryTrieCursorFactory::new( + DatabaseTrieCursorFactory::new(provider_ro.tx_ref()), + &trie_nodes_sorted, + ); let hashed_cursor_factory = HashedPostStateCursorFactory::new( - DatabaseHashedCursorFactory::new(tx), + DatabaseHashedCursorFactory::new(provider_ro.tx_ref()), &hashed_state_sorted, ); let walker = TrieWalker::new( trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?, - prefix_sets.account_prefix_set, + self.prefix_sets.account_prefix_set, ) .with_deletions_retained(retain_updates); let mut account_node_iter = TrieNodeIter::new( @@ -190,7 +208,7 @@ where trie_updates.finalize( account_node_iter.walker, hash_builder, - prefix_sets.destroyed_accounts, + self.prefix_sets.destroyed_accounts, ); let stats = tracker.finish(); @@ -290,7 +308,9 @@ mod tests { AsyncStateRoot::new( consistent_view.clone(), blocking_pool.clone(), - HashedPostState::default() + Default::default(), + HashedPostState::default(), + Default::default(), ) .incremental_root() .await @@ -323,11 +343,18 @@ mod tests { } } + let prefix_sets = hashed_state.construct_prefix_sets().freeze(); assert_eq!( - AsyncStateRoot::new(consistent_view.clone(), blocking_pool.clone(), hashed_state) - .incremental_root() - .await - .unwrap(), + AsyncStateRoot::new( + consistent_view.clone(), + blocking_pool.clone(), + Default::default(), + hashed_state, + prefix_sets + ) + .incremental_root() + .await + .unwrap(), test_utils::state_root(state) ); }