Skip to content

Commit

Permalink
refactor(blockifier): state reader trait
Browse files Browse the repository at this point in the history
  • Loading branch information
AvivYossef-starkware committed Dec 15, 2024
1 parent f8977d8 commit c8f8585
Show file tree
Hide file tree
Showing 32 changed files with 285 additions and 140 deletions.
9 changes: 7 additions & 2 deletions crates/blockifier/src/blockifier/transaction_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::blockifier::config::TransactionExecutorConfig;
use crate::bouncer::{Bouncer, BouncerWeights};
use crate::concurrency::worker_logic::WorkerExecutor;
use crate::context::BlockContext;
use crate::execution::contract_class::RunnableCompiledClass;
use crate::state::cached_state::{CachedState, CommitmentStateDiff, TransactionalState};
use crate::state::errors::StateError;
use crate::state::state_api::{StateReader, StateResult};
Expand Down Expand Up @@ -156,12 +157,16 @@ impl<S: StateReader> TransactionExecutor<S> {
.visited_pcs
.iter()
.map(|(class_hash, class_visited_pcs)| -> TransactionExecutorResult<_> {
let contract_class = self
let versioned_contract_class = self
.block_state
.as_ref()
.expect(BLOCK_STATE_ACCESS_ERR)
.get_compiled_class(*class_hash)?;
Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?))
Ok((
*class_hash,
RunnableCompiledClass::from(versioned_contract_class)
.get_visited_segments(class_visited_pcs)?,
))
})
.collect::<TransactionExecutorResult<_>>()?;

Expand Down
3 changes: 2 additions & 1 deletion crates/blockifier/src/bouncer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::blockifier::transaction_executor::{
TransactionExecutorResult,
};
use crate::execution::call_info::ExecutionSummary;
use crate::execution::contract_class::RunnableCompiledClass;
use crate::fee::gas_usage::get_onchain_data_segment_length;
use crate::fee::resources::TransactionResources;
use crate::state::cached_state::{StateChangesKeys, StorageEntry};
Expand Down Expand Up @@ -565,7 +566,7 @@ pub fn get_casm_hash_calculation_resources<S: StateReader>(
let mut casm_hash_computation_resources = ExecutionResources::default();

for class_hash in executed_class_hashes {
let class = state_reader.get_compiled_class(*class_hash)?;
let class: RunnableCompiledClass = state_reader.get_compiled_class(*class_hash)?.into();
casm_hash_computation_resources += &class.estimate_casm_hash_computation_resources();
}

Expand Down
4 changes: 2 additions & 2 deletions crates/blockifier/src/concurrency/fee_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use starknet_types_core::felt::Felt;
use crate::context::{BlockContext, TransactionContext};
use crate::execution::call_info::CallInfo;
use crate::fee::fee_utils::get_sequencer_balance_keys;
use crate::state::cached_state::{ContractClassMapping, StateMaps};
use crate::state::cached_state::{StateMaps, VersionedContractClassMapping};
use crate::state::state_api::UpdatableState;
use crate::transaction::objects::TransactionExecutionInfo;

Expand Down Expand Up @@ -118,5 +118,5 @@ pub fn add_fee_to_sequencer_balance(
]),
..StateMaps::default()
};
state.apply_writes(&writes, &ContractClassMapping::default(), &HashMap::default());
state.apply_writes(&writes, &VersionedContractClassMapping::default(), &HashMap::default());
}
12 changes: 7 additions & 5 deletions crates/blockifier/src/concurrency/flow_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::abi::sierra_types::{SierraType, SierraU128};
use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus};
use crate::concurrency::test_utils::{safe_versioned_state_for_testing, DEFAULT_CHUNK_SIZE};
use crate::concurrency::versioned_state::ThreadSafeVersionedState;
use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps};
use crate::state::cached_state::{CachedState, StateMaps, VersionedContractClassMapping};
use crate::state::state_api::UpdatableState;
use crate::test_utils::dict_state_reader::DictStateReader;

