diff --git a/crates/blockifier/src/blockifier/transaction_executor.rs b/crates/blockifier/src/blockifier/transaction_executor.rs index 79bba304be..55bea5add7 100644 --- a/crates/blockifier/src/blockifier/transaction_executor.rs +++ b/crates/blockifier/src/blockifier/transaction_executor.rs @@ -16,7 +16,7 @@ use crate::context::BlockContext; use crate::state::cached_state::{CachedState, CommitmentStateDiff, TransactionalState}; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult}; -use crate::state::stateful_compression::state_diff_with_alias_allocation; +use crate::state::stateful_compression::allocate_aliases_in_storage; use crate::transaction::errors::TransactionExecutionError; use crate::transaction::objects::TransactionExecutionInfo; use crate::transaction::transaction_execution::Transaction; @@ -168,19 +168,18 @@ impl TransactionExecutor { .collect::>()?; log::debug!("Final block weights: {:?}.", self.bouncer.get_accumulated_weights()); - let mut block_state = self.block_state.take().expect(BLOCK_STATE_ACCESS_ERR); - let state_diff = if self.block_context.versioned_constants.enable_stateful_compression { - state_diff_with_alias_allocation( - &mut block_state, + let block_state = self.block_state.as_mut().expect(BLOCK_STATE_ACCESS_ERR); + if self.block_context.versioned_constants.enable_stateful_compression { + allocate_aliases_in_storage( + block_state, self.block_context .versioned_constants .os_constants .os_contract_addresses .alias_contract_address(), - )? - } else { - block_state.to_state_diff()?.state_maps - }; + )?; + } + let state_diff = block_state.to_state_diff()?.state_maps; Ok((state_diff.into(), visited_segments, *self.bouncer.get_accumulated_weights())) } } diff --git a/crates/blockifier/src/state/stateful_compression.rs b/crates/blockifier/src/state/stateful_compression.rs index 4136278628..44f4937dc9 100644 --- a/crates/blockifier/src/state/stateful_compression.rs +++ b/crates/blockifier/src/state/stateful_compression.rs @@ -6,9 +6,9 @@ use starknet_api::StarknetApiError; use starknet_types_core::felt::Felt; use thiserror::Error; -use super::cached_state::{CachedState, StateMaps, StorageEntry}; +use super::cached_state::{CachedState, StateMaps}; use super::errors::StateError; -use super::state_api::{StateReader, StateResult}; +use super::state_api::{State, StateReader, StateResult}; #[cfg(test)] #[path = "stateful_compression_test.rs"] @@ -46,11 +46,11 @@ pub const MIN_VALUE_FOR_ALIAS_ALLOC: PatriciaKey = /// Allocates aliases for the new addresses and storage keys in the alias contract. /// Iterates over the addresses in ascending order. For each address, sets an alias for the new /// storage keys (in ascending order) and for the address itself. -pub fn state_diff_with_alias_allocation( +pub fn allocate_aliases_in_storage( state: &mut CachedState, alias_contract_address: ContractAddress, -) -> StateResult { - let mut state_diff = state.to_state_diff()?.state_maps; +) -> StateResult<()> { + let state_diff = state.to_state_diff()?.state_maps; // Collect the contract addresses and the storage keys that need aliases. let contract_addresses: BTreeSet = @@ -76,65 +76,62 @@ pub fn state_diff_with_alias_allocation( alias_updater.insert_alias(&StorageKey(contract_address.0))?; } - let alias_storage_updates = alias_updater.finalize_updates(); - state_diff.storage.extend(alias_storage_updates); - Ok(state_diff) + alias_updater.finalize_updates() } -/// Generate updates for the alias contract with the new keys. -struct AliasUpdater<'a, S: StateReader> { - state: &'a S, - new_aliases: HashMap, +/// Updates the alias contract with the new keys. +struct AliasUpdater<'a, S: State> { + state: &'a mut S, + is_alias_inserted: bool, next_free_alias: Option, alias_contract_address: ContractAddress, } -impl<'a, S: StateReader> AliasUpdater<'a, S> { - fn new(state: &'a S, alias_contract_address: ContractAddress) -> StateResult { +impl<'a, S: State> AliasUpdater<'a, S> { + fn new(state: &'a mut S, alias_contract_address: ContractAddress) -> StateResult { let stored_counter = state.get_storage_at(alias_contract_address, ALIAS_COUNTER_STORAGE_KEY)?; Ok(Self { state, - new_aliases: HashMap::new(), + is_alias_inserted: false, next_free_alias: if stored_counter == Felt::ZERO { None } else { Some(stored_counter) }, alias_contract_address, }) } + fn set_alias_in_storage(&mut self, alias_key: AliasKey, alias: Alias) -> StateResult<()> { + self.state.set_storage_at(self.alias_contract_address, alias_key, alias) + } + /// Inserts the alias key to the updates if it's not already aliased. fn insert_alias(&mut self, alias_key: &AliasKey) -> StateResult<()> { if alias_key.0 >= MIN_VALUE_FOR_ALIAS_ALLOC && self.state.get_storage_at(self.alias_contract_address, *alias_key)? == Felt::ZERO - && !self.new_aliases.contains_key(alias_key) { let alias_to_allocate = match self.next_free_alias { Some(alias) => alias, None => INITIAL_AVAILABLE_ALIAS, }; - self.new_aliases.insert(*alias_key, alias_to_allocate); + self.set_alias_in_storage(*alias_key, alias_to_allocate)?; + self.is_alias_inserted = true; self.next_free_alias = Some(alias_to_allocate + Felt::ONE); } Ok(()) } - /// Inserts the counter of the alias contract. Returns the storage updates for the alias - /// contract. - fn finalize_updates(mut self) -> HashMap { + /// Inserts the counter of the alias contract. + fn finalize_updates(mut self) -> StateResult<()> { match self.next_free_alias { None => { - self.new_aliases.insert(ALIAS_COUNTER_STORAGE_KEY, INITIAL_AVAILABLE_ALIAS); + self.set_alias_in_storage(ALIAS_COUNTER_STORAGE_KEY, INITIAL_AVAILABLE_ALIAS)?; } Some(alias) => { - if !self.new_aliases.is_empty() { - self.new_aliases.insert(ALIAS_COUNTER_STORAGE_KEY, alias); + if self.is_alias_inserted { + self.set_alias_in_storage(ALIAS_COUNTER_STORAGE_KEY, alias)?; } } } - - self.new_aliases - .into_iter() - .map(|(key, alias)| ((self.alias_contract_address, key), alias)) - .collect() + Ok(()) } } diff --git a/crates/blockifier/src/state/stateful_compression_test.rs b/crates/blockifier/src/state/stateful_compression_test.rs index 72898f112b..c6bc7ea487 100644 --- a/crates/blockifier/src/state/stateful_compression_test.rs +++ b/crates/blockifier/src/state/stateful_compression_test.rs @@ -9,8 +9,8 @@ use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; use super::{ + allocate_aliases_in_storage, compress, - state_diff_with_alias_allocation, Alias, AliasKey, AliasUpdater, @@ -164,14 +164,16 @@ fn test_alias_updater( #[case] expected_alias_keys: Vec, #[values(0, 2)] n_existing_aliases: u8, ) { - let state = initial_state(n_existing_aliases); + let mut state = initial_state(n_existing_aliases); // Insert the keys into the alias contract updater and finalize the updates. - let mut alias_contract_updater = AliasUpdater::new(&state, *ALIAS_CONTRACT_ADDRESS).unwrap(); + let mut alias_contract_updater = + AliasUpdater::new(&mut state, *ALIAS_CONTRACT_ADDRESS).unwrap(); for key in keys { alias_contract_updater.insert_alias(&StorageKey::try_from(key).unwrap()).unwrap(); } - let storage_diff = alias_contract_updater.finalize_updates(); + alias_contract_updater.finalize_updates().unwrap(); + let storage_diff = state.to_state_diff().unwrap().state_maps.storage; // Test the new aliases. let mut expected_storage_diff = HashMap::new(); @@ -214,8 +216,9 @@ fn test_iterate_aliases() { state.set_class_hash_at(ContractAddress::from(0x202_u16), ClassHash(Felt::ONE)).unwrap(); state.increment_nonce(ContractAddress::from(0x200_u16)).unwrap(); - let storage_diff = - state_diff_with_alias_allocation(&mut state, *ALIAS_CONTRACT_ADDRESS).unwrap().storage; + allocate_aliases_in_storage(&mut state, *ALIAS_CONTRACT_ADDRESS).unwrap(); + let storage_diff = state.to_state_diff().unwrap().state_maps.storage; + assert_eq!( storage_diff, vec![ @@ -262,8 +265,8 @@ fn test_read_only_state(#[values(0, 2)] n_existing_aliases: u8) { .unwrap(); state.get_nonce_at(ContractAddress::from(0x201_u16)).unwrap(); state.get_class_hash_at(ContractAddress::from(0x202_u16)).unwrap(); - let storage_diff = - state_diff_with_alias_allocation(&mut state, *ALIAS_CONTRACT_ADDRESS).unwrap().storage; + allocate_aliases_in_storage(&mut state, *ALIAS_CONTRACT_ADDRESS).unwrap(); + let storage_diff = state.to_state_diff().unwrap().state_maps.storage; let expected_storage_diff = if n_existing_aliases == 0 { HashMap::from([(