From 5ad0edfdf69b4ac5ac32f83364a50e321b5fd9e7 Mon Sep 17 00:00:00 2001 From: AvivYossef-starkware Date: Mon, 9 Dec 2024 08:58:54 +0200 Subject: [PATCH] refactor(blockifier): state reader trait --- .../src/blockifier/transaction_executor.rs | 9 +++- crates/blockifier/src/bouncer.rs | 3 +- .../blockifier/src/concurrency/fee_utils.rs | 4 +- .../blockifier/src/concurrency/flow_test.rs | 12 +++-- .../src/concurrency/versioned_state.rs | 40 +++++++++------ .../src/concurrency/versioned_state_test.rs | 51 +++++++++++-------- .../src/concurrency/worker_logic.rs | 4 +- .../src/execution/contract_class.rs | 2 +- .../blockifier/src/execution/entry_point.rs | 12 +++-- .../src/execution/syscalls/syscall_base.rs | 4 +- crates/blockifier/src/state/cached_state.rs | 29 +++++++---- .../blockifier/src/state/cached_state_test.rs | 12 ++--- crates/blockifier/src/state/state_api.rs | 17 ++++--- crates/blockifier/src/test_utils/contracts.rs | 1 - .../src/test_utils/dict_state_reader.rs | 15 +++--- .../src/test_utils/initial_test_state.rs | 4 +- .../src/transaction/account_transaction.rs | 2 +- .../transaction/account_transactions_test.rs | 7 ++- .../src/transaction/transactions.rs | 35 ++++++++++++- .../src/transaction/transactions_test.rs | 5 +- .../src/state_reader/offline_state_reader.rs | 24 ++++++--- .../src/state_reader/test_state_reader.rs | 24 ++++++--- .../src/py_block_executor_test.rs | 5 +- crates/native_blockifier/src/py_test_utils.rs | 10 +++- .../src/state_readers/py_state_reader.rs | 16 ++++-- crates/papyrus_execution/src/state_reader.rs | 18 ++++--- .../src/state_reader_test.rs | 16 ++++-- .../papyrus_state_reader/src/papyrus_state.rs | 9 ++-- .../starknet_gateway/src/rpc_state_reader.rs | 23 ++++++--- .../src/rpc_state_reader_test.rs | 5 +- crates/starknet_gateway/src/state_reader.rs | 7 ++- .../src/state_reader_test_utils.rs | 7 ++- 32 files changed, 289 insertions(+), 143 deletions(-) diff --git a/crates/blockifier/src/blockifier/transaction_executor.rs b/crates/blockifier/src/blockifier/transaction_executor.rs index bec2bd5b8cb..f28b8fee95a 100644 --- a/crates/blockifier/src/blockifier/transaction_executor.rs +++ b/crates/blockifier/src/blockifier/transaction_executor.rs @@ -13,6 +13,7 @@ use crate::blockifier::config::TransactionExecutorConfig; use crate::bouncer::{Bouncer, BouncerWeights}; use crate::concurrency::worker_logic::WorkerExecutor; use crate::context::BlockContext; +use crate::execution::contract_class::RunnableCompiledClass; use crate::state::cached_state::{CachedState, CommitmentStateDiff, TransactionalState}; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult}; @@ -156,12 +157,16 @@ impl TransactionExecutor { .visited_pcs .iter() .map(|(class_hash, class_visited_pcs)| -> TransactionExecutorResult<_> { - let contract_class = self + let versioned_contract_class = self .block_state .as_ref() .expect(BLOCK_STATE_ACCESS_ERR) .get_compiled_class(*class_hash)?; - Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?)) + Ok(( + *class_hash, + RunnableCompiledClass::from(versioned_contract_class) + .get_visited_segments(class_visited_pcs)?, + )) }) .collect::>()?; diff --git a/crates/blockifier/src/bouncer.rs b/crates/blockifier/src/bouncer.rs index 99c93451841..846bdf53043 100644 --- a/crates/blockifier/src/bouncer.rs +++ b/crates/blockifier/src/bouncer.rs @@ -13,6 +13,7 @@ use crate::blockifier::transaction_executor::{ TransactionExecutorResult, }; use crate::execution::call_info::ExecutionSummary; +use crate::execution::contract_class::RunnableCompiledClass; use crate::fee::gas_usage::get_onchain_data_segment_length; use crate::fee::resources::TransactionResources; use crate::state::cached_state::{StateChangesKeys, StorageEntry}; @@ -565,7 +566,7 @@ pub fn get_casm_hash_calculation_resources( let mut casm_hash_computation_resources = ExecutionResources::default(); for class_hash in executed_class_hashes { - let class = state_reader.get_compiled_class(*class_hash)?; + let class: RunnableCompiledClass = state_reader.get_compiled_class(*class_hash)?.into(); casm_hash_computation_resources += &class.estimate_casm_hash_computation_resources(); } diff --git a/crates/blockifier/src/concurrency/fee_utils.rs b/crates/blockifier/src/concurrency/fee_utils.rs index 8ea34ea0654..dff673281af 100644 --- a/crates/blockifier/src/concurrency/fee_utils.rs +++ b/crates/blockifier/src/concurrency/fee_utils.rs @@ -8,7 +8,7 @@ use starknet_types_core::felt::Felt; use crate::context::{BlockContext, TransactionContext}; use crate::execution::call_info::CallInfo; use crate::fee::fee_utils::get_sequencer_balance_keys; -use crate::state::cached_state::{ContractClassMapping, StateMaps}; +use crate::state::cached_state::{StateMaps, VersionedContractClassMapping}; use crate::state::state_api::UpdatableState; use crate::transaction::objects::TransactionExecutionInfo; @@ -118,5 +118,5 @@ pub fn add_fee_to_sequencer_balance( ]), ..StateMaps::default() }; - state.apply_writes(&writes, &ContractClassMapping::default(), &HashMap::default()); + state.apply_writes(&writes, &VersionedContractClassMapping::default(), &HashMap::default()); } diff --git a/crates/blockifier/src/concurrency/flow_test.rs b/crates/blockifier/src/concurrency/flow_test.rs index 5b828c32600..0ddf5913dc7 100644 --- a/crates/blockifier/src/concurrency/flow_test.rs +++ b/crates/blockifier/src/concurrency/flow_test.rs @@ -9,7 +9,7 @@ use crate::abi::sierra_types::{SierraType, SierraU128}; use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus}; use crate::concurrency::test_utils::{safe_versioned_state_for_testing, DEFAULT_CHUNK_SIZE}; use crate::concurrency::versioned_state::ThreadSafeVersionedState; -use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps}; +use crate::state::cached_state::{CachedState, StateMaps, VersionedContractClassMapping}; use crate::state::state_api::UpdatableState; use crate::test_utils::dict_state_reader::DictStateReader; @@ -43,14 +43,15 @@ fn scheduler_flow_test( get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state); let reads_valid = state_proxy.validate_reads(&reads); if !reads_valid { - state_proxy.delete_writes(&writes, &ContractClassMapping::default()); + state_proxy + .delete_writes(&writes, &VersionedContractClassMapping::default()); let (_, new_writes) = get_reads_writes_for( Task::ExecutionTask(tx_index), &versioned_state, ); state_proxy.apply_writes( &new_writes, - &ContractClassMapping::default(), + &VersionedContractClassMapping::default(), &HashMap::default(), ); scheduler.finish_execution_during_commit(tx_index); @@ -63,7 +64,7 @@ fn scheduler_flow_test( get_reads_writes_for(Task::ExecutionTask(tx_index), &versioned_state); versioned_state.pin_version(tx_index).apply_writes( &writes, - &ContractClassMapping::default(), + &VersionedContractClassMapping::default(), &HashMap::default(), ); scheduler.finish_execution(tx_index); @@ -76,7 +77,8 @@ fn scheduler_flow_test( let read_set_valid = state_proxy.validate_reads(&reads); let aborted = !read_set_valid && scheduler.try_validation_abort(tx_index); if aborted { - state_proxy.delete_writes(&writes, &ContractClassMapping::default()); + state_proxy + .delete_writes(&writes, &VersionedContractClassMapping::default()); scheduler.finish_abort(tx_index) } else { Task::AskForTask diff --git a/crates/blockifier/src/concurrency/versioned_state.rs b/crates/blockifier/src/concurrency/versioned_state.rs index 1d3c9a9270a..1801fa40ca1 100644 --- a/crates/blockifier/src/concurrency/versioned_state.rs +++ b/crates/blockifier/src/concurrency/versioned_state.rs @@ -7,8 +7,8 @@ use starknet_types_core::felt::Felt; use crate::concurrency::versioned_storage::VersionedStorage; use crate::concurrency::TxIndex; -use crate::execution::contract_class::RunnableCompiledClass; -use crate::state::cached_state::{ContractClassMapping, StateMaps}; +use crate::execution::contract_class::VersionedRunnableCompiledClass; +use crate::state::cached_state::{StateMaps, VersionedContractClassMapping}; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult, UpdatableState}; @@ -34,7 +34,8 @@ pub struct VersionedState { // the compiled contract classes mapping. Each key with value false, sohuld not apprear // in the compiled contract classes mapping. declared_contracts: VersionedStorage, - compiled_contract_classes: VersionedStorage, + versioned_compiled_contract_classes: + VersionedStorage, } impl VersionedState { @@ -45,7 +46,7 @@ impl VersionedState { nonces: VersionedStorage::default(), class_hashes: VersionedStorage::default(), compiled_class_hashes: VersionedStorage::default(), - compiled_contract_classes: VersionedStorage::default(), + versioned_compiled_contract_classes: VersionedStorage::default(), declared_contracts: VersionedStorage::default(), } } @@ -121,7 +122,7 @@ impl VersionedState { let is_declared = self.declared_contracts.read(tx_index, class_hash).expect(READ_ERR); assert_eq!( is_declared, - self.compiled_contract_classes.read(tx_index, class_hash).is_some(), + self.versioned_compiled_contract_classes.read(tx_index, class_hash).is_some(), "The declared contracts mapping should match the compiled contract classes \ mapping." ); @@ -139,7 +140,7 @@ impl VersionedState { &mut self, tx_index: TxIndex, writes: &StateMaps, - class_hash_to_class: &ContractClassMapping, + class_hash_to_class: &VersionedContractClassMapping, ) { for (&key, &value) in &writes.storage { self.storage.write(tx_index, key, value); @@ -154,13 +155,13 @@ impl VersionedState { self.compiled_class_hashes.write(tx_index, key, value); } for (&key, value) in class_hash_to_class { - self.compiled_contract_classes.write(tx_index, key, value.clone()); + self.versioned_compiled_contract_classes.write(tx_index, key, value.clone()); } for (&key, &value) in &writes.declared_contracts { self.declared_contracts.write(tx_index, key, value); assert_eq!( value, - self.compiled_contract_classes.read(tx_index, key).is_some(), + self.versioned_compiled_contract_classes.read(tx_index, key).is_some(), "The declared contracts mapping should match the compiled contract classes \ mapping." ); @@ -171,7 +172,7 @@ impl VersionedState { &mut self, tx_index: TxIndex, writes: &StateMaps, - class_hash_to_class: &ContractClassMapping, + class_hash_to_class: &VersionedContractClassMapping, ) { for &key in writes.storage.keys() { self.storage.delete_write(key, tx_index); @@ -189,7 +190,7 @@ impl VersionedState { self.declared_contracts.delete_write(key, tx_index); } for &key in class_hash_to_class.keys() { - self.compiled_contract_classes.delete_write(key, tx_index); + self.versioned_compiled_contract_classes.delete_write(key, tx_index); } } @@ -210,7 +211,7 @@ impl VersionedState { let commit_index = n_committed_txs - 1; let writes = self.get_writes_up_to_index(commit_index); let class_hash_to_class = - self.compiled_contract_classes.get_writes_up_to_index(commit_index); + self.versioned_compiled_contract_classes.get_writes_up_to_index(commit_index); let mut state = self.into_initial_state(); state.apply_writes(&writes, &class_hash_to_class, &visited_pcs); state @@ -266,7 +267,11 @@ impl VersionedStateProxy { self.state().validate_reads(self.tx_index, reads) } - pub fn delete_writes(&self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping) { + pub fn delete_writes( + &self, + writes: &StateMaps, + class_hash_to_class: &VersionedContractClassMapping, + ) { self.state().delete_writes(self.tx_index, writes, class_hash_to_class); } } @@ -276,7 +281,7 @@ impl UpdatableState for VersionedStateProxy { fn apply_writes( &mut self, writes: &StateMaps, - class_hash_to_class: &ContractClassMapping, + class_hash_to_class: &VersionedContractClassMapping, _visited_pcs: &HashMap>, ) { self.state().apply_writes(self.tx_index, writes, class_hash_to_class) @@ -336,15 +341,18 @@ impl StateReader for VersionedStateProxy { } } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let mut state = self.state(); - match state.compiled_contract_classes.read(self.tx_index, class_hash) { + match state.versioned_compiled_contract_classes.read(self.tx_index, class_hash) { Some(value) => Ok(value), None => match state.initial_state.get_compiled_class(class_hash) { Ok(initial_value) => { state.declared_contracts.set_initial_value(class_hash, true); state - .compiled_contract_classes + .versioned_compiled_contract_classes .set_initial_value(class_hash, initial_value.clone()); Ok(initial_value) } diff --git a/crates/blockifier/src/concurrency/versioned_state_test.rs b/crates/blockifier/src/concurrency/versioned_state_test.rs index 6db3b95a1de..5fa29d2c9e4 100644 --- a/crates/blockifier/src/concurrency/versioned_state_test.rs +++ b/crates/blockifier/src/concurrency/versioned_state_test.rs @@ -34,9 +34,9 @@ use crate::concurrency::TxIndex; use crate::context::BlockContext; use crate::state::cached_state::{ CachedState, - ContractClassMapping, StateMaps, TransactionalState, + VersionedContractClassMapping, }; use crate::state::errors::StateError; use crate::state::state_api::{State, StateReader, UpdatableState}; @@ -72,7 +72,7 @@ fn test_versioned_state_proxy() { let class_hash = class_hash!(27_u8); let another_class_hash = class_hash!(28_u8); let compiled_class_hash = compiled_class_hash!(29_u8); - let contract_class = test_contract.get_runnable_class(); + let contract_class = test_contract.get_versioned_runnable_class(); // Create the versioned state let cached_state = CachedState::from(DictStateReader { @@ -117,7 +117,7 @@ fn test_versioned_state_proxy() { let compiled_class_hash_v18 = compiled_class_hash!(30_u8); let contract_class_v11 = FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm)) - .get_runnable_class(); + .get_versioned_runnable_class(); versioned_state_proxys[3].state().apply_writes( 3, @@ -404,10 +404,11 @@ fn test_false_validate_reads_declared_contracts( ..Default::default() }; let version_state_proxy = safe_versioned_state.pin_version(0); - let compiled_contract_calss = + let versioned_compiled_contract_class = FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm)) - .get_runnable_class(); - let class_hash_to_class = HashMap::from([(class_hash!(1_u8), compiled_contract_calss)]); + .get_versioned_runnable_class(); + let class_hash_to_class = + HashMap::from([(class_hash!(1_u8), versioned_compiled_contract_class)]); version_state_proxy.state().apply_writes(0, &tx_0_writes, &class_hash_to_class); assert!(!safe_versioned_state.pin_version(1).validate_reads(&tx_1_reads)); } @@ -431,11 +432,15 @@ fn test_apply_writes( assert_eq!(transactional_states[0].cache.borrow().writes.class_hashes.len(), 1); // Transaction 0 contract class. - let contract_class_0 = + + let versioned_contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm)) - .get_runnable_class(); + .get_versioned_runnable_class(); + assert!(transactional_states[0].class_hash_to_class.borrow().is_empty()); - transactional_states[0].set_contract_class(class_hash, contract_class_0.clone()).unwrap(); + transactional_states[0] + .set_contract_class(class_hash, versioned_contract_class_0.clone()) + .unwrap(); assert_eq!(transactional_states[0].class_hash_to_class.borrow().len(), 1); safe_versioned_state.pin_version(0).apply_writes( @@ -444,7 +449,10 @@ fn test_apply_writes( &HashMap::default(), ); assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0); - assert!(transactional_states[1].get_compiled_class(class_hash).unwrap() == contract_class_0); + assert!( + transactional_states[1].get_compiled_class(class_hash).unwrap() + == versioned_contract_class_0 + ); } #[rstest] @@ -513,7 +521,7 @@ fn test_delete_writes( tx_state .set_contract_class( feature_contract.get_class_hash(), - feature_contract.get_runnable_class(), + feature_contract.get_versioned_runnable_class(), ) .unwrap(); safe_versioned_state.pin_version(i).apply_writes( @@ -546,7 +554,7 @@ fn test_delete_writes( .0 .lock() .unwrap() - .compiled_contract_classes + .versioned_compiled_contract_classes .get_writes_of_index(tx_index) .is_empty(), should_be_empty @@ -573,8 +581,10 @@ fn test_delete_writes_completeness( )]), declared_contracts: HashMap::from([(feature_contract.get_class_hash(), true)]), }; - let class_hash_to_class_writes = - HashMap::from([(feature_contract.get_class_hash(), feature_contract.get_runnable_class())]); + let class_hash_to_class_writes = HashMap::from([( + feature_contract.get_class_hash(), + feature_contract.get_versioned_runnable_class(), + )]); let tx_index = 0; let mut versioned_state_proxy = safe_versioned_state.pin_version(tx_index); @@ -593,7 +603,7 @@ fn test_delete_writes_completeness( .0 .lock() .unwrap() - .compiled_contract_classes + .versioned_compiled_contract_classes .get_writes_of_index(tx_index), class_hash_to_class_writes ); @@ -608,9 +618,9 @@ fn test_delete_writes_completeness( .0 .lock() .unwrap() - .compiled_contract_classes + .versioned_compiled_contract_classes .get_writes_of_index(tx_index), - ContractClassMapping::default() + VersionedContractClassMapping::default() ); } @@ -637,10 +647,11 @@ fn test_versioned_proxy_state_flow( transactional_states[3].set_class_hash_at(contract_address, class_hash_3).unwrap(); // Clients contract class values. - let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo0).get_runnable_class(); + let contract_class_0 = + FeatureContract::TestContract(CairoVersion::Cairo0).get_versioned_runnable_class(); let contract_class_2 = - FeatureContract::AccountWithLongValidate(CairoVersion::Cairo1(RunnableCairo1::Casm)) - .get_runnable_class(); + FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm)) + .get_versioned_runnable_class(); transactional_states[0].set_contract_class(class_hash, contract_class_0).unwrap(); transactional_states[2].set_contract_class(class_hash, contract_class_2.clone()).unwrap(); diff --git a/crates/blockifier/src/concurrency/worker_logic.rs b/crates/blockifier/src/concurrency/worker_logic.rs index 851a462f31f..1e129709bd4 100644 --- a/crates/blockifier/src/concurrency/worker_logic.rs +++ b/crates/blockifier/src/concurrency/worker_logic.rs @@ -15,7 +15,7 @@ use crate::concurrency::utils::lock_mutex_in_array; use crate::concurrency::versioned_state::ThreadSafeVersionedState; use crate::concurrency::TxIndex; use crate::context::BlockContext; -use crate::state::cached_state::{ContractClassMapping, StateMaps, TransactionalState}; +use crate::state::cached_state::{StateMaps, TransactionalState, VersionedContractClassMapping}; use crate::state::state_api::{StateReader, UpdatableState}; use crate::transaction::objects::{TransactionExecutionInfo, TransactionExecutionResult}; use crate::transaction::transaction_execution::Transaction; @@ -32,7 +32,7 @@ pub struct ExecutionTaskOutput { pub reads: StateMaps, // TODO(Yoni): rename to state_diff. pub writes: StateMaps, - pub contract_classes: ContractClassMapping, + pub contract_classes: VersionedContractClassMapping, pub visited_pcs: HashMap>, pub result: TransactionExecutionResult, } diff --git a/crates/blockifier/src/execution/contract_class.rs b/crates/blockifier/src/execution/contract_class.rs index 976397bd182..3fd7a12a7e8 100644 --- a/crates/blockifier/src/execution/contract_class.rs +++ b/crates/blockifier/src/execution/contract_class.rs @@ -68,7 +68,7 @@ pub enum RunnableCompiledClass { } /// Represents a runnable compiled class for Cairo, with the Sierra version (for Cairo 1). -#[derive(Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum VersionedRunnableCompiledClass { Cairo0(RunnableCompiledClass), Cairo1((RunnableCompiledClass, SierraVersion)), diff --git a/crates/blockifier/src/execution/entry_point.rs b/crates/blockifier/src/execution/entry_point.rs index e7fd2e84a74..3eb6fba73f2 100644 --- a/crates/blockifier/src/execution/entry_point.rs +++ b/crates/blockifier/src/execution/entry_point.rs @@ -21,6 +21,7 @@ use starknet_api::transaction::fields::{ use starknet_api::transaction::TransactionVersion; use starknet_types_core::felt::Felt; +use super::contract_class::RunnableCompiledClass; use crate::context::{BlockContext, TransactionContext}; use crate::execution::call_info::CallInfo; use crate::execution::common_hints::ExecutionMode; @@ -148,7 +149,7 @@ impl CallEntryPoint { } // Add class hash to the call, that will appear in the output (call info). self.class_hash = Some(class_hash); - let compiled_class = state.get_compiled_class(class_hash)?; + let compiled_class: RunnableCompiledClass = state.get_compiled_class(class_hash)?.into(); context.revert_infos.0.push(EntryPointRevertInfo::new( self.storage_address, @@ -407,9 +408,12 @@ pub fn execute_constructor_entry_point( remaining_gas: &mut u64, ) -> ConstructorEntryPointExecutionResult { // Ensure the class is declared (by reading it). - let compiled_class = state.get_compiled_class(ctor_context.class_hash).map_err(|error| { - ConstructorEntryPointExecutionError::new(error.into(), &ctor_context, None) - })?; + let compiled_class: RunnableCompiledClass = state + .get_compiled_class(ctor_context.class_hash) + .map_err(|error| { + ConstructorEntryPointExecutionError::new(error.into(), &ctor_context, None) + })? + .into(); let Some(constructor_selector) = compiled_class.constructor_selector() else { // Contract has no constructor. return handle_empty_constructor(&ctor_context, calldata, *remaining_gas) diff --git a/crates/blockifier/src/execution/syscalls/syscall_base.rs b/crates/blockifier/src/execution/syscalls/syscall_base.rs index 7e172065f1c..6abca84dcfc 100644 --- a/crates/blockifier/src/execution/syscalls/syscall_base.rs +++ b/crates/blockifier/src/execution/syscalls/syscall_base.rs @@ -12,6 +12,7 @@ use super::exceeds_event_size_limit; use crate::abi::constants; use crate::execution::call_info::{CallInfo, MessageToL1, OrderedEvent, OrderedL2ToL1Message}; use crate::execution::common_hints::ExecutionMode; +use crate::execution::contract_class::RunnableCompiledClass; use crate::execution::entry_point::{ CallEntryPoint, ConstructorContext, @@ -164,7 +165,8 @@ impl<'state> SyscallHandlerBase<'state> { pub fn replace_class(&mut self, class_hash: ClassHash) -> SyscallResult<()> { // Ensure the class is declared (by reading it), and of type V1. - let compiled_class = self.state.get_compiled_class(class_hash)?; + let compiled_class: RunnableCompiledClass = + self.state.get_compiled_class(class_hash)?.into(); if !is_cairo1(&compiled_class) { return Err(SyscallExecutionError::ForbiddenClassReplacement { class_hash }); diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index f7b2a6b9b68..fbd716be8f1 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -8,7 +8,7 @@ use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; use crate::context::TransactionContext; -use crate::execution::contract_class::RunnableCompiledClass; +use crate::execution::contract_class::VersionedRunnableCompiledClass; use crate::state::errors::StateError; use crate::state::state_api::{State, StateReader, StateResult, UpdatableState}; use crate::transaction::objects::TransactionExecutionInfo; @@ -18,7 +18,7 @@ use crate::utils::{strict_subtract_mappings, subtract_mappings}; #[path = "cached_state_test.rs"] mod test; -pub type ContractClassMapping = HashMap; +pub type VersionedContractClassMapping = HashMap; /// Caches read and write requests. /// @@ -30,7 +30,7 @@ pub struct CachedState { // Invariant: read/write access is managed by CachedState. // Using interior mutability to update caches during `State`'s immutable getters. pub(crate) cache: RefCell, - pub(crate) class_hash_to_class: RefCell, + pub(crate) class_hash_to_class: RefCell, /// A map from class hash to the set of PC values that were visited in the class. pub visited_pcs: HashMap>, } @@ -66,7 +66,7 @@ impl CachedState { pub fn update_cache( &mut self, write_updates: &StateMaps, - local_contract_cache_updates: ContractClassMapping, + local_contract_cache_updates: VersionedContractClassMapping, ) { // Check consistency between declared_contracts and class_hash_to_class. for (&key, &value) in &write_updates.declared_contracts { @@ -114,7 +114,7 @@ impl UpdatableState for CachedState { fn apply_writes( &mut self, writes: &StateMaps, - class_hash_to_class: &ContractClassMapping, + class_hash_to_class: &VersionedContractClassMapping, visited_pcs: &HashMap>, ) { // TODO(Noa,15/5/24): Reconsider the clone. @@ -178,7 +178,10 @@ impl StateReader for CachedState { Ok(*class_hash) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let mut cache = self.cache.borrow_mut(); let class_hash_to_class = &mut *self.class_hash_to_class.borrow_mut(); @@ -202,12 +205,12 @@ impl StateReader for CachedState { } } - let contract_class = class_hash_to_class + let versioned_contract_class = class_hash_to_class .get(&class_hash) .cloned() .expect("The class hash must appear in the cache."); - Ok(contract_class) + Ok(versioned_contract_class) } fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult { @@ -261,9 +264,10 @@ impl State for CachedState { fn set_contract_class( &mut self, class_hash: ClassHash, - contract_class: RunnableCompiledClass, + versioned_contract_class: VersionedRunnableCompiledClass, ) -> StateResult<()> { - self.class_hash_to_class.get_mut().insert(class_hash, contract_class); + // TODO(Aviv): Use the actual Sierra version after the change in StateReader trait. + self.class_hash_to_class.get_mut().insert(class_hash, versioned_contract_class); let mut cache = self.cache.borrow_mut(); cache.declare_contract(class_hash); Ok(()) @@ -567,7 +571,10 @@ impl StateReader for MutRefState<'_, S> { self.0.get_class_hash_at(contract_address) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { self.0.get_compiled_class(class_hash) } diff --git a/crates/blockifier/src/state/cached_state_test.rs b/crates/blockifier/src/state/cached_state_test.rs index 824acd01b16..5961e62d112 100644 --- a/crates/blockifier/src/state/cached_state_test.rs +++ b/crates/blockifier/src/state/cached_state_test.rs @@ -26,7 +26,7 @@ const CONTRACT_ADDRESS: &str = "0x100"; fn set_initial_state_values( state: &mut CachedState, - class_hash_to_class: ContractClassMapping, + class_hash_to_class: VersionedContractClassMapping, nonce_initial_values: HashMap, class_hash_initial_values: HashMap, storage_initial_values: HashMap, @@ -117,7 +117,7 @@ fn declare_contract() { let mut state = CachedState::from(DictStateReader { ..Default::default() }); let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let class_hash = test_contract.get_class_hash(); - let contract_class = test_contract.get_runnable_class(); + let contract_class = test_contract.get_versioned_runnable_class(); assert_eq!(state.cache.borrow().writes.declared_contracts.get(&class_hash), None); assert_eq!(state.cache.borrow().initial_reads.declared_contracts.get(&class_hash), None); @@ -176,7 +176,7 @@ fn get_contract_class() { let state = test_state(&ChainInfo::create_for_testing(), Fee(0), &[(test_contract, 0)]); assert_eq!( state.get_compiled_class(test_contract.get_class_hash()).unwrap(), - test_contract.get_runnable_class() + test_contract.get_versioned_runnable_class() ); // Negative flow. @@ -224,7 +224,7 @@ fn cached_state_state_diff_conversion() { let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let test_class_hash = test_contract.get_class_hash(); let class_hash_to_class = - HashMap::from([(test_class_hash, test_contract.get_runnable_class())]); + HashMap::from([(test_class_hash, test_contract.get_versioned_runnable_class())]); let nonce_initial_values = HashMap::new(); @@ -531,7 +531,7 @@ fn test_contract_cache_is_used() { // cache. let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let class_hash = test_contract.get_class_hash(); - let contract_class = test_contract.get_runnable_class(); + let contract_class = test_contract.get_versioned_runnable_class(); let mut reader = DictStateReader::default(); reader.class_hash_to_class.insert(class_hash, contract_class.clone()); let state = CachedState::new(reader); @@ -540,7 +540,7 @@ fn test_contract_cache_is_used() { assert!(state.class_hash_to_class.borrow().get(&class_hash).is_none()); // Check state uses the cache. - assert_eq!(state.get_compiled_class(class_hash).unwrap(), contract_class); + assert_eq!(state.get_compiled_class(class_hash).unwrap(), contract_class.clone()); assert_eq!(state.class_hash_to_class.borrow().get(&class_hash).unwrap(), &contract_class); } diff --git a/crates/blockifier/src/state/state_api.rs b/crates/blockifier/src/state/state_api.rs index 96e24c1e1a3..088028794ff 100644 --- a/crates/blockifier/src/state/state_api.rs +++ b/crates/blockifier/src/state/state_api.rs @@ -5,8 +5,8 @@ use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; -use super::cached_state::{ContractClassMapping, StateMaps}; -use crate::execution::contract_class::RunnableCompiledClass; +use super::cached_state::{StateMaps, VersionedContractClassMapping}; +use crate::execution::contract_class::VersionedRunnableCompiledClass; use crate::state::errors::StateError; pub type StateResult = Result; @@ -39,8 +39,11 @@ pub trait StateReader { /// Default: 0 (uninitialized class hash) for an uninitialized contract address. fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult; - /// Returns the compiled class of the given class hash. - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult; + /// Returns the versioned runnable compiled class of the given class hash. + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult; /// Returns the compiled class hash of the given class hash. fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult; @@ -89,11 +92,11 @@ pub trait State: StateReader { class_hash: ClassHash, ) -> StateResult<()>; - /// Sets the given contract class under the given class hash. + /// Sets the given versioned contract class under the given class hash. fn set_contract_class( &mut self, class_hash: ClassHash, - contract_class: RunnableCompiledClass, + contract_class: VersionedRunnableCompiledClass, ) -> StateResult<()>; /// Sets the given compiled class hash under the given class hash. @@ -114,7 +117,7 @@ pub trait UpdatableState: StateReader { fn apply_writes( &mut self, writes: &StateMaps, - class_hash_to_class: &ContractClassMapping, + class_hash_to_class: &VersionedContractClassMapping, visited_pcs: &HashMap>, ); } diff --git a/crates/blockifier/src/test_utils/contracts.rs b/crates/blockifier/src/test_utils/contracts.rs index 66ae844cced..41c6e8bd472 100644 --- a/crates/blockifier/src/test_utils/contracts.rs +++ b/crates/blockifier/src/test_utils/contracts.rs @@ -190,7 +190,6 @@ impl FeatureContract { self.get_class().try_into().unwrap() } - #[allow(dead_code)] pub fn get_versioned_runnable_class(&self) -> VersionedRunnableCompiledClass { let runnable_class = self.get_runnable_class(); match self.cairo_version() { diff --git a/crates/blockifier/src/test_utils/dict_state_reader.rs b/crates/blockifier/src/test_utils/dict_state_reader.rs index 16e3b3ed83f..48a341728d2 100644 --- a/crates/blockifier/src/test_utils/dict_state_reader.rs +++ b/crates/blockifier/src/test_utils/dict_state_reader.rs @@ -4,7 +4,7 @@ use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; -use crate::execution::contract_class::RunnableCompiledClass; +use crate::execution::contract_class::VersionedRunnableCompiledClass; use crate::state::cached_state::StorageEntry; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult}; @@ -15,7 +15,7 @@ pub struct DictStateReader { pub storage_view: HashMap, pub address_to_nonce: HashMap, pub address_to_class_hash: HashMap, - pub class_hash_to_class: HashMap, + pub class_hash_to_class: HashMap, pub class_hash_to_compiled_class_hash: HashMap, } @@ -35,10 +35,13 @@ impl StateReader for DictStateReader { Ok(nonce) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { - let contract_class = self.class_hash_to_class.get(&class_hash).cloned(); - match contract_class { - Some(contract_class) => Ok(contract_class), + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { + let versioned_contract_class = self.class_hash_to_class.get(&class_hash).cloned(); + match versioned_contract_class { + Some(versioned_contract_class) => Ok(versioned_contract_class), _ => Err(StateError::UndeclaredClassHash(class_hash)), } } diff --git a/crates/blockifier/src/test_utils/initial_test_state.rs b/crates/blockifier/src/test_utils/initial_test_state.rs index 0beeb1a36d5..6df41cdd298 100644 --- a/crates/blockifier/src/test_utils/initial_test_state.rs +++ b/crates/blockifier/src/test_utils/initial_test_state.rs @@ -49,7 +49,7 @@ pub fn test_state_inner( // Declare and deploy account and ERC20 contracts. let erc20 = FeatureContract::ERC20(erc20_contract_version); - class_hash_to_class.insert(erc20.get_class_hash(), erc20.get_runnable_class()); + class_hash_to_class.insert(erc20.get_class_hash(), erc20.get_versioned_runnable_class()); address_to_class_hash .insert(chain_info.fee_token_address(&FeeType::Eth), erc20.get_class_hash()); address_to_class_hash @@ -58,7 +58,7 @@ pub fn test_state_inner( // Set up the rest of the requested contracts. for (contract, n_instances) in contract_instances.iter() { let class_hash = contract.get_class_hash(); - class_hash_to_class.insert(class_hash, contract.get_runnable_class()); + class_hash_to_class.insert(class_hash, contract.get_versioned_runnable_class()); for instance in 0..*n_instances { let instance_address = contract.get_instance_address(instance); address_to_class_hash.insert(instance_address, class_hash); diff --git a/crates/blockifier/src/transaction/account_transaction.rs b/crates/blockifier/src/transaction/account_transaction.rs index 4261c41066c..0199d3b05be 100644 --- a/crates/blockifier/src/transaction/account_transaction.rs +++ b/crates/blockifier/src/transaction/account_transaction.rs @@ -875,7 +875,7 @@ impl ValidatableTransaction for AccountTransaction { })?; // Validate return data. - let compiled_class = state.get_compiled_class(class_hash)?; + let compiled_class: RunnableCompiledClass = state.get_compiled_class(class_hash)?.into(); if is_cairo1(&compiled_class) { // The account contract class is a Cairo 1.0 contract; the `validate` entry point should // return `VALID`. diff --git a/crates/blockifier/src/transaction/account_transactions_test.rs b/crates/blockifier/src/transaction/account_transactions_test.rs index 2deb9e53abe..7e4dee41630 100644 --- a/crates/blockifier/src/transaction/account_transactions_test.rs +++ b/crates/blockifier/src/transaction/account_transactions_test.rs @@ -777,7 +777,10 @@ fn test_fail_declare(block_context: BlockContext, max_fee: Fee) { create_test_init_data(chain_info, CairoVersion::Cairo0); let class_hash = class_hash!(0xdeadeadeaf72_u128); let contract_class = - FeatureContract::Empty(CairoVersion::Cairo1(RunnableCairo1::Casm)).get_class(); + FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm)).get_class(); + let versioned_contract_class = + FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm)) + .get_versioned_runnable_class(); let next_nonce = nonce_manager.next(account_address); // Cannot fail executing a declare tx unless it's V2 or above, and already declared. @@ -787,7 +790,7 @@ fn test_fail_declare(block_context: BlockContext, max_fee: Fee) { sender_address: account_address, ..Default::default() }; - state.set_contract_class(class_hash, contract_class.clone().try_into().unwrap()).unwrap(); + state.set_contract_class(class_hash, versioned_contract_class.clone()).unwrap(); state.set_compiled_class_hash(class_hash, declare_tx_v2.compiled_class_hash).unwrap(); let class_info = calculate_class_info_for_testing(contract_class); let executable_declare = ApiExecutableDeclareTransaction { diff --git a/crates/blockifier/src/transaction/transactions.rs b/crates/blockifier/src/transaction/transactions.rs index b6cb6075005..681947188d5 100644 --- a/crates/blockifier/src/transaction/transactions.rs +++ b/crates/blockifier/src/transaction/transactions.rs @@ -25,6 +25,7 @@ use starknet_api::transaction::{ use crate::context::{BlockContext, TransactionContext}; use crate::execution::call_info::CallInfo; +use crate::execution::contract_class::VersionedRunnableCompiledClass; use crate::execution::entry_point::{ CallEntryPoint, CallType, @@ -204,7 +205,23 @@ impl Executable for DeclareTransaction { // We allow redeclaration of the class for backward compatibility. // In the past, we allowed redeclaration of Cairo 0 contracts since there was // no class commitment (so no need to check if the class is already declared). - state.set_contract_class(class_hash, self.contract_class().try_into()?)?; + let compiled_contract_class = self.contract_class(); + let versioned_compiled_contract_class = { + match compiled_contract_class { + starknet_api::contract_class::ContractClass::V0(_) => { + VersionedRunnableCompiledClass::Cairo0( + compiled_contract_class.try_into()?, + ) + } + starknet_api::contract_class::ContractClass::V1(_) => { + VersionedRunnableCompiledClass::Cairo1(( + compiled_contract_class.try_into()?, + self.class_info.sierra_version.clone(), + )) + } + } + }; + state.set_contract_class(class_hash, versioned_compiled_contract_class)?; } } starknet_api::transaction::DeclareTransaction::V2(DeclareTransactionV2 { @@ -417,7 +434,21 @@ fn try_declare( match state.get_compiled_class(class_hash) { Err(StateError::UndeclaredClassHash(_)) => { // Class is undeclared; declare it. - state.set_contract_class(class_hash, tx.contract_class().try_into()?)?; + let compiled_contract_class = tx.contract_class(); + let versioned_compiled_contract_class = { + match compiled_contract_class { + starknet_api::contract_class::ContractClass::V0(_) => { + VersionedRunnableCompiledClass::Cairo0(compiled_contract_class.try_into()?) + } + starknet_api::contract_class::ContractClass::V1(_) => { + VersionedRunnableCompiledClass::Cairo1(( + compiled_contract_class.try_into()?, + tx.class_info.sierra_version.clone(), + )) + } + } + }; + state.set_contract_class(class_hash, versioned_compiled_contract_class)?; if let Some(compiled_class_hash) = compiled_class_hash { state.set_compiled_class_hash(class_hash, compiled_class_hash)?; } diff --git a/crates/blockifier/src/transaction/transactions_test.rs b/crates/blockifier/src/transaction/transactions_test.rs index 19cebde9cd1..458f54f97ca 100644 --- a/crates/blockifier/src/transaction/transactions_test.rs +++ b/crates/blockifier/src/transaction/transactions_test.rs @@ -1562,6 +1562,8 @@ fn test_declare_tx( #[case] empty_contract_version: CairoVersion, #[values(false, true)] use_kzg_da: bool, ) { + use crate::execution::contract_class::RunnableCompiledClass; + let block_context = &BlockContext::create_for_account_testing_with_kzg(use_kzg_da); let versioned_constants = &block_context.versioned_constants; let empty_contract = FeatureContract::Empty(empty_contract_version); @@ -1698,7 +1700,8 @@ fn test_declare_tx( ); // Verify class declaration. - let contract_class_from_state = state.get_compiled_class(class_hash).unwrap(); + let contract_class_from_state: RunnableCompiledClass = + state.get_compiled_class(class_hash).unwrap().into(); assert_eq!(contract_class_from_state, class_info.contract_class().try_into().unwrap()); // Checks that redeclaring the same contract fails. diff --git a/crates/blockifier_reexecution/src/state_reader/offline_state_reader.rs b/crates/blockifier_reexecution/src/state_reader/offline_state_reader.rs index f3f5290f8fa..81f1b8c1067 100644 --- a/crates/blockifier_reexecution/src/state_reader/offline_state_reader.rs +++ b/crates/blockifier_reexecution/src/state_reader/offline_state_reader.rs @@ -5,7 +5,10 @@ use blockifier::blockifier::config::TransactionExecutorConfig; use blockifier::blockifier::transaction_executor::TransactionExecutor; use blockifier::bouncer::BouncerConfig; use blockifier::context::BlockContext; -use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::execution::contract_class::{ + RunnableCompiledClass, + VersionedRunnableCompiledClass, +}; use blockifier::state::cached_state::{CommitmentStateDiff, StateMaps}; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader, StateResult}; @@ -157,15 +160,22 @@ impl StateReader for OfflineStateReader { )?) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { match self.get_contract_class(&class_hash)? { StarknetContractClass::Sierra(sierra) => { - let (casm, _) = sierra_to_versioned_contract_class_v1(sierra).unwrap(); - Ok(casm.try_into().unwrap()) - } - StarknetContractClass::Legacy(legacy) => { - Ok(legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap()) + let (casm, sierra_version) = sierra_to_versioned_contract_class_v1(sierra).unwrap(); + let runnable_compiled_class: RunnableCompiledClass = casm.try_into().unwrap(); + Ok(VersionedRunnableCompiledClass::Cairo1(( + runnable_compiled_class, + sierra_version, + ))) } + StarknetContractClass::Legacy(legacy) => Ok(VersionedRunnableCompiledClass::Cairo0( + legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap(), + )), } } diff --git a/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs b/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs index 3158aea484a..f15ed796784 100644 --- a/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs +++ b/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs @@ -7,7 +7,10 @@ use blockifier::blockifier::config::TransactionExecutorConfig; use blockifier::blockifier::transaction_executor::TransactionExecutor; use blockifier::bouncer::BouncerConfig; use blockifier::context::BlockContext; -use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::execution::contract_class::{ + RunnableCompiledClass, + VersionedRunnableCompiledClass, +}; use blockifier::state::cached_state::CommitmentStateDiff; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader, StateResult}; @@ -128,18 +131,25 @@ impl StateReader for TestStateReader { /// Returns the contract class of the given class hash. /// Compile the contract class if it is Sierra. - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let contract_class = retry_request!(self.retry_config, || self.get_contract_class(&class_hash))?; match contract_class { StarknetContractClass::Sierra(sierra) => { - let (casm, _) = sierra_to_versioned_contract_class_v1(sierra).unwrap(); - Ok(RunnableCompiledClass::try_from(casm).unwrap()) - } - StarknetContractClass::Legacy(legacy) => { - Ok(legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap()) + let (casm, sierra_version) = sierra_to_versioned_contract_class_v1(sierra).unwrap(); + let runnable_contract_class: RunnableCompiledClass = casm.try_into().unwrap(); + Ok(VersionedRunnableCompiledClass::Cairo1(( + runnable_contract_class, + sierra_version, + ))) } + StarknetContractClass::Legacy(legacy) => Ok(VersionedRunnableCompiledClass::Cairo0( + legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap(), + )), } } diff --git a/crates/native_blockifier/src/py_block_executor_test.rs b/crates/native_blockifier/src/py_block_executor_test.rs index d223df4657a..5881606a1df 100644 --- a/crates/native_blockifier/src/py_block_executor_test.rs +++ b/crates/native_blockifier/src/py_block_executor_test.rs @@ -71,13 +71,14 @@ fn global_contract_cache_update() { assert_eq!(block_executor.global_contract_cache.lock().cache_size(), 0); - let queried_contract_class = block_executor + let queried_contract_class: RunnableCompiledClass = block_executor .tx_executor() .block_state .as_ref() .expect(BLOCK_STATE_ACCESS_ERR) .get_compiled_class(class_hash) - .unwrap(); + .unwrap() + .into(); assert_eq!(queried_contract_class, contract_class); assert_eq!(block_executor.global_contract_cache.lock().cache_size(), 1); diff --git a/crates/native_blockifier/src/py_test_utils.rs b/crates/native_blockifier/src/py_test_utils.rs index 2ba5ac4d485..5a7d94e4697 100644 --- a/crates/native_blockifier/src/py_test_utils.rs +++ b/crates/native_blockifier/src/py_test_utils.rs @@ -1,6 +1,10 @@ use std::collections::HashMap; -use blockifier::execution::contract_class::CompiledClassV0; +use blockifier::execution::contract_class::{ + CompiledClassV0, + RunnableCompiledClass, + VersionedRunnableCompiledClass, +}; use blockifier::state::cached_state::CachedState; use blockifier::test_utils::dict_state_reader::DictStateReader; use blockifier::test_utils::struct_impls::LoadContractFromFile; @@ -16,7 +20,9 @@ pub const TOKEN_FOR_TESTING_CONTRACT_PATH: &str = pub fn create_py_test_state() -> CachedState { let contract_class: CompiledClassV0 = ContractClass::from_file(TOKEN_FOR_TESTING_CONTRACT_PATH).try_into().unwrap(); + let versioned_contract_class = + VersionedRunnableCompiledClass::Cairo0(RunnableCompiledClass::from(contract_class)); let class_hash_to_class = - HashMap::from([(class_hash!(TOKEN_FOR_TESTING_CLASS_HASH), contract_class.into())]); + HashMap::from([(class_hash!(TOKEN_FOR_TESTING_CLASS_HASH), versioned_contract_class)]); CachedState::from(DictStateReader { class_hash_to_class, ..Default::default() }) } diff --git a/crates/native_blockifier/src/state_readers/py_state_reader.rs b/crates/native_blockifier/src/state_readers/py_state_reader.rs index 7f1f8a8fe30..d020ff7f1b8 100644 --- a/crates/native_blockifier/src/state_readers/py_state_reader.rs +++ b/crates/native_blockifier/src/state_readers/py_state_reader.rs @@ -2,6 +2,7 @@ use blockifier::execution::contract_class::{ CompiledClassV0, CompiledClassV1, RunnableCompiledClass, + VersionedRunnableCompiledClass, }; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader, StateResult}; @@ -70,8 +71,11 @@ impl StateReader for PyStateReader { .map_err(|err| StateError::StateReadError(err.to_string())) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { - Python::with_gil(|py| -> Result { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { + Python::with_gil(|py| -> Result { let args = (PyFelt::from(class_hash),); let py_versioned_raw_compiled_class: &PyTuple = self .state_reader_proxy @@ -87,9 +91,11 @@ impl StateReader for PyStateReader { // Extract and process the Sierra version let (minor, major, patch): (u64, u64, u64) = py_versioned_raw_compiled_class.get_item(1)?.extract()?; - // TODO(Aviv): Return it in the next PR after the change in the StateReader API. - let _sierra_version = SierraVersion::new(major, minor, patch); - Ok(runnable_compiled_class) + let sierra_version = SierraVersion::new(major, minor, patch); + if sierra_version == SierraVersion::DEPRECATED { + return Ok(VersionedRunnableCompiledClass::Cairo0(runnable_compiled_class)); + } + Ok(VersionedRunnableCompiledClass::Cairo1((runnable_compiled_class, sierra_version))) }) .map_err(|err| { if Python::with_gil(|py| err.is_instance_of::(py)) { diff --git a/crates/papyrus_execution/src/state_reader.rs b/crates/papyrus_execution/src/state_reader.rs index 35ff6b4ee39..ce5284a7982 100644 --- a/crates/papyrus_execution/src/state_reader.rs +++ b/crates/papyrus_execution/src/state_reader.rs @@ -76,7 +76,10 @@ impl BlockifierStateReader for ExecutionStateReader { .unwrap_or_default()) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { if let Some(pending_casm) = self .maybe_pending_data .as_ref() @@ -90,9 +93,10 @@ impl BlockifierStateReader for ExecutionStateReader { let runnable_compiled_class = RunnableCompiledClass::V1( CompiledClassV1::try_from(pending_casm).map_err(StateError::ProgramError)?, ); - let _sierra_version = SierraVersion::extract_from_program(&sierra.sierra_program)?; - // TODO(AVIV): Use the sierra version when the return type is updated. - return Ok(runnable_compiled_class); + return Ok(VersionedRunnableCompiledClass::Cairo1(( + runnable_compiled_class, + SierraVersion::extract_from_program(&sierra.sierra_program)?, + ))); } } @@ -101,17 +105,17 @@ impl BlockifierStateReader for ExecutionStateReader { .as_ref() .and_then(|pending_data| pending_data.classes.get_class(class_hash)) { - return Ok(RunnableCompiledClass::V0( + return Ok(VersionedRunnableCompiledClass::Cairo0(RunnableCompiledClass::V0( CompiledClassV0::try_from(pending_deprecated_class) .map_err(StateError::ProgramError)?, - )); + ))); } match get_versioned_contract_class( &self.storage_reader.begin_ro_txn().map_err(storage_err_to_state_err)?, &class_hash, self.state_number, ) { - Ok(Some(versioned_contract_class)) => Ok(versioned_contract_class.into()), + Ok(Some(versioned_contract_class)) => Ok(versioned_contract_class), Ok(None) => Err(StateError::UndeclaredClassHash(class_hash)), Err(ExecutionUtilsError::CasmTableNotSynced) => { self.missing_compiled_class.set(Some(class_hash)); diff --git a/crates/papyrus_execution/src/state_reader_test.rs b/crates/papyrus_execution/src/state_reader_test.rs index 8e6a3c9558d..49bc3e1db52 100644 --- a/crates/papyrus_execution/src/state_reader_test.rs +++ b/crates/papyrus_execution/src/state_reader_test.rs @@ -183,8 +183,8 @@ fn read_state() { assert_eq!(nonce_after_block_1, nonce0); let class_hash_after_block_1 = state_reader1.get_class_hash_at(address0).unwrap(); assert_eq!(class_hash_after_block_1, class_hash0); - let compiled_contract_class_after_block_1 = - state_reader1.get_compiled_class(class_hash0).unwrap(); + let compiled_contract_class_after_block_1: RunnableCompiledClass = + state_reader1.get_compiled_class(class_hash0).unwrap().into(); assert_eq!(compiled_contract_class_after_block_1, blockifier_casm0); // Test that an error is returned if we try to get a missing casm, and the field @@ -233,13 +233,19 @@ fn read_state() { assert_eq!(state_reader2.get_compiled_class_hash(class_hash2).unwrap(), compiled_class_hash2); assert_eq!(state_reader2.get_nonce_at(address0).unwrap(), nonce0); assert_eq!(state_reader2.get_nonce_at(address2).unwrap(), nonce1); - assert_eq!(state_reader2.get_compiled_class(class_hash0).unwrap(), blockifier_casm0); - assert_eq!(state_reader2.get_compiled_class(class_hash2).unwrap(), blockifier_casm1); + assert_eq!( + RunnableCompiledClass::from(state_reader2.get_compiled_class(class_hash0).unwrap()), + blockifier_casm0 + ); + assert_eq!( + RunnableCompiledClass::from(state_reader2.get_compiled_class(class_hash2).unwrap()), + blockifier_casm1 + ); // Test that an error is returned if we only got the class without the casm. state_reader2.get_compiled_class(class_hash3).unwrap_err(); // Test that if the class is deprecated it is returned. assert_eq!( - state_reader2.get_compiled_class(class_hash4).unwrap(), + RunnableCompiledClass::from(state_reader2.get_compiled_class(class_hash4).unwrap()), RunnableCompiledClass::V0(CompiledClassV0::try_from(class1).unwrap()) ); diff --git a/crates/papyrus_state_reader/src/papyrus_state.rs b/crates/papyrus_state_reader/src/papyrus_state.rs index bfa5e01c41f..7a900c1c2c5 100644 --- a/crates/papyrus_state_reader/src/papyrus_state.rs +++ b/crates/papyrus_state_reader/src/papyrus_state.rs @@ -130,18 +130,21 @@ impl StateReader for PapyrusReader { } } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { // Assumption: the global cache is cleared upon reverted blocks. let versioned_contract_class = self.global_class_hash_to_class.get(&class_hash); match versioned_contract_class { - Some(contract_class) => Ok(RunnableCompiledClass::from(contract_class)), + Some(versioned_contract_class) => Ok(versioned_contract_class), None => { let versioned_contract_class_from_db = self.get_compiled_class_inner(class_hash)?; // The class was declared in a previous (finalized) state; update the global cache. self.global_class_hash_to_class .set(class_hash, versioned_contract_class_from_db.clone()); - Ok(RunnableCompiledClass::from(versioned_contract_class_from_db)) + Ok(versioned_contract_class_from_db) } } } diff --git a/crates/starknet_gateway/src/rpc_state_reader.rs b/crates/starknet_gateway/src/rpc_state_reader.rs index 62e637a2925..8b1ad0da24c 100644 --- a/crates/starknet_gateway/src/rpc_state_reader.rs +++ b/crates/starknet_gateway/src/rpc_state_reader.rs @@ -2,6 +2,7 @@ use blockifier::execution::contract_class::{ CompiledClassV0, CompiledClassV1, RunnableCompiledClass, + VersionedRunnableCompiledClass, }; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; @@ -139,21 +140,29 @@ impl BlockifierStateReader for RpcStateReader { } } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let get_compiled_class_params = GetCompiledClassParams { class_hash, block_id: self.block_id }; let result = self.send_rpc_request("starknet_getCompiledContractClass", get_compiled_class_params)?; - let (contract_class, _): (CompiledContractClass, SierraVersion) = + let (contract_class, sierra_version): (CompiledContractClass, SierraVersion) = serde_json::from_value(result).map_err(serde_err_to_state_err)?; - match contract_class { - CompiledContractClass::V1(contract_class_v1) => Ok(RunnableCompiledClass::V1( + let runnable_contract_class = match contract_class { + CompiledContractClass::V1(contract_class_v1) => RunnableCompiledClass::V1( CompiledClassV1::try_from(contract_class_v1).map_err(StateError::ProgramError)?, - )), - CompiledContractClass::V0(contract_class_v0) => Ok(RunnableCompiledClass::V0( + ), + CompiledContractClass::V0(contract_class_v0) => RunnableCompiledClass::V0( CompiledClassV0::try_from(contract_class_v0).map_err(StateError::ProgramError)?, - )), + ), + }; + if sierra_version == SierraVersion::DEPRECATED { + Ok(VersionedRunnableCompiledClass::Cairo0(runnable_contract_class)) + } else { + Ok(VersionedRunnableCompiledClass::Cairo1((runnable_contract_class, sierra_version))) } } diff --git a/crates/starknet_gateway/src/rpc_state_reader_test.rs b/crates/starknet_gateway/src/rpc_state_reader_test.rs index 514518636ad..9c26d6a664a 100644 --- a/crates/starknet_gateway/src/rpc_state_reader_test.rs +++ b/crates/starknet_gateway/src/rpc_state_reader_test.rs @@ -193,7 +193,10 @@ async fn test_get_compiled_class() { .await .unwrap() .unwrap(); - assert_eq!(result, RunnableCompiledClass::V1(expected_result.try_into().unwrap())); + assert_eq!( + RunnableCompiledClass::from(result), + RunnableCompiledClass::V1(expected_result.try_into().unwrap()) + ); mock.assert_async().await; } diff --git a/crates/starknet_gateway/src/state_reader.rs b/crates/starknet_gateway/src/state_reader.rs index b026bbe6ef0..d545f15791c 100644 --- a/crates/starknet_gateway/src/state_reader.rs +++ b/crates/starknet_gateway/src/state_reader.rs @@ -1,4 +1,4 @@ -use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::execution::contract_class::VersionedRunnableCompiledClass; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; #[cfg(test)] @@ -44,7 +44,10 @@ impl BlockifierStateReader for Box { self.as_ref().get_class_hash_at(contract_address) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { self.as_ref().get_compiled_class(class_hash) } diff --git a/crates/starknet_gateway/src/state_reader_test_utils.rs b/crates/starknet_gateway/src/state_reader_test_utils.rs index 80c8fc8b516..4971c0d1c27 100644 --- a/crates/starknet_gateway/src/state_reader_test_utils.rs +++ b/crates/starknet_gateway/src/state_reader_test_utils.rs @@ -1,5 +1,5 @@ use blockifier::context::BlockContext; -use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::execution::contract_class::VersionedRunnableCompiledClass; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; use blockifier::test_utils::contracts::FeatureContract; @@ -43,7 +43,10 @@ impl BlockifierStateReader for TestStateReader { self.blockifier_state_reader.get_class_hash_at(contract_address) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { self.blockifier_state_reader.get_compiled_class(class_hash) }