diff --git a/crates/blockifier/src/state.rs b/crates/blockifier/src/state.rs index 8aa857c963a..179737b9f36 100644 --- a/crates/blockifier/src/state.rs +++ b/crates/blockifier/src/state.rs @@ -6,3 +6,4 @@ pub mod error_format_test; pub mod errors; pub mod global_cache; pub mod state_api; +pub mod stateful_compression; diff --git a/crates/blockifier/src/state/stateful_compression.rs b/crates/blockifier/src/state/stateful_compression.rs new file mode 100644 index 00000000000..95f02fd9c81 --- /dev/null +++ b/crates/blockifier/src/state/stateful_compression.rs @@ -0,0 +1,82 @@ +use std::collections::{BTreeSet, HashMap}; + +use starknet_api::core::{ContractAddress, PatriciaKey}; +use starknet_api::state::StorageKey; +use starknet_types_core::felt::Felt; + +use super::cached_state::CachedState; +use super::state_api::{State, StateReader, StateResult}; + +#[cfg(test)] +#[path = "stateful_compression_test.rs"] +pub mod stateful_compression_test; + +type Alias = Felt; +type AliasKey = StorageKey; + +// The address of the alias contract. +const ALIAS_CONTRACT_ADDRESS: u8 = 2; +// The storage key of the alias counter in the alias contract. +const ALIAS_COUNTER_STORAGE_KEY: u8 = 0; +// The minimal value for a key to be allocated an alias. Smaller keys are serialized as is (their +// alias is identical to the key). +const MIN_VALUE_FOR_ALIAS_ALLOC: Felt = Felt::from_hex_unchecked("0x80"); + +pub fn get_alias_contract_address() -> ContractAddress { + ContractAddress::from(ALIAS_CONTRACT_ADDRESS) +} +pub fn get_alias_counter_storage_key() -> StorageKey { + StorageKey::from(ALIAS_COUNTER_STORAGE_KEY) +} + +/// Updates the alias contract with the new keys. +struct AliasUpdater<'a, S: StateReader> { + state: &'a mut CachedState, + next_free_alias: Alias, + was_updated: bool, +} + +impl<'a, S: StateReader> AliasUpdater<'a, S> { + fn new(state: &'a mut CachedState) -> StateResult { + let next_free_alias = + state.get_storage_at(get_alias_contract_address(), get_alias_counter_storage_key())?; + Ok(Self { + state, + next_free_alias: if next_free_alias == Felt::ZERO { + // Aliasing first time. + MIN_VALUE_FOR_ALIAS_ALLOC + } else { + next_free_alias + }, + was_updated: false, + }) + } + + /// Inserts the alias key to the updates if it's not already aliased. + fn set_alias(&mut self, alias_key: &AliasKey) -> StateResult<()> { + if alias_key.0 >= PatriciaKey::try_from(MIN_VALUE_FOR_ALIAS_ALLOC)? + && self.state.get_storage_at(get_alias_contract_address(), *alias_key)? == Felt::ZERO + { + self.state.set_storage_at( + get_alias_contract_address(), + *alias_key, + self.next_free_alias, + )?; + self.was_updated = true; + self.next_free_alias += Felt::ONE; + } + Ok(()) + } + + /// Finalizes the updates and inserts them to the state changes. + fn finalize_updates(self) -> StateResult<()> { + if self.was_updated { + self.state.set_storage_at( + get_alias_contract_address(), + get_alias_counter_storage_key(), + self.next_free_alias, + )?; + } + Ok(()) + } +} diff --git a/crates/blockifier/src/state/stateful_compression_test.rs b/crates/blockifier/src/state/stateful_compression_test.rs new file mode 100644 index 00000000000..97d72dd0abd --- /dev/null +++ b/crates/blockifier/src/state/stateful_compression_test.rs @@ -0,0 +1,99 @@ +use std::collections::HashMap; + +use rstest::rstest; +use starknet_api::core::ContractAddress; +use starknet_api::state::StorageKey; +use starknet_types_core::felt::Felt; + +use super::{ + get_alias_contract_address, + get_alias_counter_storage_key, + AliasUpdater, + MIN_VALUE_FOR_ALIAS_ALLOC, +}; +use crate::state::cached_state::CachedState; +use crate::test_utils::dict_state_reader::DictStateReader; + +fn insert_to_alias_contract( + storage: &mut HashMap<(ContractAddress, StorageKey), Felt>, + key: StorageKey, + value: Felt, +) { + storage.insert((get_alias_contract_address(), key), value); +} + +fn initial_state(n_exist_aliases: u8) -> CachedState { + let mut state_reader = DictStateReader::default(); + if n_exist_aliases > 0 { + let high_alias_key = MIN_VALUE_FOR_ALIAS_ALLOC * Felt::TWO; + insert_to_alias_contract( + &mut state_reader.storage_view, + get_alias_counter_storage_key(), + MIN_VALUE_FOR_ALIAS_ALLOC + Felt::from(n_exist_aliases), + ); + for i in 0..n_exist_aliases { + insert_to_alias_contract( + &mut state_reader.storage_view, + (high_alias_key + Felt::from(i)).try_into().unwrap(), + MIN_VALUE_FOR_ALIAS_ALLOC + Felt::from(i), + ); + } + } + + CachedState::new(state_reader) +} + +/// Tests the alias contract updater with an empty state. +#[rstest] +#[case::no_update(vec![], vec![])] +#[case::low_update(vec![MIN_VALUE_FOR_ALIAS_ALLOC - 1], vec![])] +#[case::single_update(vec![MIN_VALUE_FOR_ALIAS_ALLOC], vec![MIN_VALUE_FOR_ALIAS_ALLOC])] +#[case::some_update( + vec![ + MIN_VALUE_FOR_ALIAS_ALLOC + 1, + MIN_VALUE_FOR_ALIAS_ALLOC - 1, + MIN_VALUE_FOR_ALIAS_ALLOC, + MIN_VALUE_FOR_ALIAS_ALLOC + 2, + MIN_VALUE_FOR_ALIAS_ALLOC, + ], + vec![ + MIN_VALUE_FOR_ALIAS_ALLOC + 1, + MIN_VALUE_FOR_ALIAS_ALLOC, + MIN_VALUE_FOR_ALIAS_ALLOC + 2, + ] +)] +fn test_alias_updater( + #[case] keys: Vec, + #[case] expected_alias_keys: Vec, + #[values(0, 2)] n_exist_aliases: u8, +) { + let mut state = initial_state(n_exist_aliases); + + // Insert the keys into the alias contract updater and finalize the updates. + let mut alias_contract_updater = AliasUpdater::new(&mut state).unwrap(); + for key in keys { + alias_contract_updater.set_alias(&StorageKey::try_from(key).unwrap()).unwrap(); + } + alias_contract_updater.finalize_updates().unwrap(); + let storage_diff = state.to_state_diff().unwrap().state_maps.storage; + + // Test the new aliases. + let mut expeceted_storage_diff = HashMap::new(); + if !expected_alias_keys.is_empty() { + let mut next_alias = MIN_VALUE_FOR_ALIAS_ALLOC + Felt::from(n_exist_aliases); + for key in expected_alias_keys { + insert_to_alias_contract( + &mut expeceted_storage_diff, + StorageKey::try_from(key).unwrap(), + next_alias, + ); + next_alias += Felt::ONE; + } + insert_to_alias_contract( + &mut expeceted_storage_diff, + get_alias_counter_storage_key(), + next_alias, + ); + } + assert_eq!(storage_diff, expeceted_storage_diff); +}