Skip to content

Commit 48fee88

Browse files
authored
feat(engine): integrate state root task and comment it (#13265)
1 parent e663f95 commit 48fee88

File tree

3 files changed

+142
-56
lines changed

3 files changed

+142
-56
lines changed

crates/engine/tree/benches/state_root_task.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ use revm_primitives::{
2222
Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot, HashMap,
2323
B256, KECCAK_EMPTY, U256,
2424
};
25-
use std::sync::Arc;
2625

2726
#[derive(Debug, Clone)]
2827
struct BenchParams {
@@ -137,16 +136,15 @@ fn bench_state_root(c: &mut Criterion) {
137136
let state_updates = create_bench_state_updates(params);
138137
setup_provider(&factory, &state_updates).expect("failed to setup provider");
139138

140-
let trie_input = Arc::new(TrieInput::from_state(Default::default()));
141-
142-
let config = StateRootConfig {
143-
consistent_view: ConsistentDbView::new(factory, None),
144-
input: trie_input,
145-
};
139+
let trie_input = TrieInput::from_state(Default::default());
140+
let config = StateRootConfig::new_from_input(
141+
ConsistentDbView::new(factory, None),
142+
trie_input,
143+
);
146144
let provider = config.consistent_view.provider_ro().unwrap();
147-
let nodes_sorted = config.input.nodes.clone().into_sorted();
148-
let state_sorted = config.input.state.clone().into_sorted();
149-
let prefix_sets = Arc::new(config.input.prefix_sets.clone());
145+
let nodes_sorted = config.nodes_sorted.clone();
146+
let state_sorted = config.state_sorted.clone();
147+
let prefix_sets = config.prefix_sets.clone();
150148

151149
(config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets)
152150
},

crates/engine/tree/src/tree/mod.rs

Lines changed: 82 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2224,13 +2224,47 @@ where
22242224

22252225
let exec_time = Instant::now();
22262226

2227-
// TODO: create StateRootTask with the receiving end of a channel and
2228-
// pass the sending end of the channel to the state hook.
2229-
let noop_state_hook = |_state: &EvmState| {};
2227+
let persistence_not_in_progress = !self.persistence_state.in_progress();
2228+
2229+
// TODO: uncomment to use StateRootTask
2230+
2231+
// let (state_root_handle, state_hook) = if persistence_not_in_progress {
2232+
// let consistent_view = ConsistentDbView::new_with_latest_tip(self.provider.clone())?;
2233+
//
2234+
// let state_root_config = StateRootConfig::new_from_input(
2235+
// consistent_view.clone(),
2236+
// self.compute_trie_input(consistent_view, block.header().parent_hash())
2237+
// .map_err(ParallelStateRootError::into)?,
2238+
// );
2239+
//
2240+
// let provider_ro = consistent_view.provider_ro()?;
2241+
// let nodes_sorted = state_root_config.nodes_sorted.clone();
2242+
// let state_sorted = state_root_config.state_sorted.clone();
2243+
// let prefix_sets = state_root_config.prefix_sets.clone();
2244+
// let blinded_provider_factory = ProofBlindedProviderFactory::new(
2245+
// InMemoryTrieCursorFactory::new(
2246+
// DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
2247+
// &nodes_sorted,
2248+
// ),
2249+
// HashedPostStateCursorFactory::new(
2250+
// DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
2251+
// &state_sorted,
2252+
// ),
2253+
// prefix_sets,
2254+
// );
2255+
//
2256+
// let state_root_task = StateRootTask::new(state_root_config,
2257+
// blinded_provider_factory); let state_hook = state_root_task.state_hook();
2258+
// (Some(state_root_task.spawn(scope)), Box::new(state_hook) as Box<dyn OnStateHook>)
2259+
// } else {
2260+
// (None, Box::new(|_state: &EvmState| {}) as Box<dyn OnStateHook>)
2261+
// };
2262+
let state_hook = Box::new(|_state: &EvmState| {});
2263+
22302264
let output = self.metrics.executor.execute_metered(
22312265
executor,
22322266
(&block, U256::MAX).into(),
2233-
Box::new(noop_state_hook),
2267+
state_hook,
22342268
)?;
22352269

22362270
trace!(target: "engine::tree", elapsed=?exec_time.elapsed(), ?block_number, "Executed block");
@@ -2253,33 +2287,47 @@ where
22532287

22542288
trace!(target: "engine::tree", block=?sealed_block.num_hash(), "Calculating block state root");
22552289
let root_time = Instant::now();
2256-
let mut state_root_result = None;
2257-
2258-
// TODO: switch to calculate state root using `StateRootTask`.
22592290

22602291
// We attempt to compute state root in parallel if we are currently not persisting anything
22612292
// to database. This is safe, because the database state cannot change until we
22622293
// finish parallel computation. It is important that nothing is being persisted as
22632294
// we are computing in parallel, because we initialize a different database transaction
22642295
// per thread and it might end up with a different view of the database.
2265-
let persistence_in_progress = self.persistence_state.in_progress();
2266-
if !persistence_in_progress {
2267-
state_root_result = match self
2268-
.compute_state_root_parallel(block.header().parent_hash(), &hashed_state)
2269-
{
2270-
Ok((state_root, trie_output)) => Some((state_root, trie_output)),
2296+
let state_root_result = if persistence_not_in_progress {
2297+
// TODO: uncomment to use StateRootTask
2298+
2299+
// if let Some(state_root_handle) = state_root_handle {
2300+
// match state_root_handle.wait_for_result() {
2301+
// Ok((task_state_root, task_trie_updates)) => {
2302+
// info!(
2303+
// target: "engine::tree",
2304+
// block = ?sealed_block.num_hash(),
2305+
// ?task_state_root,
2306+
// "State root task finished"
2307+
// );
2308+
// }
2309+
// Err(error) => {
2310+
// info!(target: "engine::tree", ?error, "Failed to wait for state root task
2311+
// result"); }
2312+
// }
2313+
// }
2314+
2315+
match self.compute_state_root_parallel(block.header().parent_hash(), &hashed_state) {
2316+
Ok(result) => Some(result),
22712317
Err(ParallelStateRootError::Provider(ProviderError::ConsistentView(error))) => {
22722318
debug!(target: "engine", %error, "Parallel state root computation failed consistency check, falling back");
22732319
None
22742320
}
22752321
Err(error) => return Err(InsertBlockErrorKindTwo::Other(Box::new(error))),
2276-
};
2277-
}
2322+
}
2323+
} else {
2324+
None
2325+
};
22782326

22792327
let (state_root, trie_output) = if let Some(result) = state_root_result {
22802328
result
22812329
} else {
2282-
debug!(target: "engine::tree", block=?sealed_block.num_hash(), persistence_in_progress, "Failed to compute state root in parallel");
2330+
debug!(target: "engine::tree", block=?sealed_block.num_hash(), ?persistence_not_in_progress, "Failed to compute state root in parallel");
22832331
state_provider.state_root_with_updates(hashed_state.clone())?
22842332
};
22852333

@@ -2344,14 +2392,25 @@ where
23442392
parent_hash: B256,
23452393
hashed_state: &HashedPostState,
23462394
) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
2347-
// TODO: when we switch to calculate state root using `StateRootTask` this
2348-
// method can be still useful to calculate the required `TrieInput` to
2349-
// create the task.
23502395
let consistent_view = ConsistentDbView::new_with_latest_tip(self.provider.clone())?;
2396+
2397+
let mut input = self.compute_trie_input(consistent_view.clone(), parent_hash)?;
2398+
// Extend with block we are validating root for.
2399+
input.append_ref(hashed_state);
2400+
2401+
ParallelStateRoot::new(consistent_view, input).incremental_root_with_updates()
2402+
}
2403+
2404+
/// Computes the trie input at the provided parent hash.
2405+
fn compute_trie_input(
2406+
&self,
2407+
consistent_view: ConsistentDbView<P>,
2408+
parent_hash: B256,
2409+
) -> Result<TrieInput, ParallelStateRootError> {
23512410
let mut input = TrieInput::default();
23522411

23532412
if let Some((historical, blocks)) = self.state.tree_state.blocks_by_hash(parent_hash) {
2354-
debug!(target: "engine::tree", %parent_hash, %historical, "Calculating state root in parallel, parent found in memory");
2413+
debug!(target: "engine::tree", %parent_hash, %historical, "Parent found in memory");
23552414
// Retrieve revert state for historical block.
23562415
let revert_state = consistent_view.revert_state(historical)?;
23572416
input.append(revert_state);
@@ -2362,15 +2421,12 @@ where
23622421
}
23632422
} else {
23642423
// The block attaches to canonical persisted parent.
2365-
debug!(target: "engine::tree", %parent_hash, "Calculating state root in parallel, parent found in disk");
2424+
debug!(target: "engine::tree", %parent_hash, "Parent found on disk");
23662425
let revert_state = consistent_view.revert_state(parent_hash)?;
23672426
input.append(revert_state);
23682427
}
23692428