Expand Down Expand Up @@ -43,14 +43,15 @@ fn scheduler_flow_test(
get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state);
let reads_valid = state_proxy.validate_reads(&reads);
if !reads_valid {
state_proxy.delete_writes(&writes, &ContractClassMapping::default());
state_proxy
.delete_writes(&writes, &VersionedContractClassMapping::default());
let (_, new_writes) = get_reads_writes_for(
Task::ExecutionTask(tx_index),
&versioned_state,
);
state_proxy.apply_writes(
&new_writes,
&ContractClassMapping::default(),
&VersionedContractClassMapping::default(),
&HashMap::default(),
);
scheduler.finish_execution_during_commit(tx_index);
Expand All @@ -63,7 +64,7 @@ fn scheduler_flow_test(
get_reads_writes_for(Task::ExecutionTask(tx_index), &versioned_state);
versioned_state.pin_version(tx_index).apply_writes(
&writes,
&ContractClassMapping::default(),
&VersionedContractClassMapping::default(),
&HashMap::default(),
);
scheduler.finish_execution(tx_index);
Expand All @@ -76,7 +77,8 @@ fn scheduler_flow_test(
let read_set_valid = state_proxy.validate_reads(&reads);
let aborted = !read_set_valid && scheduler.try_validation_abort(tx_index);
if aborted {
state_proxy.delete_writes(&writes, &ContractClassMapping::default());
state_proxy
.delete_writes(&writes, &VersionedContractClassMapping::default());
scheduler.finish_abort(tx_index)
} else {
Task::AskForTask
Expand Down
40 changes: 24 additions & 16 deletions crates/blockifier/src/concurrency/versioned_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use starknet_types_core::felt::Felt;

use crate::concurrency::versioned_storage::VersionedStorage;
use crate::concurrency::TxIndex;
use crate::execution::contract_class::RunnableCompiledClass;
use crate::state::cached_state::{ContractClassMapping, StateMaps};
use crate::execution::contract_class::VersionedRunnableCompiledClass;
use crate::state::cached_state::{StateMaps, VersionedContractClassMapping};
use crate::state::errors::StateError;
use crate::state::state_api::{StateReader, StateResult, UpdatableState};

Expand All @@ -34,7 +34,8 @@ pub struct VersionedState<S: StateReader> {
// the compiled contract classes mapping. Each key with value false, sohuld not apprear
// in the compiled contract classes mapping.
declared_contracts: VersionedStorage<ClassHash, bool>,
compiled_contract_classes: VersionedStorage<ClassHash, RunnableCompiledClass>,
versioned_compiled_contract_classes:
VersionedStorage<ClassHash, VersionedRunnableCompiledClass>,
}

impl<S: StateReader> VersionedState<S> {
Expand All @@ -45,7 +46,7 @@ impl<S: StateReader> VersionedState<S> {
nonces: VersionedStorage::default(),
class_hashes: VersionedStorage::default(),
compiled_class_hashes: VersionedStorage::default(),
compiled_contract_classes: VersionedStorage::default(),
versioned_compiled_contract_classes: VersionedStorage::default(),
declared_contracts: VersionedStorage::default(),
}
}
Expand Down Expand Up @@ -121,7 +122,7 @@ impl<S: StateReader> VersionedState<S> {
let is_declared = self.declared_contracts.read(tx_index, class_hash).expect(READ_ERR);
assert_eq!(
is_declared,
self.compiled_contract_classes.read(tx_index, class_hash).is_some(),
self.versioned_compiled_contract_classes.read(tx_index, class_hash).is_some(),
"The declared contracts mapping should match the compiled contract classes \
mapping."
);
Expand All @@ -139,7 +140,7 @@ impl<S: StateReader> VersionedState<S> {
&mut self,
tx_index: TxIndex,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
class_hash_to_class: &VersionedContractClassMapping,
) {
for (&key, &value) in &writes.storage {
self.storage.write(tx_index, key, value);
Expand All @@ -154,13 +155,13 @@ impl<S: StateReader> VersionedState<S> {
self.compiled_class_hashes.write(tx_index, key, value);
}
for (&key, value) in class_hash_to_class {
self.compiled_contract_classes.write(tx_index, key, value.clone());
self.versioned_compiled_contract_classes.write(tx_index, key, value.clone());
}
for (&key, &value) in &writes.declared_contracts {
self.declared_contracts.write(tx_index, key, value);
assert_eq!(
value,
self.compiled_contract_classes.read(tx_index, key).is_some(),
self.versioned_compiled_contract_classes.read(tx_index, key).is_some(),
"The declared contracts mapping should match the compiled contract classes \
mapping."
);
Expand All @@ -171,7 +172,7 @@ impl<S: StateReader> VersionedState<S> {
&mut self,
tx_index: TxIndex,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
class_hash_to_class: &VersionedContractClassMapping,
) {
for &key in writes.storage.keys() {
self.storage.delete_write(key, tx_index);
Expand All @@ -189,7 +190,7 @@ impl<S: StateReader> VersionedState<S> {
self.declared_contracts.delete_write(key, tx_index);
}
for &key in class_hash_to_class.keys() {
self.compiled_contract_classes.delete_write(key, tx_index);
self.versioned_compiled_contract_classes.delete_write(key, tx_index);
}
}

Expand All @@ -210,7 +211,7 @@ impl<U: UpdatableState> VersionedState<U> {
let commit_index = n_committed_txs - 1;
let writes = self.get_writes_up_to_index(commit_index);
let class_hash_to_class =
self.compiled_contract_classes.get_writes_up_to_index(commit_index);
self.versioned_compiled_contract_classes.get_writes_up_to_index(commit_index);
let mut state = self.into_initial_state();
state.apply_writes(&writes, &class_hash_to_class, &visited_pcs);
state
Expand Down Expand Up @@ -266,7 +267,11 @@ impl<S: StateReader> VersionedStateProxy<S> {
self.state().validate_reads(self.tx_index, reads)
}

pub fn delete_writes(&self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping) {
pub fn delete_writes(
&self,
writes: &StateMaps,
class_hash_to_class: &VersionedContractClassMapping,
) {
self.state().delete_writes(self.tx_index, writes, class_hash_to_class);
}
}
Expand All @@ -276,7 +281,7 @@ impl<S: StateReader> UpdatableState for VersionedStateProxy<S> {
fn apply_writes(
&mut self,
writes: &StateMaps,
class_hash_to_class: &ContractClassMapping,
class_hash_to_class: &VersionedContractClassMapping,
_visited_pcs: &HashMap<ClassHash, HashSet<usize>>,
) {
self.state().apply_writes(self.tx_index, writes, class_hash_to_class)
Expand Down Expand Up @@ -336,15 +341,18 @@ impl<S: StateReader> StateReader for VersionedStateProxy<S> {
}
}

fn get_compiled_class(&self, class_hash: ClassHash) -> StateResult<RunnableCompiledClass> {
fn get_compiled_class(
&self,
class_hash: ClassHash,
) -> StateResult<VersionedRunnableCompiledClass> {
let mut state = self.state();
match state.compiled_contract_classes.read(self.tx_index, class_hash) {
match state.versioned_compiled_contract_classes.read(self.tx_index, class_hash) {
Some(value) => Ok(value),
None => match state.initial_state.get_compiled_class(class_hash) {
Ok(initial_value) => {
state.declared_contracts.set_initial_value(class_hash, true);
state
.compiled_contract_classes
.versioned_compiled_contract_classes
.set_initial_value(class_hash, initial_value.clone());
Ok(initial_value)
}
Expand Down
51 changes: 31 additions & 20 deletions crates/blockifier/src/concurrency/versioned_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ use crate::concurrency::TxIndex;
use crate::context::BlockContext;
use crate::state::cached_state::{
CachedState,
ContractClassMapping,
StateMaps,
TransactionalState,
VersionedContractClassMapping,
};
use crate::state::errors::StateError;
use crate::state::state_api::{State, StateReader, UpdatableState};
Expand Down Expand Up @@ -72,7 +72,7 @@ fn test_versioned_state_proxy() {
let class_hash = class_hash!(27_u8);
let another_class_hash = class_hash!(28_u8);
let compiled_class_hash = compiled_class_hash!(29_u8);
let contract_class = test_contract.get_runnable_class();
let contract_class = test_contract.get_versioned_runnable_class();

// Create the versioned state
let cached_state = CachedState::from(DictStateReader {
Expand Down Expand Up @@ -117,7 +117,7 @@ fn test_versioned_state_proxy() {
let compiled_class_hash_v18 = compiled_class_hash!(30_u8);
let contract_class_v11 =
FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm))
.get_runnable_class();
.get_versioned_runnable_class();

versioned_state_proxys[3].state().apply_writes(
3,
Expand Down Expand Up @@ -404,10 +404,11 @@ fn test_false_validate_reads_declared_contracts(
..Default::default()
};
let version_state_proxy = safe_versioned_state.pin_version(0);
let compiled_contract_calss =
let versioned_compiled_contract_class =
FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm))
.get_runnable_class();
let class_hash_to_class = HashMap::from([(class_hash!(1_u8), compiled_contract_calss)]);
.get_versioned_runnable_class();
let class_hash_to_class =
HashMap::from([(class_hash!(1_u8), versioned_compiled_contract_class)]);
version_state_proxy.state().apply_writes(0, &tx_0_writes, &class_hash_to_class);
assert!(!safe_versioned_state.pin_version(1).validate_reads(&tx_1_reads));
}
Expand All @@ -431,11 +432,15 @@ fn test_apply_writes(
assert_eq!(transactional_states[0].cache.borrow().writes.class_hashes.len(), 1);

// Transaction 0 contract class.
let contract_class_0 =

let versioned_contract_class_0 =
FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm))
.get_runnable_class();
.get_versioned_runnable_class();

assert!(transactional_states[0].class_hash_to_class.borrow().is_empty());
transactional_states[0].set_contract_class(class_hash, contract_class_0.clone()).unwrap();
transactional_states[0]
.set_contract_class(class_hash, versioned_contract_class_0.clone())
.unwrap();
assert_eq!(transactional_states[0].class_hash_to_class.borrow().len(), 1);

safe_versioned_state.pin_version(0).apply_writes(
Expand All @@ -444,7 +449,10 @@ fn test_apply_writes(
&HashMap::default(),
);
assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0);
assert!(transactional_states[1].get_compiled_class(class_hash).unwrap() == contract_class_0);
assert!(
transactional_states[1].get_compiled_class(class_hash).unwrap()
== versioned_contract_class_0
);
}

#[rstest]
Expand Down Expand Up @@ -513,7 +521,7 @@ fn test_delete_writes(
tx_state
.set_contract_class(
feature_contract.get_class_hash(),
feature_contract.get_runnable_class(),
feature_contract.get_versioned_runnable_class(),
)
.unwrap();
safe_versioned_state.pin_version(i).apply_writes(
Expand Down Expand Up @@ -546,7 +554,7 @@ fn test_delete_writes(
.0
.lock()
.unwrap()
.compiled_contract_classes
.versioned_compiled_contract_classes
.get_writes_of_index(tx_index)
.is_empty(),
should_be_empty
Expand All @@ -573,8 +581,10 @@ fn test_delete_writes_completeness(
)]),
declared_contracts: HashMap::from([(feature_contract.get_class_hash(), true)]),
};
let class_hash_to_class_writes =
HashMap::from([(feature_contract.get_class_hash(), feature_contract.get_runnable_class())]);
let class_hash_to_class_writes = HashMap::from([(
feature_contract.get_class_hash(),
feature_contract.get_versioned_runnable_class(),
)]);

let tx_index = 0;
let mut versioned_state_proxy = safe_versioned_state.pin_version(tx_index);
Expand All @@ -593,7 +603,7 @@ fn test_delete_writes_completeness(
.0
.lock()
.unwrap()
.compiled_contract_classes
.versioned_compiled_contract_classes
.get_writes_of_index(tx_index),
class_hash_to_class_writes
);
Expand All @@ -608,9 +618,9 @@ fn test_delete_writes_completeness(
.0
.lock()
.unwrap()
.compiled_contract_classes
.versioned_compiled_contract_classes
.get_writes_of_index(tx_index),
ContractClassMapping::default()
VersionedContractClassMapping::default()
);
}

Expand All @@ -637,10 +647,11 @@ fn test_versioned_proxy_state_flow(
transactional_states[3].set_class_hash_at(contract_address, class_hash_3).unwrap();

// Clients contract class values.
let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo0).get_runnable_class();
let contract_class_0 =
FeatureContract::TestContract(CairoVersion::Cairo0).get_versioned_runnable_class();
let contract_class_2 =
FeatureContract::AccountWithLongValidate(CairoVersion::Cairo1(RunnableCairo1::Casm))
.get_runnable_class();
FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Casm))
.get_versioned_runnable_class();

transactional_states[0].set_contract_class(class_hash, contract_class_0).unwrap();
transactional_states[2].set_contract_class(class_hash, contract_class_2.clone()).unwrap();
Expand Down
Loading

0 comments on commit c8f8585

Please sign in to comment.