diff --git a/crates/blockifier/bench/blockifier_bench.rs b/crates/blockifier/bench/blockifier_bench.rs index 2097bb1cec..a6c08c740a 100644 --- a/crates/blockifier/bench/blockifier_bench.rs +++ b/crates/blockifier/bench/blockifier_bench.rs @@ -2,17 +2,42 @@ //! various aspects related to transferring between accounts, including preparation //! and execution of transfers. //! -//! The main benchmark function is `transfers_benchmark`, which measures the performance -//! of transfers between randomly created accounts, which are iterated over round-robin. +//! The benchmark function `transfers_benchmark` measures the performance of transfers between +//! randomly created accounts, which are iterated over round-robin. +//! +//! The benchmark function `execution_benchmark` measures the performance of the method +//! [`blockifier::transactions::transaction::ExecutableTransaction::execute`] by executing the entry +//! point `advance_counter` of the test contract. +//! +//! The benchmark function `cached_state_benchmark` measures the performance of +//! [`blockifier::state::cached_state::CachedState::add_visited_pcs`] method using a realistic size +//! of data. //! //! Run the benchmarks using `cargo bench --bench blockifier_bench`. +use std::time::Duration; + +use blockifier::context::BlockContext; +use blockifier::state::cached_state::{CachedState, TransactionalState}; +use blockifier::state::state_api::State; +use blockifier::state::visited_pcs::VisitedPcsSet; +use blockifier::test_utils::contracts::FeatureContract; +use blockifier::test_utils::dict_state_reader::DictStateReader; +use blockifier::test_utils::initial_test_state::test_state; use blockifier::test_utils::transfers_generator::{ RecipientGeneratorType, TransfersGenerator, TransfersGeneratorConfig, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use blockifier::test_utils::{create_calldata, CairoVersion, BALANCE}; +use blockifier::transaction::account_transaction::AccountTransaction; +use blockifier::transaction::test_utils::{account_invoke_tx, block_context, max_resource_bounds}; +use blockifier::transaction::transactions::ExecutableTransaction; +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use starknet_api::core::ClassHash; +use starknet_api::hash::StarkHash; +use starknet_api::test_utils::NonceManager; +use starknet_api::{felt, invoke_tx_args}; pub fn transfers_benchmark(c: &mut Criterion) { let transfers_generator_config = TransfersGeneratorConfig { @@ -29,5 +54,201 @@ pub fn transfers_benchmark(c: &mut Criterion) { }); } -criterion_group!(benches, transfers_benchmark); +pub fn cached_state_benchmark(c: &mut Criterion) { + fn get_random_array(size: usize) -> Vec { + let mut vec: Vec = Vec::with_capacity(size); + for _ in 0..vec.capacity() { + vec.push(rand::random()); + } + vec + } + + fn create_class_hash(class_hash: &str) -> ClassHash { + ClassHash(StarkHash::from_hex_unchecked(class_hash)) + } + + // The state shared across all iterations. + let mut cached_state: CachedState = CachedState::default(); + + c.bench_function("cached_state", move |benchmark| { + benchmark.iter_batched( + || { + // This anonymous function creates the simulated visited program counters to add in + // `cached_state`. + // The numbers are taken from tx hash + // 0x0177C9365875CAA840EA8F03F97B0E3A8EE8851A8B952BF157B5DBD4FECCB060. This + // transaction has been chosen randomly, but it may not be representative of the + // average transaction on Starknet. + + let mut class_hashes = Vec::new(); + let mut random_arrays = Vec::new(); + + let class_hash = create_class_hash("a"); + let random_array = get_random_array(11393); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("a"); + let random_array = get_random_array(453); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("a"); + let random_array = get_random_array(604); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("a"); + let random_array = get_random_array(806); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("b"); + let random_array = get_random_array(1327); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("b"); + let random_array = get_random_array(1135); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("b"); + let random_array = get_random_array(213); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("b"); + let random_array = get_random_array(135); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("c"); + let random_array = get_random_array(348); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("c"); + let random_array = get_random_array(88); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("c"); + let random_array = get_random_array(348); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("c"); + let random_array = get_random_array(348); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("d"); + let random_array = get_random_array(875); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("d"); + let random_array = get_random_array(450); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("d"); + let random_array = get_random_array(255); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("d"); + let random_array = get_random_array(210); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("d"); + let random_array = get_random_array(1403); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("d"); + let random_array = get_random_array(210); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("d"); + let random_array = get_random_array(210); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("e"); + let random_array = get_random_array(2386); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + let class_hash = create_class_hash("e"); + let random_array = get_random_array(3602); + class_hashes.push(class_hash); + random_arrays.push(random_array); + + (class_hashes, random_arrays) + }, + |input_data| { + let mut transactional_state = + TransactionalState::create_transactional(&mut cached_state); + for (class_hash, random_array) in input_data.0.into_iter().zip(input_data.1) { + transactional_state.add_visited_pcs(class_hash, &random_array); + } + transactional_state.commit(); + }, + BatchSize::SmallInput, + ) + }); +} + +pub fn execution_benchmark(c: &mut Criterion) { + /// This function sets up and returns all the objects required to execute an invoke transaction. + fn prepare_account_tx() + -> (AccountTransaction, CachedState, BlockContext) { + let block_context = block_context(); + let max_resource_bounds = max_resource_bounds(); + let cairo_version = CairoVersion::Cairo1; + let account = FeatureContract::AccountWithoutValidations(cairo_version); + let test_contract = FeatureContract::TestContract(cairo_version); + let state = + test_state(block_context.chain_info(), BALANCE, &[(account, 1), (test_contract, 1)]); + let account_address = account.get_instance_address(0); + let contract_address = test_contract.get_instance_address(0); + let index = felt!(123_u32); + let base_tx_args = invoke_tx_args! { + resource_bounds: max_resource_bounds, + sender_address: account_address, + }; + + let mut nonce_manager = NonceManager::default(); + let counter_diffs = [101_u32, 102_u32]; + let initial_counters = [felt!(counter_diffs[0]), felt!(counter_diffs[1])]; + let calldata_args = vec![index, initial_counters[0], initial_counters[1]]; + + let account_tx = account_invoke_tx(invoke_tx_args! { + nonce: nonce_manager.next(account_address), + calldata: + create_calldata(contract_address, "advance_counter", &calldata_args), + ..base_tx_args + }); + (account_tx, state, block_context) + } + c.bench_function("execution", move |benchmark| { + benchmark.iter_batched( + prepare_account_tx, + |(account_tx, mut state, block_context)| { + account_tx.execute(&mut state, &block_context, true, true).unwrap() + }, + BatchSize::SmallInput, + ) + }); +} + +criterion_group! { + name = benches; + config = Criterion::default().measurement_time(Duration::from_secs(20)); + targets = transfers_benchmark, execution_benchmark, cached_state_benchmark +} criterion_main!(benches); diff --git a/crates/blockifier/src/blockifier/stateful_validator.rs b/crates/blockifier/src/blockifier/stateful_validator.rs index dd515fb591..7ed3dd1cb7 100644 --- a/crates/blockifier/src/blockifier/stateful_validator.rs +++ b/crates/blockifier/src/blockifier/stateful_validator.rs @@ -17,6 +17,7 @@ use crate::fee::receipt::TransactionReceipt; use crate::state::cached_state::CachedState; use crate::state::errors::StateError; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::VisitedPcs; use crate::transaction::account_transaction::AccountTransaction; use crate::transaction::errors::{TransactionExecutionError, TransactionPreValidationError}; use crate::transaction::transaction_execution::Transaction; @@ -41,12 +42,12 @@ pub enum StatefulValidatorError { pub type StatefulValidatorResult = Result; /// Manages state related transaction validations for pre-execution flows. -pub struct StatefulValidator { - tx_executor: TransactionExecutor, +pub struct StatefulValidator { + tx_executor: TransactionExecutor, } -impl StatefulValidator { - pub fn create(state: CachedState, block_context: BlockContext) -> Self { +impl StatefulValidator { + pub fn create(state: CachedState, block_context: BlockContext) -> Self { let tx_executor = TransactionExecutor::new(state, block_context, TransactionExecutorConfig::default()); Self { tx_executor } diff --git a/crates/blockifier/src/blockifier/transaction_executor.rs b/crates/blockifier/src/blockifier/transaction_executor.rs index 14c13ed932..37df27a5e3 100644 --- a/crates/blockifier/src/blockifier/transaction_executor.rs +++ b/crates/blockifier/src/blockifier/transaction_executor.rs @@ -1,6 +1,4 @@ #[cfg(feature = "concurrency")] -use std::collections::{HashMap, HashSet}; -#[cfg(feature = "concurrency")] use std::panic::{self, catch_unwind, AssertUnwindSafe}; #[cfg(feature = "concurrency")] use std::sync::Arc; @@ -20,6 +18,7 @@ use crate::context::BlockContext; use crate::state::cached_state::{CachedState, CommitmentStateDiff, TransactionalState}; use crate::state::errors::StateError; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::VisitedPcs; use crate::transaction::errors::TransactionExecutionError; use crate::transaction::objects::TransactionExecutionInfo; use crate::transaction::transaction_execution::Transaction; @@ -45,7 +44,7 @@ pub type TransactionExecutorResult = Result; pub type VisitedSegmentsMapping = Vec<(ClassHash, Vec)>; // TODO(Gilad): make this hold TransactionContext instead of BlockContext. -pub struct TransactionExecutor { +pub struct TransactionExecutor { pub block_context: BlockContext, pub bouncer: Bouncer, // Note: this config must not affect the execution result (e.g. state diff and traces). @@ -56,12 +55,12 @@ pub struct TransactionExecutor { // block state to the worker executor - operating at the chunk level - and gets it back after // committing the chunk. The block state is wrapped with an Option<_> to allow setting it to // `None` while it is moved to the worker executor. - pub block_state: Option>, + pub block_state: Option>, } -impl TransactionExecutor { +impl TransactionExecutor { pub fn new( - block_state: CachedState, + block_state: CachedState, block_context: BlockContext, config: TransactionExecutorConfig, ) -> Self { @@ -159,7 +158,8 @@ impl TransactionExecutor { .as_ref() .expect(BLOCK_STATE_ACCESS_ERR) .get_compiled_contract_class(*class_hash)?; - Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?)) + let class_visited_pcs = V::to_set(class_visited_pcs.clone()); + Ok((*class_hash, contract_class.get_visited_segments(&class_visited_pcs)?)) }) .collect::>()?; @@ -172,7 +172,11 @@ impl TransactionExecutor { } } -impl TransactionExecutor { +impl TransactionExecutor +where + S: StateReader + Send + Sync, + V: VisitedPcs + Send + Sync, +{ /// Executes the given transactions on the state maintained by the executor. /// Stops if and when there is no more room in the block, and returns the executed transactions' /// results. @@ -221,6 +225,7 @@ impl TransactionExecutor { chunk: &[Transaction], ) -> Vec> { use crate::concurrency::utils::AbortIfPanic; + use crate::concurrency::worker_logic::ExecutionTaskOutput; let block_state = self.block_state.take().expect("The block state should be `Some`."); @@ -264,20 +269,20 @@ impl TransactionExecutor { let n_committed_txs = worker_executor.scheduler.get_n_committed_txs(); let mut tx_execution_results = Vec::new(); - let mut visited_pcs: HashMap> = HashMap::new(); + let mut visited_pcs: V = V::new(); for execution_output in worker_executor.execution_outputs.iter() { if tx_execution_results.len() >= n_committed_txs { break; } - let locked_execution_output = execution_output + let locked_execution_output: ExecutionTaskOutput = execution_output .lock() .expect("Failed to lock execution output.") .take() .expect("Output must be ready."); tx_execution_results .push(locked_execution_output.result.map_err(TransactionExecutorError::from)); - for (class_hash, class_visited_pcs) in locked_execution_output.visited_pcs { - visited_pcs.entry(class_hash).or_default().extend(class_visited_pcs); + for (class_hash, class_visited_pcs) in locked_execution_output.visited_pcs.iter() { + visited_pcs.extend(class_hash, class_visited_pcs); } } diff --git a/crates/blockifier/src/blockifier/transaction_executor_test.rs b/crates/blockifier/src/blockifier/transaction_executor_test.rs index 7be2bf7ab4..0c403c6a81 100644 --- a/crates/blockifier/src/blockifier/transaction_executor_test.rs +++ b/crates/blockifier/src/blockifier/transaction_executor_test.rs @@ -16,6 +16,7 @@ use crate::bouncer::{Bouncer, BouncerWeights}; use crate::context::BlockContext; use crate::state::cached_state::CachedState; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::VisitedPcs; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; use crate::test_utils::deploy_account::deploy_account_tx; @@ -35,8 +36,8 @@ use crate::transaction::test_utils::{ use crate::transaction::transaction_execution::Transaction; use crate::transaction::transactions::L1HandlerTransaction; -fn tx_executor_test_body( - state: CachedState, +fn tx_executor_test_body( + state: CachedState, block_context: BlockContext, tx: Transaction, expected_bouncer_weights: BouncerWeights, diff --git a/crates/blockifier/src/bouncer_test.rs b/crates/blockifier/src/bouncer_test.rs index 726c0fe533..cff045ddb0 100644 --- a/crates/blockifier/src/bouncer_test.rs +++ b/crates/blockifier/src/bouncer_test.rs @@ -13,7 +13,7 @@ use crate::blockifier::transaction_executor::{ use crate::bouncer::{verify_tx_weights_in_bounds, Bouncer, BouncerWeights, BuiltinCount}; use crate::context::BlockContext; use crate::execution::call_info::ExecutionSummary; -use crate::state::cached_state::{StateChangesKeys, TransactionalState}; +use crate::state::cached_state::{StateChangesKeys}; use crate::test_utils::initial_test_state::test_state; use crate::transaction::errors::TransactionExecutionError; @@ -184,10 +184,11 @@ fn test_bouncer_try_update( ) { use cairo_vm::vm::runners::cairo_runner::ExecutionResources; + use crate::state::cached_state::TransactionalState; use crate::transaction::objects::TransactionResources; let state = &mut test_state(&BlockContext::create_for_account_testing().chain_info, 0, &[]); - let mut transactional_state = TransactionalState::create_transactional(state); + let mut transactional_state = TransactionalState::create_transactional_for_testing(state); // Setup the bouncer. let block_max_capacity = BouncerWeights { diff --git a/crates/blockifier/src/concurrency/fee_utils.rs b/crates/blockifier/src/concurrency/fee_utils.rs index e40f492319..61a436ffdc 100644 --- a/crates/blockifier/src/concurrency/fee_utils.rs +++ b/crates/blockifier/src/concurrency/fee_utils.rs @@ -10,6 +10,7 @@ use crate::execution::call_info::CallInfo; use crate::fee::fee_utils::get_sequencer_balance_keys; use crate::state::cached_state::{ContractClassMapping, StateMaps}; use crate::state::state_api::UpdatableState; +use crate::state::visited_pcs::VisitedPcs; use crate::transaction::objects::TransactionExecutionInfo; #[cfg(test)] @@ -22,10 +23,10 @@ mod test; pub(crate) const STORAGE_READ_SEQUENCER_BALANCE_INDICES: (usize, usize) = (2, 3); // Completes the fee transfer flow if needed (if the transfer was made in concurrent mode). -pub fn complete_fee_transfer_flow( +pub fn complete_fee_transfer_flow>( tx_context: &TransactionContext, tx_execution_info: &mut TransactionExecutionInfo, - state: &mut impl UpdatableState, + state: &mut U, ) { if tx_context.is_sequencer_the_sender() { // When the sequencer is the sender, we use the sequential (full) fee transfer. @@ -93,9 +94,9 @@ pub fn fill_sequencer_balance_reads( storage_read_values[high_index] = high; } -pub fn add_fee_to_sequencer_balance( +pub fn add_fee_to_sequencer_balance>( fee_token_address: ContractAddress, - state: &mut impl UpdatableState, + state: &mut U, actual_fee: Fee, block_context: &BlockContext, sequencer_balance: (Felt, Felt), @@ -120,5 +121,5 @@ pub fn add_fee_to_sequencer_balance( ]), ..StateMaps::default() }; - state.apply_writes(&writes, &ContractClassMapping::default(), &HashMap::default()); + state.apply_writes(&writes, &ContractClassMapping::default(), &V::default()); } diff --git a/crates/blockifier/src/concurrency/flow_test.rs b/crates/blockifier/src/concurrency/flow_test.rs index 684dbd49a5..d570e92cd6 100644 --- a/crates/blockifier/src/concurrency/flow_test.rs +++ b/crates/blockifier/src/concurrency/flow_test.rs @@ -12,6 +12,7 @@ use crate::concurrency::test_utils::{safe_versioned_state_for_testing, DEFAULT_C use crate::concurrency::versioned_state::ThreadSafeVersionedState; use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps}; use crate::state::state_api::UpdatableState; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::dict_state_reader::DictStateReader; const CONTRACT_ADDRESS: &str = "0x18031991"; @@ -26,6 +27,8 @@ fn scheduler_flow_test( // transactions with multiple threads, where every transaction depends on its predecessor. Each // transaction sequentially advances a counter by reading the previous value and bumping it by // 1. + + use crate::state::visited_pcs::VisitedPcsSet; let scheduler = Arc::new(Scheduler::new(DEFAULT_CHUNK_SIZE)); let versioned_state = safe_versioned_state_for_testing(CachedState::from(DictStateReader::default())); @@ -52,7 +55,7 @@ fn scheduler_flow_test( state_proxy.apply_writes( &new_writes, &ContractClassMapping::default(), - &HashMap::default(), + &VisitedPcsSet::default(), ); scheduler.finish_execution_during_commit(tx_index); } @@ -65,13 +68,13 @@ fn scheduler_flow_test( versioned_state.pin_version(tx_index).apply_writes( &writes, &ContractClassMapping::default(), - &HashMap::default(), + &VisitedPcsSet::default(), ); scheduler.finish_execution(tx_index); Task::AskForTask } Task::ValidationTask(tx_index) => { - let state_proxy = versioned_state.pin_version(tx_index); + let state_proxy = versioned_state.pin_version_for_testing(tx_index); let (reads, writes) = get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state); let read_set_valid = state_proxy.validate_reads(&reads); @@ -119,7 +122,7 @@ fn scheduler_flow_test( fn get_reads_writes_for( task: Task, - versioned_state: &ThreadSafeVersionedState>, + versioned_state: &ThreadSafeVersionedState>, ) -> (StateMaps, StateMaps) { match task { Task::ExecutionTask(tx_index) => { @@ -130,7 +133,7 @@ fn get_reads_writes_for( state_maps_with_single_storage_entry(1), ); } - _ => versioned_state.pin_version(tx_index - 1), + _ => versioned_state.pin_version_for_testing(tx_index - 1), }; let tx_written_value = SierraU128::from_storage( &state_proxy, @@ -145,7 +148,7 @@ fn get_reads_writes_for( ) } Task::ValidationTask(tx_index) => { - let state_proxy = versioned_state.pin_version(tx_index); + let state_proxy = versioned_state.pin_version_for_testing(tx_index); let tx_written_value = SierraU128::from_storage( &state_proxy, &contract_address!(CONTRACT_ADDRESS), diff --git a/crates/blockifier/src/concurrency/test_utils.rs b/crates/blockifier/src/concurrency/test_utils.rs index 87722b1171..c75ee7ca7a 100644 --- a/crates/blockifier/src/concurrency/test_utils.rs +++ b/crates/blockifier/src/concurrency/test_utils.rs @@ -7,6 +7,7 @@ use crate::context::BlockContext; use crate::execution::call_info::CallInfo; use crate::state::cached_state::{CachedState, TransactionalState}; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::{VisitedPcs, VisitedPcsSet}; use crate::test_utils::dict_state_reader::DictStateReader; use crate::transaction::account_transaction::AccountTransaction; use crate::transaction::transactions::{ExecutableTransaction, ExecutionFlags}; @@ -61,16 +62,16 @@ macro_rules! default_scheduler { // TODO(meshi, 01/06/2024): Consider making this a macro. pub fn safe_versioned_state_for_testing( - block_state: CachedState, -) -> ThreadSafeVersionedState> { + block_state: CachedState, +) -> ThreadSafeVersionedState> { ThreadSafeVersionedState::new(VersionedState::new(block_state)) } // Utils. // Note: this function does not mutate the state. -pub fn create_fee_transfer_call_info( - state: &mut CachedState, +pub fn create_fee_transfer_call_info( + state: &mut CachedState, account_tx: &AccountTransaction, concurrency_mode: bool, ) -> CallInfo { diff --git a/crates/blockifier/src/concurrency/versioned_state.rs b/crates/blockifier/src/concurrency/versioned_state.rs index a6edb590ee..4a80f76938 100644 --- a/crates/blockifier/src/concurrency/versioned_state.rs +++ b/crates/blockifier/src/concurrency/versioned_state.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::marker::PhantomData; use std::sync::{Arc, Mutex, MutexGuard}; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; @@ -11,6 +11,7 @@ use crate::execution::contract_class::ContractClass; use crate::state::cached_state::{ContractClassMapping, StateMaps}; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult, UpdatableState}; +use crate::state::visited_pcs::VisitedPcs; #[cfg(test)] #[path = "versioned_state_test.rs"] @@ -198,11 +199,11 @@ impl VersionedState { } } -impl VersionedState { +impl> VersionedState { pub fn commit_chunk_and_recover_block_state( mut self, n_committed_txs: usize, - visited_pcs: HashMap>, + visited_pcs: V, ) -> U { if n_committed_txs == 0 { return self.into_initial_state(); @@ -229,8 +230,16 @@ impl ThreadSafeVersionedState { ThreadSafeVersionedState(Arc::new(Mutex::new(versioned_state))) } - pub fn pin_version(&self, tx_index: TxIndex) -> VersionedStateProxy { - VersionedStateProxy { tx_index, state: self.0.clone() } + pub fn pin_version(&self, tx_index: TxIndex) -> VersionedStateProxy { + VersionedStateProxy { tx_index, state: self.0.clone(), _marker: PhantomData } + } + + #[cfg(test)] + pub fn pin_version_for_testing( + &self, + tx_index: TxIndex, + ) -> VersionedStateProxy { + VersionedStateProxy { tx_index, state: self.0.clone(), _marker: PhantomData } } pub fn into_inner_state(self) -> VersionedState { @@ -252,12 +261,13 @@ impl Clone for ThreadSafeVersionedState { } } -pub struct VersionedStateProxy { +pub struct VersionedStateProxy { pub tx_index: TxIndex, pub state: Arc>>, + _marker: PhantomData, } -impl VersionedStateProxy { +impl VersionedStateProxy { fn state(&self) -> LockedVersionedState<'_, S> { self.state.lock().expect("Failed to acquire state lock.") } @@ -272,18 +282,20 @@ impl VersionedStateProxy { } // TODO(Noa, 15/5/24): Consider using visited_pcs. -impl UpdatableState for VersionedStateProxy { +impl UpdatableState for VersionedStateProxy { + type Pcs = V; + fn apply_writes( &mut self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping, - _visited_pcs: &HashMap>, + _visited_pcs: &V, ) { self.state().apply_writes(self.tx_index, writes, class_hash_to_class) } } -impl StateReader for VersionedStateProxy { +impl StateReader for VersionedStateProxy { fn get_storage_at( &self, contract_address: ContractAddress, diff --git a/crates/blockifier/src/concurrency/versioned_state_test.rs b/crates/blockifier/src/concurrency/versioned_state_test.rs index 74f4b7e259..697d8b9391 100644 --- a/crates/blockifier/src/concurrency/versioned_state_test.rs +++ b/crates/blockifier/src/concurrency/versioned_state_test.rs @@ -31,11 +31,7 @@ use crate::concurrency::test_utils::{ contract_address, safe_versioned_state_for_testing, }; -use crate::concurrency::versioned_state::{ - ThreadSafeVersionedState, - VersionedState, - VersionedStateProxy, -}; +use crate::concurrency::versioned_state::{ThreadSafeVersionedState, VersionedState}; use crate::concurrency::TxIndex; use crate::context::BlockContext; use crate::state::cached_state::{ @@ -46,6 +42,7 @@ use crate::state::cached_state::{ }; use crate::state::errors::StateError; use crate::state::state_api::{State, StateReader, UpdatableState}; +use crate::state::visited_pcs::{VisitedPcs, VisitedPcsSet}; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::deploy_account::deploy_account_tx; use crate::test_utils::dict_state_reader::DictStateReader; @@ -60,7 +57,7 @@ use crate::transaction::transactions::ExecutableTransaction; pub fn safe_versioned_state( contract_address: ContractAddress, class_hash: ClassHash, -) -> ThreadSafeVersionedState> { +) -> ThreadSafeVersionedState> { let init_state = DictStateReader { address_to_class_hash: HashMap::from([(contract_address, class_hash)]), ..Default::default() @@ -93,8 +90,8 @@ fn test_versioned_state_proxy() { let versioned_state = Arc::new(Mutex::new(VersionedState::new(cached_state))); let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state)); - let versioned_state_proxys: Vec>> = - (0..20).map(|i| safe_versioned_state.pin_version(i)).collect(); + let versioned_state_proxys: Vec<_> = + (0..20).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect(); // Read initial data assert_eq!(versioned_state_proxys[5].get_nonce_at(contract_address).unwrap(), nonce); @@ -229,10 +226,12 @@ fn test_run_parallel_txs(max_resource_bounds: ValidResourceBounds) { )))); let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state)); - let mut versioned_state_proxy_1 = safe_versioned_state.pin_version(1); - let mut state_1 = TransactionalState::create_transactional(&mut versioned_state_proxy_1); - let mut versioned_state_proxy_2 = safe_versioned_state.pin_version(2); - let mut state_2 = TransactionalState::create_transactional(&mut versioned_state_proxy_2); + let mut versioned_state_proxy_1 = safe_versioned_state.pin_version_for_testing(1); + let mut state_1 = + TransactionalState::create_transactional_for_testing(&mut versioned_state_proxy_1); + let mut versioned_state_proxy_2 = safe_versioned_state.pin_version_for_testing(2); + let mut state_2 = + TransactionalState::create_transactional_for_testing(&mut versioned_state_proxy_2); // Prepare transactions let deploy_account_tx_1 = deploy_account_tx( @@ -297,15 +296,16 @@ fn test_run_parallel_txs(max_resource_bounds: ValidResourceBounds) { fn test_validate_reads( contract_address: ContractAddress, class_hash: ClassHash, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let storage_key = storage_key!(0x10_u8); - let mut version_state_proxy = safe_versioned_state.pin_version(1); - let transactional_state = TransactionalState::create_transactional(&mut version_state_proxy); + let mut version_state_proxy = safe_versioned_state.pin_version_for_testing(1); + let transactional_state = + TransactionalState::create_transactional_for_testing(&mut version_state_proxy); // Validating tx index 0 always succeeds. - assert!(safe_versioned_state.pin_version(0).validate_reads(&StateMaps::default())); + assert!(safe_versioned_state.pin_version_for_testing(0).validate_reads(&StateMaps::default())); assert!(transactional_state.cache.borrow().initial_reads.storage.is_empty()); transactional_state.get_storage_at(contract_address, storage_key).unwrap(); @@ -334,7 +334,7 @@ fn test_validate_reads( assert!( safe_versioned_state - .pin_version(1) + .pin_version_for_testing(1) .validate_reads(&transactional_state.cache.borrow().initial_reads) ); } @@ -387,16 +387,16 @@ fn test_validate_reads( fn test_false_validate_reads( #[case] tx_1_reads: StateMaps, #[case] tx_0_writes: StateMaps, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { - let version_state_proxy = safe_versioned_state.pin_version(0); + let version_state_proxy = safe_versioned_state.pin_version_for_testing(0); version_state_proxy.state().apply_writes(0, &tx_0_writes, &HashMap::default()); - assert!(!safe_versioned_state.pin_version(1).validate_reads(&tx_1_reads)); + assert!(!safe_versioned_state.pin_version_for_testing(1).validate_reads(&tx_1_reads)); } #[rstest] fn test_false_validate_reads_declared_contracts( - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let tx_1_reads = StateMaps { declared_contracts: HashMap::from([(class_hash!(1_u8), false)]), @@ -406,24 +406,25 @@ fn test_false_validate_reads_declared_contracts( declared_contracts: HashMap::from([(class_hash!(1_u8), true)]), ..Default::default() }; - let version_state_proxy = safe_versioned_state.pin_version(0); + let version_state_proxy = safe_versioned_state.pin_version_for_testing(0); let compiled_contract_calss = FeatureContract::TestContract(CairoVersion::Cairo1).get_class(); let class_hash_to_class = HashMap::from([(class_hash!(1_u8), compiled_contract_calss)]); version_state_proxy.state().apply_writes(0, &tx_0_writes, &class_hash_to_class); - assert!(!safe_versioned_state.pin_version(1).validate_reads(&tx_1_reads)); + assert!(!safe_versioned_state.pin_version_for_testing(1).validate_reads(&tx_1_reads)); } #[rstest] fn test_apply_writes( contract_address: ContractAddress, class_hash: ClassHash, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { - let mut versioned_proxy_states: Vec>> = - (0..2).map(|i| safe_versioned_state.pin_version(i)).collect(); - let mut transactional_states: Vec< - TransactionalState<'_, VersionedStateProxy>>, - > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); + let mut versioned_proxy_states: Vec<_> = + (0..2).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect(); + let mut transactional_states: Vec<_> = versioned_proxy_states + .iter_mut() + .map(TransactionalState::create_transactional_for_testing) + .collect(); // Transaction 0 class hash. let class_hash_0 = class_hash!(76_u8); @@ -437,10 +438,10 @@ fn test_apply_writes( transactional_states[0].set_contract_class(class_hash, contract_class_0.clone()).unwrap(); assert_eq!(transactional_states[0].class_hash_to_class.borrow().len(), 1); - safe_versioned_state.pin_version(0).apply_writes( + safe_versioned_state.pin_version_for_testing(0).apply_writes( &transactional_states[0].cache.borrow().writes, &transactional_states[0].class_hash_to_class.borrow().clone(), - &HashMap::default(), + &VisitedPcsSet::default(), ); assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0); assert!( @@ -453,13 +454,14 @@ fn test_apply_writes( fn test_apply_writes_reexecute_scenario( contract_address: ContractAddress, class_hash: ClassHash, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { - let mut versioned_proxy_states: Vec>> = - (0..2).map(|i| safe_versioned_state.pin_version(i)).collect(); - let mut transactional_states: Vec< - TransactionalState<'_, VersionedStateProxy>>, - > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); + let mut versioned_proxy_states: Vec<_> = + (0..2).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect(); + let mut transactional_states: Vec<_> = versioned_proxy_states + .iter_mut() + .map(TransactionalState::create_transactional_for_testing) + .collect(); // Transaction 0 class hash. let class_hash_0 = class_hash!(76_u8); @@ -469,10 +471,10 @@ fn test_apply_writes_reexecute_scenario( // updated. assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash); - safe_versioned_state.pin_version(0).apply_writes( + safe_versioned_state.pin_version_for_testing(0).apply_writes( &transactional_states[0].cache.borrow().writes, &transactional_states[0].class_hash_to_class.borrow().clone(), - &HashMap::default(), + &VisitedPcsSet::default(), ); // Although transaction 0 wrote to the shared state, version 1 needs to be re-executed to see // the new value (its read value has already been cached). @@ -480,7 +482,7 @@ fn test_apply_writes_reexecute_scenario( // TODO: Use re-execution native util once it's ready. // "Re-execute" the transaction. - let mut versioned_state_proxy = safe_versioned_state.pin_version(1); + let mut versioned_state_proxy = safe_versioned_state.pin_version_for_testing(1); transactional_states[1] = TransactionalState::create_transactional(&mut versioned_state_proxy); // The class hash should be updated. assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0); @@ -489,14 +491,15 @@ fn test_apply_writes_reexecute_scenario( #[rstest] fn test_delete_writes( #[values(0, 1, 2)] tx_index_to_delete_writes: TxIndex, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let num_of_txs = 3; - let mut versioned_proxy_states: Vec>> = - (0..num_of_txs).map(|i| safe_versioned_state.pin_version(i)).collect(); - let mut transactional_states: Vec< - TransactionalState<'_, VersionedStateProxy>>, - > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); + let mut versioned_proxy_states: Vec<_> = + (0..num_of_txs).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect(); + let mut transactional_states: Vec<_> = versioned_proxy_states + .iter_mut() + .map(TransactionalState::create_transactional_for_testing) + .collect(); // Setting 2 instances of the contract to ensure `delete_writes` removes information from // multiple keys. Class hash values are not checked in this test. @@ -514,14 +517,14 @@ fn test_delete_writes( tx_state .set_contract_class(feature_contract.get_class_hash(), feature_contract.get_class()) .unwrap(); - safe_versioned_state.pin_version(i).apply_writes( + safe_versioned_state.pin_version_for_testing(i).apply_writes( &tx_state.cache.borrow().writes, &tx_state.class_hash_to_class.borrow(), - &HashMap::default(), + &VisitedPcsSet::default(), ); } - safe_versioned_state.pin_version(tx_index_to_delete_writes).delete_writes( + safe_versioned_state.pin_version_for_testing(tx_index_to_delete_writes).delete_writes( &transactional_states[tx_index_to_delete_writes].cache.borrow().writes, &transactional_states[tx_index_to_delete_writes].class_hash_to_class.borrow(), ); @@ -554,7 +557,7 @@ fn test_delete_writes( #[rstest] fn test_delete_writes_completeness( - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let feature_contract = FeatureContract::TestContract(CairoVersion::Cairo1); let state_maps_writes = StateMaps { @@ -574,12 +577,12 @@ fn test_delete_writes_completeness( HashMap::from([(feature_contract.get_class_hash(), feature_contract.get_class())]); let tx_index = 0; - let mut versioned_state_proxy = safe_versioned_state.pin_version(tx_index); + let mut versioned_state_proxy = safe_versioned_state.pin_version_for_testing(tx_index); versioned_state_proxy.apply_writes( &state_maps_writes, &class_hash_to_class_writes, - &HashMap::default(), + &VisitedPcsSet::default(), ); assert_eq!( safe_versioned_state.0.lock().unwrap().get_writes_of_index(tx_index), @@ -613,17 +616,18 @@ fn test_delete_writes_completeness( #[rstest] fn test_versioned_proxy_state_flow( - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let contract_address = contract_address!("0x1"); let class_hash = ClassHash(felt!(27_u8)); - let mut versioned_proxy_states: Vec>> = - (0..4).map(|i| safe_versioned_state.pin_version(i)).collect(); + let mut versioned_proxy_states: Vec<_> = + (0..4).map(|i| safe_versioned_state.pin_version_for_testing(i)).collect(); let mut transactional_states = Vec::with_capacity(4); for proxy_state in &mut versioned_proxy_states { - transactional_states.push(TransactionalState::create_transactional(proxy_state)); + transactional_states + .push(TransactionalState::create_transactional_for_testing(proxy_state)); } // Clients class hash values. @@ -656,7 +660,7 @@ fn test_versioned_proxy_state_flow( } let modified_block_state = safe_versioned_state .into_inner_state() - .commit_chunk_and_recover_block_state(4, HashMap::new()); + .commit_chunk_and_recover_block_state(4, VisitedPcsSet::new()); assert!(modified_block_state.get_class_hash_at(contract_address).unwrap() == class_hash_3); assert!( diff --git a/crates/blockifier/src/concurrency/worker_logic.rs b/crates/blockifier/src/concurrency/worker_logic.rs index e84960f350..854df1ce4b 100644 --- a/crates/blockifier/src/concurrency/worker_logic.rs +++ b/crates/blockifier/src/concurrency/worker_logic.rs @@ -1,10 +1,9 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::fmt::Debug; use std::sync::Mutex; use std::thread; use std::time::Duration; -use starknet_api::core::ClassHash; use super::versioned_state::VersionedState; use crate::blockifier::transaction_executor::TransactionExecutorError; @@ -22,6 +21,7 @@ use crate::state::cached_state::{ TransactionalState, }; use crate::state::state_api::{StateReader, UpdatableState}; +use crate::state::visited_pcs::{VisitedPcs, VisitedPcsSet}; use crate::transaction::objects::{TransactionExecutionInfo, TransactionExecutionResult}; use crate::transaction::transaction_execution::Transaction; use crate::transaction::transactions::{ExecutableTransaction, ExecutionFlags}; @@ -33,23 +33,34 @@ pub mod test; const EXECUTION_OUTPUTS_UNWRAP_ERROR: &str = "Execution task outputs should not be None."; #[derive(Debug)] -pub struct ExecutionTaskOutput { +pub struct ExecutionTaskOutput { pub reads: StateMaps, pub writes: StateMaps, pub contract_classes: ContractClassMapping, - pub visited_pcs: HashMap>, + pub visited_pcs: V, pub result: TransactionExecutionResult, } -pub struct WorkerExecutor<'a, S: StateReader> { +pub struct WorkerExecutor<'a, S: StateReader, V: VisitedPcs> { pub scheduler: Scheduler, pub state: ThreadSafeVersionedState, pub chunk: &'a [Transaction], - pub execution_outputs: Box<[Mutex>]>, + pub execution_outputs: Box<[Mutex>>]>, pub block_context: &'a BlockContext, pub bouncer: Mutex<&'a mut Bouncer>, } -impl<'a, S: StateReader> WorkerExecutor<'a, S> { +impl<'a, S: StateReader> WorkerExecutor<'a, S, VisitedPcsSet> { + #[cfg(test)] + pub fn new_for_testing( + state: ThreadSafeVersionedState, + chunk: &'a [Transaction], + block_context: &'a BlockContext, + bouncer: Mutex<&'a mut Bouncer>, + ) -> WorkerExecutor<'a, S, VisitedPcsSet> { + WorkerExecutor::new(state, chunk, block_context, bouncer) + } +} +impl<'a, S: StateReader, V: VisitedPcs> WorkerExecutor<'a, S, V> { pub fn new( state: ThreadSafeVersionedState, chunk: &'a [Transaction], @@ -144,7 +155,7 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { &tx_reads_writes.writes, &class_hash_to_class, // The versioned state does not carry the visited PCs. - &HashMap::default(), + &V::default(), ); ExecutionTaskOutput { reads: tx_reads_writes.initial_reads, @@ -159,7 +170,7 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { // Failed transaction - ignore the writes and visited PCs. writes: StateMaps::default(), contract_classes: HashMap::default(), - visited_pcs: HashMap::default(), + visited_pcs: V::default(), result: execution_result, }, }; @@ -168,7 +179,7 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { } fn validate(&self, tx_index: TxIndex) -> Task { - let tx_versioned_state = self.state.pin_version(tx_index); + let tx_versioned_state = self.state.pin_version::(tx_index); let execution_output = lock_mutex_in_array(&self.execution_outputs, tx_index); let execution_output = execution_output.as_ref().expect(EXECUTION_OUTPUTS_UNWRAP_ERROR); let reads = &execution_output.reads; @@ -201,7 +212,7 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { let execution_output_ref = execution_output.as_ref().expect(EXECUTION_OUTPUTS_UNWRAP_ERROR); let reads = &execution_output_ref.reads; - let mut tx_versioned_state = self.state.pin_version(tx_index); + let mut tx_versioned_state = self.state.pin_version::(tx_index); let reads_valid = tx_versioned_state.validate_reads(reads); // First, re-validate the transaction. @@ -268,12 +279,8 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { } } -impl<'a, U: UpdatableState> WorkerExecutor<'a, U> { - pub fn commit_chunk_and_recover_block_state( - self, - n_committed_txs: usize, - visited_pcs: HashMap>, - ) -> U { +impl<'a, V: VisitedPcs, U: UpdatableState> WorkerExecutor<'a, U, V> { + pub fn commit_chunk_and_recover_block_state(self, n_committed_txs: usize, visited_pcs: V) -> U { self.state .into_inner_state() .commit_chunk_and_recover_block_state(n_committed_txs, visited_pcs) diff --git a/crates/blockifier/src/concurrency/worker_logic_test.rs b/crates/blockifier/src/concurrency/worker_logic_test.rs index b871c90f7a..e8b1996658 100644 --- a/crates/blockifier/src/concurrency/worker_logic_test.rs +++ b/crates/blockifier/src/concurrency/worker_logic_test.rs @@ -34,6 +34,7 @@ use crate::context::{BlockContext, TransactionContext}; use crate::fee::fee_utils::get_sequencer_balance_keys; use crate::state::cached_state::StateMaps; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; use crate::test_utils::initial_test_state::test_state; @@ -77,7 +78,7 @@ fn verify_sequencer_balance_update( // We assume the balance is at most 2^128, so the "low" value is sufficient. expected_sequencer_balance_low: u128, ) { - let tx_version_state = state.pin_version(tx_index); + let tx_version_state = state.pin_version_for_testing(tx_index); let (sequencer_balance_key_low, sequencer_balance_key_high) = get_sequencer_balance_keys(&tx_context.block_context); for (expected_balance, storage_key) in [ @@ -117,8 +118,12 @@ pub fn test_commit_tx() { let cached_state = test_state(&block_context.chain_info, BALANCE, &[(account, 1), (test_contract, 1)]); let versioned_state = safe_versioned_state_for_testing(cached_state); - let executor = - WorkerExecutor::new(versioned_state, &txs, &block_context, Mutex::new(&mut bouncer)); + let executor = WorkerExecutor::new_for_testing( + versioned_state, + &txs, + &block_context, + Mutex::new(&mut bouncer), + ); // Execute transactions. // Simulate a concurrent run by executing tx1 before tx0. @@ -217,14 +222,14 @@ fn test_commit_tx_when_sender_is_sequencer() { let state = test_state(&block_context.chain_info, BALANCE, &[(account, 1), (test_contract, 1)]); let versioned_state = safe_versioned_state_for_testing(state); - let executor = WorkerExecutor::new( + let executor = WorkerExecutor::new_for_testing( versioned_state, &sequencer_tx, &block_context, Mutex::new(&mut bouncer), ); let tx_index = 0; - let tx_versioned_state = executor.state.pin_version(tx_index); + let tx_versioned_state = executor.state.pin_version_for_testing(tx_index); // Execute and save the execution result. executor.execute_tx(tx_index); @@ -324,7 +329,7 @@ fn test_worker_execute(max_resource_bounds: ValidResourceBounds) { .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = WorkerExecutor::new( + let worker_executor = WorkerExecutor::new_for_testing( safe_versioned_state.clone(), &txs, &block_context, @@ -342,7 +347,7 @@ fn test_worker_execute(max_resource_bounds: ValidResourceBounds) { // Read a write made by the transaction. assert_eq!( safe_versioned_state - .pin_version(tx_index) + .pin_version_for_testing(tx_index) .get_storage_at(test_contract_address, storage_key) .unwrap(), storage_value @@ -395,14 +400,17 @@ fn test_worker_execute(max_resource_bounds: ValidResourceBounds) { assert_eq!(execution_output.writes, writes); assert_eq!(execution_output.reads, reads); - assert_ne!(execution_output.visited_pcs, HashMap::default()); + assert_ne!(execution_output.visited_pcs, VisitedPcsSet::default()); // Failed execution. let tx_index = 1; worker_executor.execute(tx_index); // No write was made by the transaction. assert_eq!( - safe_versioned_state.pin_version(tx_index).get_nonce_at(account_address).unwrap(), + safe_versioned_state + .pin_version_for_testing(tx_index) + .get_nonce_at(account_address) + .unwrap(), nonce!(1_u8) ); let execution_output = worker_executor.execution_outputs[tx_index].lock().unwrap(); @@ -414,21 +422,24 @@ fn test_worker_execute(max_resource_bounds: ValidResourceBounds) { }; assert_eq!(execution_output.reads, reads); assert_eq!(execution_output.writes, StateMaps::default()); - assert_eq!(execution_output.visited_pcs, HashMap::default()); + assert_eq!(execution_output.visited_pcs, VisitedPcsSet::default()); // Reverted execution. let tx_index = 2; worker_executor.execute(tx_index); // Read a write made by the transaction. assert_eq!( - safe_versioned_state.pin_version(tx_index).get_nonce_at(account_address).unwrap(), + safe_versioned_state + .pin_version_for_testing(tx_index) + .get_nonce_at(account_address) + .unwrap(), nonce!(2_u8) ); let execution_output = worker_executor.execution_outputs[tx_index].lock().unwrap(); let execution_output = execution_output.as_ref().unwrap(); assert!(execution_output.result.as_ref().unwrap().is_reverted()); assert_ne!(execution_output.writes, StateMaps::default()); - assert_ne!(execution_output.visited_pcs, HashMap::default()); + assert_ne!(execution_output.visited_pcs, VisitedPcsSet::default()); // Validate status change. for tx_index in 0..3 { @@ -486,7 +497,7 @@ fn test_worker_validate(max_resource_bounds: ValidResourceBounds) { .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = WorkerExecutor::new( + let worker_executor = WorkerExecutor::new_for_testing( safe_versioned_state.clone(), &txs, &block_context, @@ -512,7 +523,7 @@ fn test_worker_validate(max_resource_bounds: ValidResourceBounds) { // Verify writes exist in state. assert_eq!( safe_versioned_state - .pin_version(tx_index) + .pin_version_for_testing(tx_index) .get_storage_at(test_contract_address, storage_key) .unwrap(), storage_value0 @@ -527,7 +538,7 @@ fn test_worker_validate(max_resource_bounds: ValidResourceBounds) { // Verify writes were removed. assert_eq!( safe_versioned_state - .pin_version(tx_index) + .pin_version_for_testing(tx_index) .get_storage_at(test_contract_address, storage_key) .unwrap(), storage_value0 @@ -599,8 +610,12 @@ fn test_deploy_before_declare( .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = - WorkerExecutor::new(safe_versioned_state, &txs, &block_context, Mutex::new(&mut bouncer)); + let worker_executor = WorkerExecutor::new_for_testing( + safe_versioned_state, + &txs, + &block_context, + Mutex::new(&mut bouncer), + ); // Creates 2 active tasks. worker_executor.scheduler.next_task(); @@ -671,8 +686,12 @@ fn test_worker_commit_phase(max_resource_bounds: ValidResourceBounds) { .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = - WorkerExecutor::new(safe_versioned_state, &txs, &block_context, Mutex::new(&mut bouncer)); + let worker_executor = WorkerExecutor::new_for_testing( + safe_versioned_state, + &txs, + &block_context, + Mutex::new(&mut bouncer), + ); // Try to commit before any transaction is ready. worker_executor.commit_while_possible(); @@ -761,8 +780,12 @@ fn test_worker_commit_phase_with_halt() { .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = - WorkerExecutor::new(safe_versioned_state, &txs, &block_context, Mutex::new(&mut bouncer)); + let worker_executor = WorkerExecutor::new_for_testing( + safe_versioned_state, + &txs, + &block_context, + Mutex::new(&mut bouncer), + ); // Creates 2 active tasks. // Creating these tasks changes the status of both transactions to `Executing`. If we skip this diff --git a/crates/blockifier/src/execution/contract_address_test.rs b/crates/blockifier/src/execution/contract_address_test.rs index 405360da1e..55b65cc3c5 100644 --- a/crates/blockifier/src/execution/contract_address_test.rs +++ b/crates/blockifier/src/execution/contract_address_test.rs @@ -9,6 +9,7 @@ use crate::execution::call_info::{CallExecution, Retdata}; use crate::execution::entry_point::CallEntryPoint; use crate::retdata; use crate::state::cached_state::CachedState; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -27,7 +28,7 @@ fn test_calculate_contract_address() { constructor_calldata: &Calldata, calldata: Calldata, deployer_address: ContractAddress, - state: &mut CachedState, + state: &mut CachedState, ) { let versioned_constants = VersionedConstants::create_for_testing(); let entry_point_call = CallEntryPoint { diff --git a/crates/blockifier/src/execution/entry_point_execution.rs b/crates/blockifier/src/execution/entry_point_execution.rs index 1d21241a6b..a871146cff 100644 --- a/crates/blockifier/src/execution/entry_point_execution.rs +++ b/crates/blockifier/src/execution/entry_point_execution.rs @@ -1,5 +1,3 @@ -use std::collections::HashSet; - use cairo_vm::types::builtin_name::BuiltinName; use cairo_vm::types::layout_name::LayoutName; use cairo_vm::types::relocatable::{MaybeRelocatable, Relocatable}; @@ -121,7 +119,7 @@ fn register_visited_pcs( program_segment_size: usize, bytecode_length: usize, ) -> EntryPointExecutionResult<()> { - let mut class_visited_pcs = HashSet::new(); + let mut class_visited_pcs = Vec::new(); // Relocate the trace, putting the program segment at address 1 and the execution segment right // after it. // TODO(lior): Avoid unnecessary relocation once the VM has a non-relocated `get_trace()` @@ -138,7 +136,7 @@ fn register_visited_pcs( // Jumping to a PC that is not inside the bytecode is possible. For example, to obtain // the builtin costs. Filter out these values. if real_pc < bytecode_length { - class_visited_pcs.insert(real_pc); + class_visited_pcs.push(real_pc); } } state.add_visited_pcs(class_hash, &class_visited_pcs); diff --git a/crates/blockifier/src/execution/entry_point_test.rs b/crates/blockifier/src/execution/entry_point_test.rs index ec4591be3f..fca6de54f4 100644 --- a/crates/blockifier/src/execution/entry_point_test.rs +++ b/crates/blockifier/src/execution/entry_point_test.rs @@ -13,6 +13,7 @@ use crate::execution::call_info::{CallExecution, CallInfo, Retdata}; use crate::execution::entry_point::CallEntryPoint; use crate::retdata; use crate::state::cached_state::CachedState; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -187,7 +188,7 @@ fn test_storage_var() { /// Runs test scenarios that could fail the OS run and therefore must be caught in the Blockifier. fn run_security_test( - state: &mut CachedState, + state: &mut CachedState, security_contract: FeatureContract, expected_error: &str, entry_point_name: &str, diff --git a/crates/blockifier/src/execution/syscalls/syscall_tests/get_block_hash.rs b/crates/blockifier/src/execution/syscalls/syscall_tests/get_block_hash.rs index cb3bdc85db..1a94662d64 100644 --- a/crates/blockifier/src/execution/syscalls/syscall_tests/get_block_hash.rs +++ b/crates/blockifier/src/execution/syscalls/syscall_tests/get_block_hash.rs @@ -13,6 +13,7 @@ use crate::execution::call_info::{CallExecution, Retdata}; use crate::execution::entry_point::CallEntryPoint; use crate::state::cached_state::CachedState; use crate::state::state_api::State; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -24,7 +25,9 @@ use crate::test_utils::{ }; use crate::{check_entry_point_execution_error_for_custom_hint, retdata}; -fn initialize_state(test_contract: FeatureContract) -> (CachedState, Felt, Felt) { +fn initialize_state( + test_contract: FeatureContract, +) -> (CachedState, Felt, Felt) { let chain_info = &ChainInfo::create_for_testing(); let mut state = test_state(chain_info, BALANCE, &[(test_contract, 1)]); diff --git a/crates/blockifier/src/state.rs b/crates/blockifier/src/state.rs index e027d2b301..3bef337429 100644 --- a/crates/blockifier/src/state.rs +++ b/crates/blockifier/src/state.rs @@ -4,3 +4,4 @@ pub mod error_format_test; pub mod errors; pub mod global_cache; pub mod state_api; +pub mod visited_pcs; diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index deff9cc9c7..da478ee7be 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -7,6 +7,7 @@ use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; +use super::visited_pcs::{VisitedPcs, VisitedPcsSet}; use crate::abi::abi_utils::get_fee_token_var_address; use crate::context::TransactionContext; use crate::execution::contract_class::ContractClass; @@ -26,23 +27,23 @@ pub type ContractClassMapping = HashMap; /// Writer functionality is builtin, whereas Reader functionality is injected through /// initialization. #[derive(Debug)] -pub struct CachedState { +pub struct CachedState { pub state: S, // Invariant: read/write access is managed by CachedState. // Using interior mutability to update caches during `State`'s immutable getters. pub(crate) cache: RefCell, pub(crate) class_hash_to_class: RefCell, /// A map from class hash to the set of PC values that were visited in the class. - pub visited_pcs: HashMap>, + pub visited_pcs: V, } -impl CachedState { +impl CachedState { pub fn new(state: S) -> Self { Self { state, cache: RefCell::new(StateCache::default()), class_hash_to_class: RefCell::new(HashMap::default()), - visited_pcs: HashMap::default(), + visited_pcs: V::default(), } } @@ -73,9 +74,9 @@ impl CachedState { self.class_hash_to_class.get_mut().extend(local_contract_cache_updates); } - pub fn update_visited_pcs_cache(&mut self, visited_pcs: &HashMap>) { - for (class_hash, class_visited_pcs) in visited_pcs { - self.add_visited_pcs(*class_hash, class_visited_pcs); + pub fn update_visited_pcs_cache(&mut self, visited_pcs: &V) { + for (class_hash, class_visited_pcs) in visited_pcs.iter() { + V::add_visited_pcs(self, class_hash, class_visited_pcs.clone()) } } @@ -107,12 +108,13 @@ impl CachedState { } } -impl UpdatableState for CachedState { +impl UpdatableState for CachedState { + type Pcs = V; fn apply_writes( &mut self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping, - visited_pcs: &HashMap>, + visited_pcs: &V, ) { // TODO(Noa,15/5/24): Reconsider the clone. self.update_cache(writes, class_hash_to_class.clone()); @@ -121,13 +123,13 @@ impl UpdatableState for CachedState { } #[cfg(any(feature = "testing", test))] -impl From for CachedState { +impl From for CachedState { fn from(state_reader: S) -> Self { CachedState::new(state_reader) } } -impl StateReader for CachedState { +impl StateReader for CachedState { fn get_storage_at( &self, contract_address: ContractAddress, @@ -222,7 +224,7 @@ impl StateReader for CachedState { } } -impl State for CachedState { +impl State for CachedState { fn set_storage_at( &mut self, contract_address: ContractAddress, @@ -275,13 +277,18 @@ impl State for CachedState { Ok(()) } - fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet) { - self.visited_pcs.entry(class_hash).or_default().extend(pcs); + fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &[usize]) { + self.visited_pcs.insert(&class_hash, pcs); } } #[cfg(any(feature = "testing", test))] -impl Default for CachedState { +impl Default + for CachedState< + crate::test_utils::dict_state_reader::DictStateReader, + super::visited_pcs::VisitedPcsSet, + > +{ fn default() -> Self { Self { state: Default::default(), @@ -504,14 +511,21 @@ impl<'a, S: StateReader + ?Sized> StateReader for MutRefState<'a, S> { } } -pub type TransactionalState<'a, U> = CachedState>; - -impl<'a, S: StateReader> TransactionalState<'a, S> { +pub type TransactionalState<'a, U, V> = CachedState, V>; +impl<'a, S: StateReader> TransactionalState<'a, S, VisitedPcsSet> { + #[cfg(test)] + pub fn create_transactional_for_testing( + state: &mut S, + ) -> TransactionalState<'_, S, VisitedPcsSet> { + TransactionalState::create_transactional(state) + } +} +impl<'a, S: StateReader, V: VisitedPcs> TransactionalState<'a, S, V> { /// Creates a transactional instance from the given updatable state. /// It allows performing buffered modifying actions on the given state, which /// will either all happen (will be updated in the state and committed) /// or none of them (will be discarded). - pub fn create_transactional(state: &mut S) -> TransactionalState<'_, S> { + pub fn create_transactional(state: &mut S) -> TransactionalState<'_, S, V> { CachedState::new(MutRefState::new(state)) } @@ -520,7 +534,7 @@ impl<'a, S: StateReader> TransactionalState<'a, S> { } /// Adds the ability to perform a transactional execution. -impl<'a, U: UpdatableState> TransactionalState<'a, U> { +impl<'a, V: VisitedPcs, U: UpdatableState> TransactionalState<'a, U, V> { /// Commits changes in the child (wrapping) state to its parent. pub fn commit(self) { let state = self.state.0; diff --git a/crates/blockifier/src/state/cached_state_test.rs b/crates/blockifier/src/state/cached_state_test.rs index d6c450d9f7..69f118a06a 100644 --- a/crates/blockifier/src/state/cached_state_test.rs +++ b/crates/blockifier/src/state/cached_state_test.rs @@ -17,6 +17,7 @@ use starknet_api::{ use crate::context::{BlockContext, ChainInfo}; use crate::state::cached_state::*; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -24,7 +25,7 @@ use crate::test_utils::CairoVersion; const CONTRACT_ADDRESS: &str = "0x100"; fn set_initial_state_values( - state: &mut CachedState, + state: &mut CachedState, class_hash_to_class: ContractClassMapping, nonce_initial_values: HashMap, class_hash_initial_values: HashMap, @@ -40,7 +41,7 @@ fn set_initial_state_values( #[test] fn get_uninitialized_storage_value() { - let state: CachedState = CachedState::default(); + let state: CachedState = CachedState::default(); let contract_address = contract_address!("0x1"); let key = storage_key!(0x10_u16); @@ -56,7 +57,8 @@ fn get_and_set_storage_value() { let storage_val0: Felt = felt!("0x1"); let storage_val1: Felt = felt!("0x5"); - let mut state = CachedState::from(DictStateReader { + let mut state: CachedState = + CachedState::from(DictStateReader { storage_view: HashMap::from([ ((contract_address0, key0), storage_val0), ((contract_address1, key1), storage_val1), @@ -105,7 +107,7 @@ fn cast_between_storage_mapping_types() { #[test] fn get_uninitialized_value() { - let state: CachedState = CachedState::default(); + let state: CachedState = CachedState::default(); let contract_address = contract_address!("0x1"); assert_eq!(state.get_nonce_at(contract_address).unwrap(), Nonce::default()); @@ -113,7 +115,8 @@ fn get_uninitialized_value() { #[test] fn declare_contract() { - let mut state = CachedState::from(DictStateReader { ..Default::default() }); + let mut state: CachedState = + CachedState::from(DictStateReader { ..Default::default() }); let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let class_hash = test_contract.get_class_hash(); let contract_class = test_contract.get_class(); @@ -142,7 +145,8 @@ fn get_and_increment_nonce() { let contract_address2 = contract_address!("0x200"); let initial_nonce = Nonce(felt!(1_u8)); - let mut state = CachedState::from(DictStateReader { + let mut state: CachedState = + CachedState::from(DictStateReader { address_to_nonce: HashMap::from([ (contract_address1, initial_nonce), (contract_address2, initial_nonce), @@ -188,7 +192,7 @@ fn get_contract_class() { #[test] fn get_uninitialized_class_hash_value() { - let state: CachedState = CachedState::default(); + let state: CachedState = CachedState::default(); let valid_contract_address = contract_address!("0x1"); assert_eq!(state.get_class_hash_at(valid_contract_address).unwrap(), ClassHash::default()); @@ -197,7 +201,7 @@ fn get_uninitialized_class_hash_value() { #[test] fn set_and_get_contract_hash() { let contract_address = contract_address!("0x1"); - let mut state: CachedState = CachedState::default(); + let mut state: CachedState = CachedState::default(); let class_hash = class_hash!("0x10"); assert!(state.set_class_hash_at(contract_address, class_hash).is_ok()); @@ -206,7 +210,7 @@ fn set_and_get_contract_hash() { #[test] fn cannot_set_class_hash_to_uninitialized_contract() { - let mut state: CachedState = CachedState::default(); + let mut state: CachedState = CachedState::default(); let uninitialized_contract_address = ContractAddress::default(); let class_hash = class_hash!("0x100"); @@ -296,8 +300,8 @@ fn cached_state_state_diff_conversion() { assert_eq!(expected_state_diff, state.to_state_diff().unwrap().into()); } -fn create_state_changes_for_test( - state: &mut CachedState, +fn create_state_changes_for_test( + state: &mut CachedState, sender_address: Option, fee_token_address: ContractAddress, ) -> StateChanges { @@ -338,7 +342,7 @@ fn create_state_changes_for_test( fn test_from_state_changes_for_fee_charge( #[values(Some(contract_address!("0x102")), None)] sender_address: Option, ) { - let mut state: CachedState = CachedState::default(); + let mut state: CachedState = CachedState::default(); let fee_token_address = contract_address!("0x17"); let state_changes = create_state_changes_for_test(&mut state, sender_address, fee_token_address); @@ -359,7 +363,7 @@ fn test_state_changes_merge( ) { // Create a transactional state containing the `create_state_changes_for_test` logic, get the // state changes and then commit. - let mut state: CachedState = CachedState::default(); + let mut state: CachedState = CachedState::default(); let mut transactional_state = TransactionalState::create_transactional(&mut state); let block_context = BlockContext::create_for_testing(); let fee_token_address = block_context.chain_info.fee_token_addresses.eth_fee_token_address; @@ -429,7 +433,7 @@ fn test_contract_cache_is_used() { let contract_class = test_contract.get_class(); let mut reader = DictStateReader::default(); reader.class_hash_to_class.insert(class_hash, contract_class.clone()); - let state = CachedState::new(reader); + let state: CachedState = CachedState::new(reader); // Assert local cache is initialized empty. assert!(state.class_hash_to_class.borrow().get(&class_hash).is_none()); diff --git a/crates/blockifier/src/state/state_api.rs b/crates/blockifier/src/state/state_api.rs index 5d3c308e2f..199ad53416 100644 --- a/crates/blockifier/src/state/state_api.rs +++ b/crates/blockifier/src/state/state_api.rs @@ -1,5 +1,3 @@ -use std::collections::{HashMap, HashSet}; - use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; @@ -107,15 +105,17 @@ pub trait State: StateReader { /// Marks the given set of PC values as visited for the given class hash. // TODO(lior): Once we have a BlockResources object, move this logic there. Make sure reverted // entry points do not affect the final set of PCs. - fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet); + fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &[usize]); } /// A class defining the API for updating a state with transactions writes. pub trait UpdatableState: StateReader { + type Pcs; + fn apply_writes( &mut self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping, - visited_pcs: &HashMap>, + visited_pcs: &Self::Pcs, ); } diff --git a/crates/blockifier/src/state/visited_pcs.rs b/crates/blockifier/src/state/visited_pcs.rs new file mode 100644 index 0000000000..306660ad0c --- /dev/null +++ b/crates/blockifier/src/state/visited_pcs.rs @@ -0,0 +1,98 @@ +use std::collections::hash_map::Entry; +use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; + +use starknet_api::core::ClassHash; + +use super::state_api::State; + +/// This trait is used in `CachedState` to record visited pcs of an entry point call. This allows +/// flexible storage of program counters returned from cairo vm trace. +/// +/// # Object Safety +/// +/// This trait uses associated types instead of generics because only one implementation of the +/// trait is required. Also, using associated types reduces the number of parameters required to be +/// specified. +/// The use of associated types makes the trait implementation not [object safe](https://doc.rust-lang.org/reference/items/traits.html#object-safety). +/// +/// Self Bounds +/// +/// - [`Default`] is required to allow a default instantiation of `CachedState`. +/// - [`Debug`] is required for compatibility with other structs which derive `Debug`. +pub trait VisitedPcs +where + Self: Default + Debug, +{ + /// This is the type which contains visited program counters. + /// + /// [`Clone`] is required to allow ownership of data throught cloning when receiving references + /// from one of the trait methods. + type Pcs: Clone; + + /// Constructs a concrete implementation of the trait. + fn new() -> Self; + + /// This function records the program counters returned by the cairo vm trace. + /// + /// The elements of the vector `pcs` match the type of field `pc` in + /// [`cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry`] + fn insert(&mut self, class_hash: &ClassHash, pcs: &[usize]); + + /// This function extends the program counters in `self` with those from another instance. + /// + /// It is used to transfer the visited program counters from one object to another. + fn extend(&mut self, class_hash: &ClassHash, pcs: &Self::Pcs); + + /// This function returns an iterator of `VisitedPcs`. + /// + /// One tuple is returned for each class hash recorded in `self`. + fn iter(&self) -> impl Iterator; + + /// Get the recorded visited program counters for a specific `class_hash`. + fn entry(&mut self, class_hash: ClassHash) -> Entry<'_, ClassHash, Self::Pcs>; + + /// Marks the given `pcs` values as visited for the given class hash. + fn add_visited_pcs(state: &mut dyn State, class_hash: &ClassHash, pcs: Self::Pcs); + + /// This function transforms the internal representation of program counters into a set. + fn to_set(pcs: Self::Pcs) -> HashSet; +} + +/// [`VisitedPcsSet`] is the default implementation of the trait [`VisiedPcs`]. All visited program +/// counters are inserted in a set and grouped by class hash. +/// +/// This is also the structure used by the `native_blockifier`. +#[derive(Debug, Default, PartialEq, Eq)] +pub struct VisitedPcsSet(HashMap>); +impl VisitedPcs for VisitedPcsSet { + type Pcs = HashSet; + + fn new() -> Self { + VisitedPcsSet(HashMap::default()) + } + + fn insert(&mut self, class_hash: &ClassHash, pcs: &[usize]) { + self.0.entry(*class_hash).or_default().extend(pcs); + } + + fn extend(&mut self, class_hash: &ClassHash, pcs: &Self::Pcs) { + self.0.entry(*class_hash).or_default().extend(pcs); + } + + fn iter(&self) -> impl Iterator { + self.0.iter() + } + + fn entry(&mut self, class_hash: ClassHash) -> Entry<'_, ClassHash, HashSet> { + self.0.entry(class_hash) + } + + fn add_visited_pcs(state: &mut dyn State, class_hash: &ClassHash, pcs: Self::Pcs) { + state.add_visited_pcs(*class_hash, &Vec::from_iter(pcs)); + } + + fn to_set(pcs: Self::Pcs) -> HashSet { + pcs + } +} diff --git a/crates/blockifier/src/test_utils/initial_test_state.rs b/crates/blockifier/src/test_utils/initial_test_state.rs index 6e0268cb29..6bf2fc56b3 100644 --- a/crates/blockifier/src/test_utils/initial_test_state.rs +++ b/crates/blockifier/src/test_utils/initial_test_state.rs @@ -7,6 +7,7 @@ use strum::IntoEnumIterator; use crate::abi::abi_utils::get_fee_token_var_address; use crate::context::ChainInfo; use crate::state::cached_state::CachedState; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::CairoVersion; @@ -40,7 +41,7 @@ pub fn test_state_inner( initial_balances: u128, contract_instances: &[(FeatureContract, u16)], erc20_contract_version: CairoVersion, -) -> CachedState { +) -> CachedState { let mut class_hash_to_class = HashMap::new(); let mut address_to_class_hash = HashMap::new(); @@ -87,6 +88,6 @@ pub fn test_state( chain_info: &ChainInfo, initial_balances: u128, contract_instances: &[(FeatureContract, u16)], -) -> CachedState { +) -> CachedState { test_state_inner(chain_info, initial_balances, contract_instances, CairoVersion::Cairo0) } diff --git a/crates/blockifier/src/test_utils/transfers_generator.rs b/crates/blockifier/src/test_utils/transfers_generator.rs index 2018d2eb58..e1c8083fb2 100644 --- a/crates/blockifier/src/test_utils/transfers_generator.rs +++ b/crates/blockifier/src/test_utils/transfers_generator.rs @@ -10,6 +10,7 @@ use crate::abi::abi_utils::selector_from_name; use crate::blockifier::config::{ConcurrencyConfig, TransactionExecutorConfig}; use crate::blockifier::transaction_executor::TransactionExecutor; use crate::context::{BlockContext, ChainInfo}; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -73,7 +74,7 @@ pub enum RecipientGeneratorType { pub struct TransfersGenerator { account_addresses: Vec, chain_info: ChainInfo, - executor: TransactionExecutor, + executor: TransactionExecutor, nonce_manager: NonceManager, sender_index: usize, random_recipient_generator: Option, diff --git a/crates/blockifier/src/transaction/account_transaction.rs b/crates/blockifier/src/transaction/account_transaction.rs index 15d30dd554..2676e72d92 100644 --- a/crates/blockifier/src/transaction/account_transaction.rs +++ b/crates/blockifier/src/transaction/account_transaction.rs @@ -31,6 +31,7 @@ use crate::fee::receipt::TransactionReceipt; use crate::retdata; use crate::state::cached_state::{StateChanges, TransactionalState}; use crate::state::state_api::{State, StateReader, UpdatableState}; +use crate::state::visited_pcs::VisitedPcs; use crate::transaction::constants; use crate::transaction::errors::{ TransactionExecutionError, @@ -351,9 +352,9 @@ impl AccountTransaction { } } - fn handle_fee( + fn handle_fee( &self, - state: &mut TransactionalState<'_, S>, + state: &mut TransactionalState<'_, S, V>, tx_context: Arc, actual_fee: Fee, charge_fee: bool, @@ -418,8 +419,8 @@ impl AccountTransaction { /// manipulates the state to avoid that part. /// Note: the returned transfer call info is partial, and should be completed at the commit /// stage, as well as the actual sequencer balance. - fn concurrency_execute_fee_transfer( - state: &mut TransactionalState<'_, S>, + fn concurrency_execute_fee_transfer( + state: &mut TransactionalState<'_, S, V>, tx_context: Arc, actual_fee: Fee, ) -> TransactionExecutionResult { @@ -458,9 +459,9 @@ impl AccountTransaction { } } - fn run_non_revertible( + fn run_non_revertible( &self, - state: &mut TransactionalState<'_, S>, + state: &mut TransactionalState<'_, S, V>, tx_context: Arc, remaining_gas: &mut u64, validate: bool, @@ -521,9 +522,9 @@ impl AccountTransaction { } } - fn run_revertible( + fn run_revertible( &self, - state: &mut TransactionalState<'_, S>, + state: &mut TransactionalState<'_, S, V>, tx_context: Arc, remaining_gas: &mut u64, validate: bool, @@ -662,9 +663,9 @@ impl AccountTransaction { } /// Runs validation and execution. - fn run_or_revert( + fn run_or_revert( &self, - state: &mut TransactionalState<'_, S>, + state: &mut TransactionalState<'_, S, V>, remaining_gas: &mut u64, tx_context: Arc, validate: bool, @@ -678,10 +679,10 @@ impl AccountTransaction { } } -impl ExecutableTransaction for AccountTransaction { +impl> ExecutableTransaction for AccountTransaction { fn execute_raw( &self, - state: &mut TransactionalState<'_, U>, + state: &mut TransactionalState<'_, U, V>, block_context: &BlockContext, execution_flags: ExecutionFlags, ) -> TransactionExecutionResult { diff --git a/crates/blockifier/src/transaction/execution_flavors_test.rs b/crates/blockifier/src/transaction/execution_flavors_test.rs index 0c7f356275..cbe264f2fb 100644 --- a/crates/blockifier/src/transaction/execution_flavors_test.rs +++ b/crates/blockifier/src/transaction/execution_flavors_test.rs @@ -19,6 +19,7 @@ use crate::execution::syscalls::SyscallSelector; use crate::fee::fee_utils::get_fee_by_gas_vector; use crate::state::cached_state::CachedState; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::{VisitedPcs, VisitedPcsSet}; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -56,7 +57,7 @@ use crate::transaction::transactions::ExecutableTransaction; const VALIDATE_GAS_OVERHEAD: u64 = 21; struct FlavorTestInitialState { - pub state: CachedState, + pub state: CachedState, pub account_address: ContractAddress, pub faulty_account_address: ContractAddress, pub test_contract_address: ContractAddress, @@ -86,9 +87,9 @@ fn create_flavors_test_state( /// Checks that balance of the account decreased if and only if `charge_fee` is true. /// Returns the new balance. -fn check_balance( +fn check_balance( current_balance: Felt, - state: &CachedState, + state: &mut CachedState, account_address: ContractAddress, chain_info: &ChainInfo, fee_type: &FeeType, @@ -476,7 +477,7 @@ fn test_simulate_validate_charge_fee_mid_execution( ); let current_balance = check_balance( current_balance, - &state, + &mut state, account_address, &block_context.chain_info, &fee_type, @@ -525,8 +526,14 @@ fn test_simulate_validate_charge_fee_mid_execution( // charged final fee is shown in actual_fee. if charge_fee { limited_fee } else { unlimited_fee }, ); - let current_balance = - check_balance(current_balance, &state, account_address, chain_info, &fee_type, charge_fee); + let current_balance = check_balance( + current_balance, + &mut state, + account_address, + chain_info, + &fee_type, + charge_fee, + ); // Third scenario: only limit is block bounds. Expect resources consumed to be identical, // whether or not `charge_fee` is true. @@ -565,7 +572,7 @@ fn test_simulate_validate_charge_fee_mid_execution( block_limit_fee, block_limit_fee, ); - check_balance(current_balance, &state, account_address, chain_info, &fee_type, charge_fee); + check_balance(current_balance, &mut state, account_address, chain_info, &fee_type, charge_fee); } #[rstest] @@ -655,8 +662,14 @@ fn test_simulate_validate_charge_fee_post_execution( if charge_fee { just_not_enough_fee_bound } else { unlimited_fee }, if charge_fee { revert_fee } else { unlimited_fee }, ); - let current_balance = - check_balance(current_balance, &state, account_address, chain_info, &fee_type, charge_fee); + let current_balance = check_balance( + current_balance, + &mut state, + account_address, + chain_info, + &fee_type, + charge_fee, + ); // Second scenario: balance too low. // Execute a transfer, and make sure we get the expected result. @@ -722,7 +735,7 @@ fn test_simulate_validate_charge_fee_post_execution( ); check_balance( current_balance, - &state, + &mut state, account_address, chain_info, &fee_type, diff --git a/crates/blockifier/src/transaction/test_utils.rs b/crates/blockifier/src/transaction/test_utils.rs index eb4b5600e5..701bc6353e 100644 --- a/crates/blockifier/src/transaction/test_utils.rs +++ b/crates/blockifier/src/transaction/test_utils.rs @@ -26,6 +26,7 @@ use crate::context::{BlockContext, ChainInfo}; use crate::execution::contract_class::{ClassInfo, ContractClass}; use crate::state::cached_state::CachedState; use crate::state::state_api::State; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; use crate::test_utils::deploy_account::deploy_account_tx; @@ -99,7 +100,7 @@ pub fn block_context() -> BlockContext { /// Struct containing the data usually needed to initialize a test. pub struct TestInitData { - pub state: CachedState, + pub state: CachedState, pub account_address: ContractAddress, pub contract_address: ContractAddress, pub nonce_manager: NonceManager, @@ -108,7 +109,7 @@ pub struct TestInitData { /// Deploys a new account with the given class hash, funds with both fee tokens, and returns the /// deploy tx and address. pub fn deploy_and_fund_account( - state: &mut CachedState, + state: &mut CachedState, nonce_manager: &mut NonceManager, chain_info: &ChainInfo, deploy_tx_args: DeployAccountTxArgs, @@ -294,7 +295,7 @@ pub fn account_invoke_tx(invoke_args: InvokeTxArgs) -> AccountTransaction { } pub fn run_invoke_tx( - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, invoke_args: InvokeTxArgs, ) -> TransactionExecutionResult { diff --git a/crates/blockifier/src/transaction/transaction_execution.rs b/crates/blockifier/src/transaction/transaction_execution.rs index 48a120b2f2..ea171e4091 100644 --- a/crates/blockifier/src/transaction/transaction_execution.rs +++ b/crates/blockifier/src/transaction/transaction_execution.rs @@ -11,6 +11,7 @@ use crate::execution::entry_point::EntryPointExecutionContext; use crate::fee::receipt::TransactionReceipt; use crate::state::cached_state::TransactionalState; use crate::state::state_api::UpdatableState; +use crate::state::visited_pcs::VisitedPcs; use crate::transaction::account_transaction::AccountTransaction; use crate::transaction::errors::TransactionFeeError; use crate::transaction::objects::{ @@ -108,10 +109,12 @@ impl TransactionInfoCreator for Transaction { } } -impl ExecutableTransaction for L1HandlerTransaction { +impl> ExecutableTransaction + for L1HandlerTransaction +{ fn execute_raw( &self, - state: &mut TransactionalState<'_, U>, + state: &mut TransactionalState<'_, U, V>, block_context: &BlockContext, _execution_flags: ExecutionFlags, ) -> TransactionExecutionResult { @@ -159,10 +162,10 @@ impl ExecutableTransaction for L1HandlerTransaction { } } -impl ExecutableTransaction for Transaction { +impl> ExecutableTransaction for Transaction { fn execute_raw( &self, - state: &mut TransactionalState<'_, U>, + state: &mut TransactionalState<'_, U, V>, block_context: &BlockContext, execution_flags: ExecutionFlags, ) -> TransactionExecutionResult { diff --git a/crates/blockifier/src/transaction/transactions.rs b/crates/blockifier/src/transaction/transactions.rs index 4d27c34269..acc41a7bca 100644 --- a/crates/blockifier/src/transaction/transactions.rs +++ b/crates/blockifier/src/transaction/transactions.rs @@ -29,6 +29,7 @@ use crate::execution::execution_utils::{execute_deployment, update_remaining_gas use crate::state::cached_state::TransactionalState; use crate::state::errors::StateError; use crate::state::state_api::{State, UpdatableState}; +use crate::state::visited_pcs::VisitedPcs; use crate::transaction::constants; use crate::transaction::errors::TransactionExecutionError; use crate::transaction::objects::{ @@ -60,7 +61,7 @@ pub struct ExecutionFlags { pub concurrency_mode: bool, } -pub trait ExecutableTransaction: Sized { +pub trait ExecutableTransaction>: Sized { /// Executes the transaction in a transactional manner /// (if it fails, given state does not modify). fn execute( @@ -96,7 +97,7 @@ pub trait ExecutableTransaction: Sized { /// for automatic handling of such cases. fn execute_raw( &self, - state: &mut TransactionalState<'_, U>, + state: &mut TransactionalState<'_, U, V>, block_context: &BlockContext, execution_flags: ExecutionFlags, ) -> TransactionExecutionResult; diff --git a/crates/blockifier/src/transaction/transactions_test.rs b/crates/blockifier/src/transaction/transactions_test.rs index f348039070..9bed29874d 100644 --- a/crates/blockifier/src/transaction/transactions_test.rs +++ b/crates/blockifier/src/transaction/transactions_test.rs @@ -68,6 +68,7 @@ use crate::fee::receipt::TransactionReceipt; use crate::state::cached_state::{CachedState, StateChangesCount, TransactionalState}; use crate::state::errors::StateError; use crate::state::state_api::{State, StateReader}; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; use crate::test_utils::deploy_account::deploy_account_tx; @@ -305,7 +306,7 @@ fn get_expected_cairo_resources( /// and the sequencer (in both fee types) are as expected (assuming the initial sequencer balances /// are zero). fn validate_final_balances( - state: &mut CachedState, + state: &mut CachedState, chain_info: &ChainInfo, expected_actual_fee: Fee, erc20_account_balance_key: StorageKey, @@ -554,7 +555,7 @@ fn test_invoke_tx( // Verifies the storage after each invoke execution in test_invoke_tx_advanced_operations. fn verify_storage_after_invoke_advanced_operations( - state: &mut CachedState, + state: &mut CachedState, contract_address: ContractAddress, account_address: ContractAddress, index: Felt, @@ -798,7 +799,7 @@ fn test_state_get_fee_token_balance( } fn assert_failure_if_resource_bounds_exceed_balance( - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, invalid_tx: AccountTransaction, ) { @@ -1092,7 +1093,7 @@ fn test_invalid_nonce( calldata: create_trivial_calldata(test_contract.get_instance_address(0)), resource_bounds: max_resource_bounds, }; - let mut transactional_state = TransactionalState::create_transactional(state); + let mut transactional_state = TransactionalState::create_transactional_for_testing(state); // Strict, negative flow: account nonce = 0, incoming tx nonce = 1. let invalid_nonce = nonce!(1_u8); diff --git a/crates/gateway/src/stateful_transaction_validator.rs b/crates/gateway/src/stateful_transaction_validator.rs index f064acf130..51b728ef3f 100644 --- a/crates/gateway/src/stateful_transaction_validator.rs +++ b/crates/gateway/src/stateful_transaction_validator.rs @@ -6,6 +6,7 @@ use blockifier::blockifier::stateful_validator::{ use blockifier::bouncer::BouncerConfig; use blockifier::context::BlockContext; use blockifier::state::cached_state::CachedState; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::transaction::account_transaction::AccountTransaction; use blockifier::versioned_constants::VersionedConstants; #[cfg(test)] @@ -32,7 +33,7 @@ pub struct StatefulTransactionValidator { pub config: StatefulTransactionValidatorConfig, } -type BlockifierStatefulValidator = StatefulValidator>; +type BlockifierStatefulValidator = StatefulValidator, VisitedPcsSet>; // TODO(yair): move the trait to Blockifier. #[cfg_attr(test, automock)] diff --git a/crates/native_blockifier/src/py_block_executor.rs b/crates/native_blockifier/src/py_block_executor.rs index 841619756c..6da2621eea 100644 --- a/crates/native_blockifier/src/py_block_executor.rs +++ b/crates/native_blockifier/src/py_block_executor.rs @@ -8,6 +8,7 @@ use blockifier::context::{BlockContext, ChainInfo, FeeTokenAddresses}; use blockifier::execution::call_info::CallInfo; use blockifier::state::cached_state::CachedState; use blockifier::state::global_cache::GlobalContractCache; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::transaction::objects::{GasVector, ResourcesMapping, TransactionExecutionInfo}; use blockifier::transaction::transaction_execution::Transaction; use blockifier::versioned_constants::VersionedConstants; @@ -75,7 +76,7 @@ pub struct PyBlockExecutor { pub tx_executor_config: TransactionExecutorConfig, pub chain_info: ChainInfo, pub versioned_constants: VersionedConstants, - pub tx_executor: Option>, + pub tx_executor: Option>, /// `Send` trait is required for `pyclass` compatibility as Python objects must be threadsafe. pub storage: Box, pub global_contract_cache: GlobalContractCache, @@ -356,7 +357,7 @@ impl PyBlockExecutor { } impl PyBlockExecutor { - pub fn tx_executor(&mut self) -> &mut TransactionExecutor { + pub fn tx_executor(&mut self) -> &mut TransactionExecutor { self.tx_executor.as_mut().expect("Transaction executor should be initialized") } diff --git a/crates/native_blockifier/src/py_test_utils.rs b/crates/native_blockifier/src/py_test_utils.rs index 0e66423790..e5c7fccc4a 100644 --- a/crates/native_blockifier/src/py_test_utils.rs +++ b/crates/native_blockifier/src/py_test_utils.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use blockifier::execution::contract_class::ContractClassV0; use blockifier::state::cached_state::CachedState; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::test_utils::dict_state_reader::DictStateReader; use starknet_api::core::ClassHash; use starknet_api::{class_hash, felt}; @@ -12,7 +13,7 @@ pub const TOKEN_FOR_TESTING_CONTRACT_PATH: &str = "./src/starkware/starknet/core/test_contract/starknet_compiled_contracts_lib/starkware/\ starknet/core/test_contract/token_for_testing.json"; -pub fn create_py_test_state() -> CachedState { +pub fn create_py_test_state() -> CachedState { let class_hash_to_class = HashMap::from([( class_hash!(TOKEN_FOR_TESTING_CLASS_HASH), ContractClassV0::from_file(TOKEN_FOR_TESTING_CONTRACT_PATH).into(), diff --git a/crates/native_blockifier/src/py_validator.rs b/crates/native_blockifier/src/py_validator.rs index e035f0c4ef..66d1f2fda8 100644 --- a/crates/native_blockifier/src/py_validator.rs +++ b/crates/native_blockifier/src/py_validator.rs @@ -2,6 +2,7 @@ use blockifier::blockifier::stateful_validator::{StatefulValidator, StatefulVali use blockifier::bouncer::BouncerConfig; use blockifier::context::BlockContext; use blockifier::state::cached_state::CachedState; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::transaction::account_transaction::AccountTransaction; use blockifier::transaction::objects::TransactionInfoCreator; use blockifier::transaction::transaction_types::TransactionType; @@ -21,7 +22,7 @@ use crate::state_readers::py_state_reader::PyStateReader; #[pyclass] pub struct PyValidator { - pub stateful_validator: StatefulValidator, + pub stateful_validator: StatefulValidator, pub max_nonce_for_validation_skip: Nonce, } diff --git a/crates/native_blockifier/src/state_readers/papyrus_state_test.rs b/crates/native_blockifier/src/state_readers/papyrus_state_test.rs index e999276084..89a5ba133c 100644 --- a/crates/native_blockifier/src/state_readers/papyrus_state_test.rs +++ b/crates/native_blockifier/src/state_readers/papyrus_state_test.rs @@ -7,6 +7,7 @@ use blockifier::retdata; use blockifier::state::cached_state::CachedState; use blockifier::state::global_cache::{GlobalContractCache, GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST}; use blockifier::state::state_api::StateReader; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::test_utils::contracts::FeatureContract; use blockifier::test_utils::{trivial_external_entry_point_new, CairoVersion}; use indexmap::IndexMap; @@ -56,7 +57,7 @@ fn test_entry_point_with_papyrus_state() -> papyrus_storage::StorageResult<()> { block_number, GlobalContractCache::new(GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST), ); - let mut state = CachedState::from(papyrus_reader); + let mut state: CachedState<_, VisitedPcsSet> = CachedState::from(papyrus_reader); // Call entrypoint that want to write to storage, which updates the cached state's write cache. let key = felt!(1234_u16); diff --git a/crates/papyrus_execution/src/execution_utils.rs b/crates/papyrus_execution/src/execution_utils.rs index 928728acab..196dbc0d0c 100644 --- a/crates/papyrus_execution/src/execution_utils.rs +++ b/crates/papyrus_execution/src/execution_utils.rs @@ -11,6 +11,7 @@ use blockifier::execution::contract_class::{ }; use blockifier::state::cached_state::{CachedState, CommitmentStateDiff, MutRefState}; use blockifier::state::state_api::StateReader; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::transaction::objects::TransactionExecutionInfo; use cairo_vm::types::errors::program_errors::ProgramError; use indexmap::IndexMap; @@ -117,7 +118,10 @@ pub fn get_trace_constructor( /// is a deprecated Declare, the user is required to pass the class hash of the deprecated class as /// it is not provided by the blockifier API. pub fn induced_state_diff( - transactional_state: &mut CachedState>>, + transactional_state: &mut CachedState< + MutRefState<'_, CachedState>, + VisitedPcsSet, + >, deprecated_declared_class_hash: Option, ) -> ExecutionResult { let blockifier_state_diff = CommitmentStateDiff::from(transactional_state.to_state_diff()?); diff --git a/crates/papyrus_execution/src/lib.rs b/crates/papyrus_execution/src/lib.rs index e4502626ff..d967a60db0 100644 --- a/crates/papyrus_execution/src/lib.rs +++ b/crates/papyrus_execution/src/lib.rs @@ -35,6 +35,7 @@ use blockifier::execution::entry_point::{ EntryPointExecutionContext, }; use blockifier::state::cached_state::CachedState; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::transaction::errors::TransactionExecutionError as BlockifierTransactionExecutionError; use blockifier::transaction::objects::{ DeprecatedTransactionInfo, @@ -289,7 +290,7 @@ fn verify_contract_exists( } fn create_block_context( - cached_state: &mut CachedState, + cached_state: &mut CachedState, block_context_number: BlockNumber, chain_id: ChainId, storage_reader: &StorageReader, @@ -684,7 +685,7 @@ impl From<(usize, BlockifierTransactionExecutionError)> for ExecutionError { fn get_10_blocks_ago( block_number: &BlockNumber, - cached_state: &CachedState, + cached_state: &CachedState, ) -> ExecutionResult> { if block_number.0 < 10 { return Ok(None);