From 2176de2e8cd04afd2841a42033cec5c06b520bdc Mon Sep 17 00:00:00 2001 From: AvivYossef-starkware Date: Tue, 17 Dec 2024 13:33:03 +0200 Subject: [PATCH] refactor(blockifier): state reader trait --- .../src/blockifier/transaction_executor.rs | 9 +++-- crates/blockifier/src/bouncer.rs | 3 +- .../src/concurrency/worker_logic.rs | 4 +-- .../src/execution/contract_class.rs | 2 +- .../blockifier/src/execution/entry_point.rs | 12 ++++--- .../src/execution/syscalls/syscall_base.rs | 4 ++- crates/blockifier/src/state/state_api.rs | 9 +++-- crates/blockifier/src/test_utils/contracts.rs | 1 - .../src/test_utils/dict_state_reader.rs | 15 ++++---- .../src/test_utils/initial_test_state.rs | 4 +-- .../src/transaction/account_transaction.rs | 2 +- .../transaction/account_transactions_test.rs | 7 ++-- .../src/transaction/transactions.rs | 35 +++++++++++++++++-- .../src/transaction/transactions_test.rs | 5 ++- .../src/state_reader/offline_state_reader.rs | 24 +++++++++---- .../src/state_reader/test_state_reader.rs | 24 +++++++++---- .../src/py_block_executor_test.rs | 5 +-- crates/native_blockifier/src/py_test_utils.rs | 10 ++++-- .../src/state_readers/py_state_reader.rs | 16 ++++++--- crates/papyrus_execution/src/state_reader.rs | 18 ++++++---- .../src/state_reader_test.rs | 16 ++++++--- .../papyrus_state_reader/src/papyrus_state.rs | 9 +++-- .../starknet_gateway/src/rpc_state_reader.rs | 23 ++++++++---- .../src/rpc_state_reader_test.rs | 5 ++- crates/starknet_gateway/src/state_reader.rs | 7 ++-- .../src/state_reader_test_utils.rs | 7 ++-- 26 files changed, 197 insertions(+), 79 deletions(-) diff --git a/crates/blockifier/src/blockifier/transaction_executor.rs b/crates/blockifier/src/blockifier/transaction_executor.rs index bec2bd5b8cb..f28b8fee95a 100644 --- a/crates/blockifier/src/blockifier/transaction_executor.rs +++ b/crates/blockifier/src/blockifier/transaction_executor.rs @@ -13,6 +13,7 @@ use crate::blockifier::config::TransactionExecutorConfig; use crate::bouncer::{Bouncer, BouncerWeights}; use crate::concurrency::worker_logic::WorkerExecutor; use crate::context::BlockContext; +use crate::execution::contract_class::RunnableCompiledClass; use crate::state::cached_state::{CachedState, CommitmentStateDiff, TransactionalState}; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult}; @@ -156,12 +157,16 @@ impl TransactionExecutor { .visited_pcs .iter() .map(|(class_hash, class_visited_pcs)| -> TransactionExecutorResult<_> { - let contract_class = self + let versioned_contract_class = self .block_state .as_ref() .expect(BLOCK_STATE_ACCESS_ERR) .get_compiled_class(*class_hash)?; - Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?)) + Ok(( + *class_hash, + RunnableCompiledClass::from(versioned_contract_class) + .get_visited_segments(class_visited_pcs)?, + )) }) .collect::>()?; diff --git a/crates/blockifier/src/bouncer.rs b/crates/blockifier/src/bouncer.rs index 99c93451841..846bdf53043 100644 --- a/crates/blockifier/src/bouncer.rs +++ b/crates/blockifier/src/bouncer.rs @@ -13,6 +13,7 @@ use crate::blockifier::transaction_executor::{ TransactionExecutorResult, }; use crate::execution::call_info::ExecutionSummary; +use crate::execution::contract_class::RunnableCompiledClass; use crate::fee::gas_usage::get_onchain_data_segment_length; use crate::fee::resources::TransactionResources; use crate::state::cached_state::{StateChangesKeys, StorageEntry}; @@ -565,7 +566,7 @@ pub fn get_casm_hash_calculation_resources( let mut casm_hash_computation_resources = ExecutionResources::default(); for class_hash in executed_class_hashes { - let class = state_reader.get_compiled_class(*class_hash)?; + let class: RunnableCompiledClass = state_reader.get_compiled_class(*class_hash)?.into(); casm_hash_computation_resources += &class.estimate_casm_hash_computation_resources(); } diff --git a/crates/blockifier/src/concurrency/worker_logic.rs b/crates/blockifier/src/concurrency/worker_logic.rs index 851a462f31f..1e129709bd4 100644 --- a/crates/blockifier/src/concurrency/worker_logic.rs +++ b/crates/blockifier/src/concurrency/worker_logic.rs @@ -15,7 +15,7 @@ use crate::concurrency::utils::lock_mutex_in_array; use crate::concurrency::versioned_state::ThreadSafeVersionedState; use crate::concurrency::TxIndex; use crate::context::BlockContext; -use crate::state::cached_state::{ContractClassMapping, StateMaps, TransactionalState}; +use crate::state::cached_state::{StateMaps, TransactionalState, VersionedContractClassMapping}; use crate::state::state_api::{StateReader, UpdatableState}; use crate::transaction::objects::{TransactionExecutionInfo, TransactionExecutionResult}; use crate::transaction::transaction_execution::Transaction; @@ -32,7 +32,7 @@ pub struct ExecutionTaskOutput { pub reads: StateMaps, // TODO(Yoni): rename to state_diff. pub writes: StateMaps, - pub contract_classes: ContractClassMapping, + pub contract_classes: VersionedContractClassMapping, pub visited_pcs: HashMap>, pub result: TransactionExecutionResult, } diff --git a/crates/blockifier/src/execution/contract_class.rs b/crates/blockifier/src/execution/contract_class.rs index 976397bd182..3fd7a12a7e8 100644 --- a/crates/blockifier/src/execution/contract_class.rs +++ b/crates/blockifier/src/execution/contract_class.rs @@ -68,7 +68,7 @@ pub enum RunnableCompiledClass { } /// Represents a runnable compiled class for Cairo, with the Sierra version (for Cairo 1). -#[derive(Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum VersionedRunnableCompiledClass { Cairo0(RunnableCompiledClass), Cairo1((RunnableCompiledClass, SierraVersion)), diff --git a/crates/blockifier/src/execution/entry_point.rs b/crates/blockifier/src/execution/entry_point.rs index e7fd2e84a74..3eb6fba73f2 100644 --- a/crates/blockifier/src/execution/entry_point.rs +++ b/crates/blockifier/src/execution/entry_point.rs @@ -21,6 +21,7 @@ use starknet_api::transaction::fields::{ use starknet_api::transaction::TransactionVersion; use starknet_types_core::felt::Felt; +use super::contract_class::RunnableCompiledClass; use crate::context::{BlockContext, TransactionContext}; use crate::execution::call_info::CallInfo; use crate::execution::common_hints::ExecutionMode; @@ -148,7 +149,7 @@ impl CallEntryPoint { } // Add class hash to the call, that will appear in the output (call info). self.class_hash = Some(class_hash); - let compiled_class = state.get_compiled_class(class_hash)?; + let compiled_class: RunnableCompiledClass = state.get_compiled_class(class_hash)?.into(); context.revert_infos.0.push(EntryPointRevertInfo::new( self.storage_address, @@ -407,9 +408,12 @@ pub fn execute_constructor_entry_point( remaining_gas: &mut u64, ) -> ConstructorEntryPointExecutionResult { // Ensure the class is declared (by reading it). - let compiled_class = state.get_compiled_class(ctor_context.class_hash).map_err(|error| { - ConstructorEntryPointExecutionError::new(error.into(), &ctor_context, None) - })?; + let compiled_class: RunnableCompiledClass = state + .get_compiled_class(ctor_context.class_hash) + .map_err(|error| { + ConstructorEntryPointExecutionError::new(error.into(), &ctor_context, None) + })? + .into(); let Some(constructor_selector) = compiled_class.constructor_selector() else { // Contract has no constructor. return handle_empty_constructor(&ctor_context, calldata, *remaining_gas) diff --git a/crates/blockifier/src/execution/syscalls/syscall_base.rs b/crates/blockifier/src/execution/syscalls/syscall_base.rs index 7e172065f1c..6abca84dcfc 100644 --- a/crates/blockifier/src/execution/syscalls/syscall_base.rs +++ b/crates/blockifier/src/execution/syscalls/syscall_base.rs @@ -12,6 +12,7 @@ use super::exceeds_event_size_limit; use crate::abi::constants; use crate::execution::call_info::{CallInfo, MessageToL1, OrderedEvent, OrderedL2ToL1Message}; use crate::execution::common_hints::ExecutionMode; +use crate::execution::contract_class::RunnableCompiledClass; use crate::execution::entry_point::{ CallEntryPoint, ConstructorContext, @@ -164,7 +165,8 @@ impl<'state> SyscallHandlerBase<'state> { pub fn replace_class(&mut self, class_hash: ClassHash) -> SyscallResult<()> { // Ensure the class is declared (by reading it), and of type V1. - let compiled_class = self.state.get_compiled_class(class_hash)?; + let compiled_class: RunnableCompiledClass = + self.state.get_compiled_class(class_hash)?.into(); if !is_cairo1(&compiled_class) { return Err(SyscallExecutionError::ForbiddenClassReplacement { class_hash }); diff --git a/crates/blockifier/src/state/state_api.rs b/crates/blockifier/src/state/state_api.rs index 96e24c1e1a3..e248f0b534d 100644 --- a/crates/blockifier/src/state/state_api.rs +++ b/crates/blockifier/src/state/state_api.rs @@ -6,7 +6,7 @@ use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; use super::cached_state::{ContractClassMapping, StateMaps}; -use crate::execution::contract_class::RunnableCompiledClass; +use crate::execution::contract_class::VersionedRunnableCompiledClass; use crate::state::errors::StateError; pub type StateResult = Result; @@ -39,8 +39,11 @@ pub trait StateReader { /// Default: 0 (uninitialized class hash) for an uninitialized contract address. fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult; - /// Returns the compiled class of the given class hash. - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult; + /// Returns the versioned runnable compiled class of the given class hash. + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult; /// Returns the compiled class hash of the given class hash. fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult; diff --git a/crates/blockifier/src/test_utils/contracts.rs b/crates/blockifier/src/test_utils/contracts.rs index 66ae844cced..41c6e8bd472 100644 --- a/crates/blockifier/src/test_utils/contracts.rs +++ b/crates/blockifier/src/test_utils/contracts.rs @@ -190,7 +190,6 @@ impl FeatureContract { self.get_class().try_into().unwrap() } - #[allow(dead_code)] pub fn get_versioned_runnable_class(&self) -> VersionedRunnableCompiledClass { let runnable_class = self.get_runnable_class(); match self.cairo_version() { diff --git a/crates/blockifier/src/test_utils/dict_state_reader.rs b/crates/blockifier/src/test_utils/dict_state_reader.rs index 16e3b3ed83f..48a341728d2 100644 --- a/crates/blockifier/src/test_utils/dict_state_reader.rs +++ b/crates/blockifier/src/test_utils/dict_state_reader.rs @@ -4,7 +4,7 @@ use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; -use crate::execution::contract_class::RunnableCompiledClass; +use crate::execution::contract_class::VersionedRunnableCompiledClass; use crate::state::cached_state::StorageEntry; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult}; @@ -15,7 +15,7 @@ pub struct DictStateReader { pub storage_view: HashMap, pub address_to_nonce: HashMap, pub address_to_class_hash: HashMap, - pub class_hash_to_class: HashMap, + pub class_hash_to_class: HashMap, pub class_hash_to_compiled_class_hash: HashMap, } @@ -35,10 +35,13 @@ impl StateReader for DictStateReader { Ok(nonce) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { - let contract_class = self.class_hash_to_class.get(&class_hash).cloned(); - match contract_class { - Some(contract_class) => Ok(contract_class), + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { + let versioned_contract_class = self.class_hash_to_class.get(&class_hash).cloned(); + match versioned_contract_class { + Some(versioned_contract_class) => Ok(versioned_contract_class), _ => Err(StateError::UndeclaredClassHash(class_hash)), } } diff --git a/crates/blockifier/src/test_utils/initial_test_state.rs b/crates/blockifier/src/test_utils/initial_test_state.rs index 0beeb1a36d5..6df41cdd298 100644 --- a/crates/blockifier/src/test_utils/initial_test_state.rs +++ b/crates/blockifier/src/test_utils/initial_test_state.rs @@ -49,7 +49,7 @@ pub fn test_state_inner( // Declare and deploy account and ERC20 contracts. let erc20 = FeatureContract::ERC20(erc20_contract_version); - class_hash_to_class.insert(erc20.get_class_hash(), erc20.get_runnable_class()); + class_hash_to_class.insert(erc20.get_class_hash(), erc20.get_versioned_runnable_class()); address_to_class_hash .insert(chain_info.fee_token_address(&FeeType::Eth), erc20.get_class_hash()); address_to_class_hash @@ -58,7 +58,7 @@ pub fn test_state_inner( // Set up the rest of the requested contracts. for (contract, n_instances) in contract_instances.iter() { let class_hash = contract.get_class_hash(); - class_hash_to_class.insert(class_hash, contract.get_runnable_class()); + class_hash_to_class.insert(class_hash, contract.get_versioned_runnable_class()); for instance in 0..*n_instances { let instance_address = contract.get_instance_address(instance); address_to_class_hash.insert(instance_address, class_hash); diff --git a/crates/blockifier/src/transaction/account_transaction.rs b/crates/blockifier/src/transaction/account_transaction.rs index 4261c41066c..0199d3b05be 100644 --- a/crates/blockifier/src/transaction/account_transaction.rs +++ b/crates/blockifier/src/transaction/account_transaction.rs @@ -875,7 +875,7 @@ impl ValidatableTransaction for AccountTransaction { })?; // Validate return data. - let compiled_class = state.get_compiled_class(class_hash)?; + let compiled_class: RunnableCompiledClass = state.get_compiled_class(class_hash)?.into(); if is_cairo1(&compiled_class) { // The account contract class is a Cairo 1.0 contract; the `validate` entry point should // return `VALID`. diff --git a/crates/blockifier/src/transaction/account_transactions_test.rs b/crates/blockifier/src/transaction/account_transactions_test.rs index a56d69cfcd4..59b451a7b02 100644 --- a/crates/blockifier/src/transaction/account_transactions_test.rs +++ b/crates/blockifier/src/transaction/account_transactions_test.rs @@ -779,7 +779,10 @@ fn test_fail_declare(block_context: BlockContext, max_fee: Fee) { create_test_init_data(chain_info, CairoVersion::Cairo0); let class_hash = class_hash!(0xdeadeadeaf72_u128); let contract_class = - FeatureContract::Empty(CairoVersion::Cairo1(RunnableCairo1::Casm)).get_class(); + FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm)).get_class(); + let versioned_contract_class = + FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm)) + .get_versioned_runnable_class(); let next_nonce = nonce_manager.next(account_address); // Cannot fail executing a declare tx unless it's V2 or above, and already declared. @@ -789,7 +792,7 @@ fn test_fail_declare(block_context: BlockContext, max_fee: Fee) { sender_address: account_address, ..Default::default() }; - state.set_contract_class(class_hash, contract_class.clone().try_into().unwrap()).unwrap(); + state.set_contract_class(class_hash, versioned_contract_class.clone()).unwrap(); state.set_compiled_class_hash(class_hash, declare_tx_v2.compiled_class_hash).unwrap(); let class_info = calculate_class_info_for_testing(contract_class); let executable_declare = ApiExecutableDeclareTransaction { diff --git a/crates/blockifier/src/transaction/transactions.rs b/crates/blockifier/src/transaction/transactions.rs index b6cb6075005..681947188d5 100644 --- a/crates/blockifier/src/transaction/transactions.rs +++ b/crates/blockifier/src/transaction/transactions.rs @@ -25,6 +25,7 @@ use starknet_api::transaction::{ use crate::context::{BlockContext, TransactionContext}; use crate::execution::call_info::CallInfo; +use crate::execution::contract_class::VersionedRunnableCompiledClass; use crate::execution::entry_point::{ CallEntryPoint, CallType, @@ -204,7 +205,23 @@ impl Executable for DeclareTransaction { // We allow redeclaration of the class for backward compatibility. // In the past, we allowed redeclaration of Cairo 0 contracts since there was // no class commitment (so no need to check if the class is already declared). - state.set_contract_class(class_hash, self.contract_class().try_into()?)?; + let compiled_contract_class = self.contract_class(); + let versioned_compiled_contract_class = { + match compiled_contract_class { + starknet_api::contract_class::ContractClass::V0(_) => { + VersionedRunnableCompiledClass::Cairo0( + compiled_contract_class.try_into()?, + ) + } + starknet_api::contract_class::ContractClass::V1(_) => { + VersionedRunnableCompiledClass::Cairo1(( + compiled_contract_class.try_into()?, + self.class_info.sierra_version.clone(), + )) + } + } + }; + state.set_contract_class(class_hash, versioned_compiled_contract_class)?; } } starknet_api::transaction::DeclareTransaction::V2(DeclareTransactionV2 { @@ -417,7 +434,21 @@ fn try_declare( match state.get_compiled_class(class_hash) { Err(StateError::UndeclaredClassHash(_)) => { // Class is undeclared; declare it. - state.set_contract_class(class_hash, tx.contract_class().try_into()?)?; + let compiled_contract_class = tx.contract_class(); + let versioned_compiled_contract_class = { + match compiled_contract_class { + starknet_api::contract_class::ContractClass::V0(_) => { + VersionedRunnableCompiledClass::Cairo0(compiled_contract_class.try_into()?) + } + starknet_api::contract_class::ContractClass::V1(_) => { + VersionedRunnableCompiledClass::Cairo1(( + compiled_contract_class.try_into()?, + tx.class_info.sierra_version.clone(), + )) + } + } + }; + state.set_contract_class(class_hash, versioned_compiled_contract_class)?; if let Some(compiled_class_hash) = compiled_class_hash { state.set_compiled_class_hash(class_hash, compiled_class_hash)?; } diff --git a/crates/blockifier/src/transaction/transactions_test.rs b/crates/blockifier/src/transaction/transactions_test.rs index 3c80695a94e..93f870e6685 100644 --- a/crates/blockifier/src/transaction/transactions_test.rs +++ b/crates/blockifier/src/transaction/transactions_test.rs @@ -1562,6 +1562,8 @@ fn test_declare_tx( #[case] empty_contract_version: CairoVersion, #[values(false, true)] use_kzg_da: bool, ) { + use crate::execution::contract_class::RunnableCompiledClass; + let block_context = &BlockContext::create_for_account_testing_with_kzg(use_kzg_da); let versioned_constants = &block_context.versioned_constants; let empty_contract = FeatureContract::Empty(empty_contract_version); @@ -1698,7 +1700,8 @@ fn test_declare_tx( ); // Verify class declaration. - let contract_class_from_state = state.get_compiled_class(class_hash).unwrap(); + let contract_class_from_state: RunnableCompiledClass = + state.get_compiled_class(class_hash).unwrap().into(); assert_eq!(contract_class_from_state, class_info.contract_class().try_into().unwrap()); // Checks that redeclaring the same contract fails. diff --git a/crates/blockifier_reexecution/src/state_reader/offline_state_reader.rs b/crates/blockifier_reexecution/src/state_reader/offline_state_reader.rs index f3f5290f8fa..81f1b8c1067 100644 --- a/crates/blockifier_reexecution/src/state_reader/offline_state_reader.rs +++ b/crates/blockifier_reexecution/src/state_reader/offline_state_reader.rs @@ -5,7 +5,10 @@ use blockifier::blockifier::config::TransactionExecutorConfig; use blockifier::blockifier::transaction_executor::TransactionExecutor; use blockifier::bouncer::BouncerConfig; use blockifier::context::BlockContext; -use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::execution::contract_class::{ + RunnableCompiledClass, + VersionedRunnableCompiledClass, +}; use blockifier::state::cached_state::{CommitmentStateDiff, StateMaps}; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader, StateResult}; @@ -157,15 +160,22 @@ impl StateReader for OfflineStateReader { )?) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { match self.get_contract_class(&class_hash)? { StarknetContractClass::Sierra(sierra) => { - let (casm, _) = sierra_to_versioned_contract_class_v1(sierra).unwrap(); - Ok(casm.try_into().unwrap()) - } - StarknetContractClass::Legacy(legacy) => { - Ok(legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap()) + let (casm, sierra_version) = sierra_to_versioned_contract_class_v1(sierra).unwrap(); + let runnable_compiled_class: RunnableCompiledClass = casm.try_into().unwrap(); + Ok(VersionedRunnableCompiledClass::Cairo1(( + runnable_compiled_class, + sierra_version, + ))) } + StarknetContractClass::Legacy(legacy) => Ok(VersionedRunnableCompiledClass::Cairo0( + legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap(), + )), } } diff --git a/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs b/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs index 3158aea484a..f15ed796784 100644 --- a/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs +++ b/crates/blockifier_reexecution/src/state_reader/test_state_reader.rs @@ -7,7 +7,10 @@ use blockifier::blockifier::config::TransactionExecutorConfig; use blockifier::blockifier::transaction_executor::TransactionExecutor; use blockifier::bouncer::BouncerConfig; use blockifier::context::BlockContext; -use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::execution::contract_class::{ + RunnableCompiledClass, + VersionedRunnableCompiledClass, +}; use blockifier::state::cached_state::CommitmentStateDiff; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader, StateResult}; @@ -128,18 +131,25 @@ impl StateReader for TestStateReader { /// Returns the contract class of the given class hash. /// Compile the contract class if it is Sierra. - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let contract_class = retry_request!(self.retry_config, || self.get_contract_class(&class_hash))?; match contract_class { StarknetContractClass::Sierra(sierra) => { - let (casm, _) = sierra_to_versioned_contract_class_v1(sierra).unwrap(); - Ok(RunnableCompiledClass::try_from(casm).unwrap()) - } - StarknetContractClass::Legacy(legacy) => { - Ok(legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap()) + let (casm, sierra_version) = sierra_to_versioned_contract_class_v1(sierra).unwrap(); + let runnable_contract_class: RunnableCompiledClass = casm.try_into().unwrap(); + Ok(VersionedRunnableCompiledClass::Cairo1(( + runnable_contract_class, + sierra_version, + ))) } + StarknetContractClass::Legacy(legacy) => Ok(VersionedRunnableCompiledClass::Cairo0( + legacy_to_contract_class_v0(legacy).unwrap().try_into().unwrap(), + )), } } diff --git a/crates/native_blockifier/src/py_block_executor_test.rs b/crates/native_blockifier/src/py_block_executor_test.rs index d223df4657a..5881606a1df 100644 --- a/crates/native_blockifier/src/py_block_executor_test.rs +++ b/crates/native_blockifier/src/py_block_executor_test.rs @@ -71,13 +71,14 @@ fn global_contract_cache_update() { assert_eq!(block_executor.global_contract_cache.lock().cache_size(), 0); - let queried_contract_class = block_executor + let queried_contract_class: RunnableCompiledClass = block_executor .tx_executor() .block_state .as_ref() .expect(BLOCK_STATE_ACCESS_ERR) .get_compiled_class(class_hash) - .unwrap(); + .unwrap() + .into(); assert_eq!(queried_contract_class, contract_class); assert_eq!(block_executor.global_contract_cache.lock().cache_size(), 1); diff --git a/crates/native_blockifier/src/py_test_utils.rs b/crates/native_blockifier/src/py_test_utils.rs index 2ba5ac4d485..5a7d94e4697 100644 --- a/crates/native_blockifier/src/py_test_utils.rs +++ b/crates/native_blockifier/src/py_test_utils.rs @@ -1,6 +1,10 @@ use std::collections::HashMap; -use blockifier::execution::contract_class::CompiledClassV0; +use blockifier::execution::contract_class::{ + CompiledClassV0, + RunnableCompiledClass, + VersionedRunnableCompiledClass, +}; use blockifier::state::cached_state::CachedState; use blockifier::test_utils::dict_state_reader::DictStateReader; use blockifier::test_utils::struct_impls::LoadContractFromFile; @@ -16,7 +20,9 @@ pub const TOKEN_FOR_TESTING_CONTRACT_PATH: &str = pub fn create_py_test_state() -> CachedState { let contract_class: CompiledClassV0 = ContractClass::from_file(TOKEN_FOR_TESTING_CONTRACT_PATH).try_into().unwrap(); + let versioned_contract_class = + VersionedRunnableCompiledClass::Cairo0(RunnableCompiledClass::from(contract_class)); let class_hash_to_class = - HashMap::from([(class_hash!(TOKEN_FOR_TESTING_CLASS_HASH), contract_class.into())]); + HashMap::from([(class_hash!(TOKEN_FOR_TESTING_CLASS_HASH), versioned_contract_class)]); CachedState::from(DictStateReader { class_hash_to_class, ..Default::default() }) } diff --git a/crates/native_blockifier/src/state_readers/py_state_reader.rs b/crates/native_blockifier/src/state_readers/py_state_reader.rs index 7f1f8a8fe30..d020ff7f1b8 100644 --- a/crates/native_blockifier/src/state_readers/py_state_reader.rs +++ b/crates/native_blockifier/src/state_readers/py_state_reader.rs @@ -2,6 +2,7 @@ use blockifier::execution::contract_class::{ CompiledClassV0, CompiledClassV1, RunnableCompiledClass, + VersionedRunnableCompiledClass, }; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader, StateResult}; @@ -70,8 +71,11 @@ impl StateReader for PyStateReader { .map_err(|err| StateError::StateReadError(err.to_string())) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { - Python::with_gil(|py| -> Result { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { + Python::with_gil(|py| -> Result { let args = (PyFelt::from(class_hash),); let py_versioned_raw_compiled_class: &PyTuple = self .state_reader_proxy @@ -87,9 +91,11 @@ impl StateReader for PyStateReader { // Extract and process the Sierra version let (minor, major, patch): (u64, u64, u64) = py_versioned_raw_compiled_class.get_item(1)?.extract()?; - // TODO(Aviv): Return it in the next PR after the change in the StateReader API. - let _sierra_version = SierraVersion::new(major, minor, patch); - Ok(runnable_compiled_class) + let sierra_version = SierraVersion::new(major, minor, patch); + if sierra_version == SierraVersion::DEPRECATED { + return Ok(VersionedRunnableCompiledClass::Cairo0(runnable_compiled_class)); + } + Ok(VersionedRunnableCompiledClass::Cairo1((runnable_compiled_class, sierra_version))) }) .map_err(|err| { if Python::with_gil(|py| err.is_instance_of::(py)) { diff --git a/crates/papyrus_execution/src/state_reader.rs b/crates/papyrus_execution/src/state_reader.rs index 2195ec015c5..890a4970090 100644 --- a/crates/papyrus_execution/src/state_reader.rs +++ b/crates/papyrus_execution/src/state_reader.rs @@ -77,7 +77,10 @@ impl BlockifierStateReader for ExecutionStateReader { .unwrap_or_default()) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { if let Some(pending_casm) = self .maybe_pending_data .as_ref() @@ -91,9 +94,10 @@ impl BlockifierStateReader for ExecutionStateReader { let runnable_compiled_class = RunnableCompiledClass::V1( CompiledClassV1::try_from(pending_casm).map_err(StateError::ProgramError)?, ); - let _sierra_version = SierraVersion::extract_from_program(&sierra.sierra_program)?; - // TODO(AVIV): Use the sierra version when the return type is updated. - return Ok(runnable_compiled_class); + return Ok(VersionedRunnableCompiledClass::Cairo1(( + runnable_compiled_class, + SierraVersion::extract_from_program(&sierra.sierra_program)?, + ))); } } @@ -102,17 +106,17 @@ impl BlockifierStateReader for ExecutionStateReader { .as_ref() .and_then(|pending_data| pending_data.classes.get_class(class_hash)) { - return Ok(RunnableCompiledClass::V0( + return Ok(VersionedRunnableCompiledClass::Cairo0(RunnableCompiledClass::V0( CompiledClassV0::try_from(pending_deprecated_class) .map_err(StateError::ProgramError)?, - )); + ))); } match get_versioned_contract_class( &self.storage_reader.begin_ro_txn().map_err(storage_err_to_state_err)?, &class_hash, self.state_number, ) { - Ok(Some(versioned_contract_class)) => Ok(versioned_contract_class.into()), + Ok(Some(versioned_contract_class)) => Ok(versioned_contract_class), Ok(None) => Err(StateError::UndeclaredClassHash(class_hash)), Err(ExecutionUtilsError::CasmTableNotSynced) => { self.missing_compiled_class.set(Some(class_hash)); diff --git a/crates/papyrus_execution/src/state_reader_test.rs b/crates/papyrus_execution/src/state_reader_test.rs index 8e6a3c9558d..49bc3e1db52 100644 --- a/crates/papyrus_execution/src/state_reader_test.rs +++ b/crates/papyrus_execution/src/state_reader_test.rs @@ -183,8 +183,8 @@ fn read_state() { assert_eq!(nonce_after_block_1, nonce0); let class_hash_after_block_1 = state_reader1.get_class_hash_at(address0).unwrap(); assert_eq!(class_hash_after_block_1, class_hash0); - let compiled_contract_class_after_block_1 = - state_reader1.get_compiled_class(class_hash0).unwrap(); + let compiled_contract_class_after_block_1: RunnableCompiledClass = + state_reader1.get_compiled_class(class_hash0).unwrap().into(); assert_eq!(compiled_contract_class_after_block_1, blockifier_casm0); // Test that an error is returned if we try to get a missing casm, and the field @@ -233,13 +233,19 @@ fn read_state() { assert_eq!(state_reader2.get_compiled_class_hash(class_hash2).unwrap(), compiled_class_hash2); assert_eq!(state_reader2.get_nonce_at(address0).unwrap(), nonce0); assert_eq!(state_reader2.get_nonce_at(address2).unwrap(), nonce1); - assert_eq!(state_reader2.get_compiled_class(class_hash0).unwrap(), blockifier_casm0); - assert_eq!(state_reader2.get_compiled_class(class_hash2).unwrap(), blockifier_casm1); + assert_eq!( + RunnableCompiledClass::from(state_reader2.get_compiled_class(class_hash0).unwrap()), + blockifier_casm0 + ); + assert_eq!( + RunnableCompiledClass::from(state_reader2.get_compiled_class(class_hash2).unwrap()), + blockifier_casm1 + ); // Test that an error is returned if we only got the class without the casm. state_reader2.get_compiled_class(class_hash3).unwrap_err(); // Test that if the class is deprecated it is returned. assert_eq!( - state_reader2.get_compiled_class(class_hash4).unwrap(), + RunnableCompiledClass::from(state_reader2.get_compiled_class(class_hash4).unwrap()), RunnableCompiledClass::V0(CompiledClassV0::try_from(class1).unwrap()) ); diff --git a/crates/papyrus_state_reader/src/papyrus_state.rs b/crates/papyrus_state_reader/src/papyrus_state.rs index bfa5e01c41f..7a900c1c2c5 100644 --- a/crates/papyrus_state_reader/src/papyrus_state.rs +++ b/crates/papyrus_state_reader/src/papyrus_state.rs @@ -130,18 +130,21 @@ impl StateReader for PapyrusReader { } } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { // Assumption: the global cache is cleared upon reverted blocks. let versioned_contract_class = self.global_class_hash_to_class.get(&class_hash); match versioned_contract_class { - Some(contract_class) => Ok(RunnableCompiledClass::from(contract_class)), + Some(versioned_contract_class) => Ok(versioned_contract_class), None => { let versioned_contract_class_from_db = self.get_compiled_class_inner(class_hash)?; // The class was declared in a previous (finalized) state; update the global cache. self.global_class_hash_to_class .set(class_hash, versioned_contract_class_from_db.clone()); - Ok(RunnableCompiledClass::from(versioned_contract_class_from_db)) + Ok(versioned_contract_class_from_db) } } } diff --git a/crates/starknet_gateway/src/rpc_state_reader.rs b/crates/starknet_gateway/src/rpc_state_reader.rs index 62e637a2925..8b1ad0da24c 100644 --- a/crates/starknet_gateway/src/rpc_state_reader.rs +++ b/crates/starknet_gateway/src/rpc_state_reader.rs @@ -2,6 +2,7 @@ use blockifier::execution::contract_class::{ CompiledClassV0, CompiledClassV1, RunnableCompiledClass, + VersionedRunnableCompiledClass, }; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; @@ -139,21 +140,29 @@ impl BlockifierStateReader for RpcStateReader { } } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { let get_compiled_class_params = GetCompiledClassParams { class_hash, block_id: self.block_id }; let result = self.send_rpc_request("starknet_getCompiledContractClass", get_compiled_class_params)?; - let (contract_class, _): (CompiledContractClass, SierraVersion) = + let (contract_class, sierra_version): (CompiledContractClass, SierraVersion) = serde_json::from_value(result).map_err(serde_err_to_state_err)?; - match contract_class { - CompiledContractClass::V1(contract_class_v1) => Ok(RunnableCompiledClass::V1( + let runnable_contract_class = match contract_class { + CompiledContractClass::V1(contract_class_v1) => RunnableCompiledClass::V1( CompiledClassV1::try_from(contract_class_v1).map_err(StateError::ProgramError)?, - )), - CompiledContractClass::V0(contract_class_v0) => Ok(RunnableCompiledClass::V0( + ), + CompiledContractClass::V0(contract_class_v0) => RunnableCompiledClass::V0( CompiledClassV0::try_from(contract_class_v0).map_err(StateError::ProgramError)?, - )), + ), + }; + if sierra_version == SierraVersion::DEPRECATED { + Ok(VersionedRunnableCompiledClass::Cairo0(runnable_contract_class)) + } else { + Ok(VersionedRunnableCompiledClass::Cairo1((runnable_contract_class, sierra_version))) } } diff --git a/crates/starknet_gateway/src/rpc_state_reader_test.rs b/crates/starknet_gateway/src/rpc_state_reader_test.rs index 514518636ad..9c26d6a664a 100644 --- a/crates/starknet_gateway/src/rpc_state_reader_test.rs +++ b/crates/starknet_gateway/src/rpc_state_reader_test.rs @@ -193,7 +193,10 @@ async fn test_get_compiled_class() { .await .unwrap() .unwrap(); - assert_eq!(result, RunnableCompiledClass::V1(expected_result.try_into().unwrap())); + assert_eq!( + RunnableCompiledClass::from(result), + RunnableCompiledClass::V1(expected_result.try_into().unwrap()) + ); mock.assert_async().await; } diff --git a/crates/starknet_gateway/src/state_reader.rs b/crates/starknet_gateway/src/state_reader.rs index b026bbe6ef0..d545f15791c 100644 --- a/crates/starknet_gateway/src/state_reader.rs +++ b/crates/starknet_gateway/src/state_reader.rs @@ -1,4 +1,4 @@ -use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::execution::contract_class::VersionedRunnableCompiledClass; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; #[cfg(test)] @@ -44,7 +44,10 @@ impl BlockifierStateReader for Box { self.as_ref().get_class_hash_at(contract_address) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { self.as_ref().get_compiled_class(class_hash) } diff --git a/crates/starknet_gateway/src/state_reader_test_utils.rs b/crates/starknet_gateway/src/state_reader_test_utils.rs index 80c8fc8b516..4971c0d1c27 100644 --- a/crates/starknet_gateway/src/state_reader_test_utils.rs +++ b/crates/starknet_gateway/src/state_reader_test_utils.rs @@ -1,5 +1,5 @@ use blockifier::context::BlockContext; -use blockifier::execution::contract_class::RunnableCompiledClass; +use blockifier::execution::contract_class::VersionedRunnableCompiledClass; use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; use blockifier::test_utils::contracts::FeatureContract; @@ -43,7 +43,10 @@ impl BlockifierStateReader for TestStateReader { self.blockifier_state_reader.get_class_hash_at(contract_address) } - fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class( + &self, + class_hash: ClassHash, + ) -> StateResult { self.blockifier_state_reader.get_compiled_class(class_hash) }