2370-
// Extend with block we are validating root for.
2371-
input.append_ref(hashed_state);
2372-
2373-
ParallelStateRoot::new(consistent_view, input).incremental_root_with_updates()
2429+
Ok(input)
23742430
}
23752431

23762432
/// Handles an error that occurred while inserting a block.
@@ -2648,7 +2704,7 @@ mod tests {
26482704
use reth_primitives::{Block, BlockExt, EthPrimitives};
26492705
use reth_provider::test_utils::MockEthProvider;
26502706
use reth_rpc_types_compat::engine::{block_to_payload_v1, payload::block_to_payload_v3};
2651-
use reth_trie::updates::TrieUpdates;
2707+
use reth_trie::{updates::TrieUpdates, HashedPostState};
26522708
use std::{
26532709
str::FromStr,
26542710
sync::mpsc::{channel, Sender},

crates/engine/tree/src/tree/root.rs

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@ use reth_provider::{
1010
StateCommitmentProvider,
1111
};
1212
use reth_trie::{
13-
proof::Proof, updates::TrieUpdates, HashedPostState, HashedStorage, MultiProof,
14-
MultiProofTargets, Nibbles, TrieInput,
13+
hashed_cursor::HashedPostStateCursorFactory,
14+
prefix_set::TriePrefixSetsMut,
15+
proof::Proof,
16+
trie_cursor::InMemoryTrieCursorFactory,
17+
updates::{TrieUpdates, TrieUpdatesSorted},
18+
HashedPostState, HashedPostStateSorted, HashedStorage, MultiProof, MultiProofTargets, Nibbles,
19+
TrieInput,
1520
};
16-
use reth_trie_db::DatabaseProof;
21+
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseProof, DatabaseTrieCursorFactory};
1722
use reth_trie_parallel::root::ParallelStateRootError;
1823
use reth_trie_sparse::{
1924
blinded::{BlindedProvider, BlindedProviderFactory},
@@ -72,12 +77,31 @@ impl StateRootHandle {
7277
}
7378

7479
/// Common configuration for state root tasks
75-
#[derive(Debug)]
80+
#[derive(Debug, Clone)]
7681
pub struct StateRootConfig<Factory> {
7782
/// View over the state in the database.
7883
pub consistent_view: ConsistentDbView<Factory>,
79-
/// Latest trie input.
80-
pub input: Arc<TrieInput>,
84+
/// The sorted collection of cached in-memory intermediate trie nodes that
85+
/// can be reused for computation.
86+
pub nodes_sorted: Arc<TrieUpdatesSorted>,
87+
/// The sorted in-memory overlay hashed state.
88+
pub state_sorted: Arc<HashedPostStateSorted>,
89+
/// The collection of prefix sets for the computation. Since the prefix sets _always_
90+
/// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
91+
/// if we have cached nodes for them.
92+
pub prefix_sets: Arc<TriePrefixSetsMut>,
93+
}
94+
95+
impl<Factory> StateRootConfig<Factory> {
96+
/// Creates a new state root config from the consistent view and the trie input.
97+
pub fn new_from_input(consistent_view: ConsistentDbView<Factory>, input: TrieInput) -> Self {
98+
Self {
99+
consistent_view,
100+
nodes_sorted: Arc::new(input.nodes.into_sorted()),
101+
state_sorted: Arc::new(input.state.into_sorted()),
102+
prefix_sets: Arc::new(input.prefix_sets),
103+
}
104+
}
81105
}
82106

83107
/// Messages used internally by the state root task
@@ -321,8 +345,7 @@ where
321345
/// Returns proof targets derived from the state update.
322346
fn on_state_update(
323347
scope: &rayon::Scope<'env>,
324-
view: ConsistentDbView<Factory>,
325-
input: Arc<TrieInput>,
348+
config: StateRootConfig<Factory>,
326349
update: EvmState,
327350
fetched_proof_targets: &mut MultiProofTargets,
328351
proof_sequence_number: u64,
@@ -335,7 +358,7 @@ where
335358

336359
// Dispatch proof gathering for this state update
337360
scope.spawn(move |_| {
338-
let provider = match view.provider_ro() {
361+
let provider = match config.consistent_view.provider_ro() {
339362
Ok(provider) => provider,
340363
Err(error) => {
341364
error!(target: "engine::root", ?error, "Could not get provider");
@@ -346,11 +369,18 @@ where
346369
};
347370

348371
// TODO: replace with parallel proof
349-
let result = Proof::overlay_multiproof(
350-
provider.tx_ref(),
351-
input.as_ref().clone(),
352-
proof_targets.clone(),
353-
);
372+
let result = Proof::from_tx(provider.tx_ref())
373+
.with_trie_cursor_factory(InMemoryTrieCursorFactory::new(
374+
DatabaseTrieCursorFactory::new(provider.tx_ref()),
375+
&config.nodes_sorted,
376+
))
377+
.with_hashed_cursor_factory(HashedPostStateCursorFactory::new(
378+
DatabaseHashedCursorFactory::new(provider.tx_ref()),
379+
&config.state_sorted,
380+
))
381+
.with_prefix_sets_mut(config.prefix_sets.as_ref().clone())
382+
.with_branch_node_hash_masks(true)
383+
.multiproof(proof_targets.clone());
354384
match result {
355385
Ok(proof) => {
356386
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
@@ -472,8 +502,7 @@ where
472502
);
473503
Self::on_state_update(
474504
scope,
475-
self.config.consistent_view.clone(),
476-
self.config.input.clone(),
505+
self.config.clone(),
477506
update,
478507
&mut self.fetched_proof_targets,
479508
self.proof_sequencer.next_sequence(),
@@ -859,13 +888,16 @@ mod tests {
859888
}
860889
}
861890

891+
let input = TrieInput::from_state(hashed_state);
892+
let nodes_sorted = Arc::new(input.nodes.clone().into_sorted());
893+
let state_sorted = Arc::new(input.state.clone().into_sorted());
862894
let config = StateRootConfig {
863895
consistent_view: ConsistentDbView::new(factory, None),
864-
input: Arc::new(TrieInput::from_state(hashed_state)),
896+
nodes_sorted: nodes_sorted.clone(),
897+
state_sorted: state_sorted.clone(),
898+
prefix_sets: Arc::new(input.prefix_sets),
865899
};
866900
let provider = config.consistent_view.provider_ro().unwrap();
867-
let nodes_sorted = config.input.nodes.clone().into_sorted();
868-
let state_sorted = config.input.state.clone().into_sorted();
869901
let blinded_provider_factory = ProofBlindedProviderFactory::new(
870902
InMemoryTrieCursorFactory::new(
871903
DatabaseTrieCursorFactory::new(provider.tx_ref()),
@@ -875,7 +907,7 @@ mod tests {
875907
DatabaseHashedCursorFactory::new(provider.tx_ref()),
876908
&state_sorted,
877909
),
878-
Arc::new(config.input.prefix_sets.clone()),
910+
config.prefix_sets.clone(),
879911
);
880912
let (root_from_task, _) = std::thread::scope(|std_scope| {
881913
let task = StateRootTask::new(config, blinded_provider_factory);

0 commit comments

Comments
 (0)