Skip to content

Commit

Permalink
feat(blockifier): iterate over aliases on the state diff (#2535)
Browse files Browse the repository at this point in the history
  • Loading branch information
yoavGrs authored Dec 23, 2024
1 parent 66ad912 commit 28e7cb8
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 31 deletions.
19 changes: 9 additions & 10 deletions crates/blockifier/src/blockifier/transaction_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::context::BlockContext;
use crate::state::cached_state::{CachedState, CommitmentStateDiff, TransactionalState};
use crate::state::errors::StateError;
use crate::state::state_api::{StateReader, StateResult};
use crate::state::stateful_compression::state_diff_with_alias_allocation;
use crate::transaction::errors::TransactionExecutionError;
use crate::transaction::objects::TransactionExecutionInfo;
use crate::transaction::transaction_execution::Transaction;
Expand Down Expand Up @@ -142,6 +143,7 @@ impl<S: StateReader> TransactionExecutor<S> {

/// Returns the state diff, a list of contract class hash with the corresponding list of
/// visited segment values and the block weights.
// TODO(Yoav): Consume "self".
pub fn finalize(
&mut self,
) -> TransactionExecutorResult<(CommitmentStateDiff, VisitedSegmentsMapping, BouncerWeights)>
Expand All @@ -166,16 +168,13 @@ impl<S: StateReader> TransactionExecutor<S> {
.collect::<TransactionExecutorResult<_>>()?;

log::debug!("Final block weights: {:?}.", self.bouncer.get_accumulated_weights());
Ok((
self.block_state
.as_mut()
.expect(BLOCK_STATE_ACCESS_ERR)
.to_state_diff()?
.state_maps
.into(),
visited_segments,
*self.bouncer.get_accumulated_weights(),
))
let mut block_state = self.block_state.take().expect(BLOCK_STATE_ACCESS_ERR);
let state_diff = if self.block_context.versioned_constants.enable_stateful_compression {
state_diff_with_alias_allocation(&mut block_state)?
} else {
block_state.to_state_diff()?.state_maps
};
Ok((state_diff.into(), visited_segments, *self.bouncer.get_accumulated_weights()))
}
}

Expand Down
1 change: 0 additions & 1 deletion crates/blockifier/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@ pub mod error_format_test;
pub mod errors;
pub mod global_cache;
pub mod state_api;
#[allow(dead_code)]
pub mod stateful_compression;
7 changes: 4 additions & 3 deletions crates/blockifier/src/state/cached_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub type ContractClassMapping = HashMap<ClassHash, RunnableCompiledClass>;
///
/// Writer functionality is builtin, whereas Reader functionality is injected through
/// initialization.
#[cfg_attr(any(test, feature = "reexecution"), derive(Clone))]
#[derive(Debug)]
pub struct CachedState<S: StateReader> {
pub state: S,
Expand Down Expand Up @@ -362,7 +363,7 @@ impl StateMaps {
}
}

pub fn get_modified_contracts(&self) -> HashSet<ContractAddress> {
pub fn get_contract_addresses(&self) -> HashSet<ContractAddress> {
// Storage updates.
let mut modified_contracts: HashSet<ContractAddress> =
self.storage.keys().map(|address_key_pair| address_key_pair.0).collect();
Expand All @@ -376,7 +377,7 @@ impl StateMaps {

pub fn into_keys(self) -> StateChangesKeys {
StateChangesKeys {
modified_contracts: self.get_modified_contracts(),
modified_contracts: self.get_contract_addresses(),
nonce_keys: self.nonces.into_keys().collect(),
class_hash_keys: self.class_hashes.into_keys().collect(),
storage_keys: self.storage.into_keys().collect(),
Expand Down Expand Up @@ -762,7 +763,7 @@ impl StateChanges {
sender_address: Option<ContractAddress>,
fee_token_address: ContractAddress,
) -> StateChangesCountForFee {
let mut modified_contracts = self.state_maps.get_modified_contracts();
let mut modified_contracts = self.state_maps.get_contract_addresses();

// For account transactions, we need to compute the transaction fee before we can execute
// the fee transfer, and the fee should cover the state changes that happen in the
Expand Down
59 changes: 50 additions & 9 deletions crates/blockifier/src/state/stateful_compression.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::collections::HashMap;
use std::sync::LazyLock;
use std::collections::{BTreeSet, HashMap};

use starknet_api::core::{ContractAddress, PatriciaKey};
use starknet_api::state::StorageKey;
use starknet_types_core::felt::Felt;

use super::cached_state::{CachedState, StorageEntry};
use super::cached_state::{CachedState, StateMaps, StorageEntry};
use super::state_api::{StateReader, StateResult};

#[cfg(test)]
Expand All @@ -16,26 +15,68 @@ type Alias = Felt;
type AliasKey = StorageKey;

// The initial alias available for allocation.
const INITIAL_AVAILABLE_ALIAS: Felt = Felt::from_hex_unchecked("0x80");
const INITIAL_AVAILABLE_ALIAS_HEX: &str = "0x80";
const INITIAL_AVAILABLE_ALIAS: Felt = Felt::from_hex_unchecked(INITIAL_AVAILABLE_ALIAS_HEX);

// The address of the alias contract.
const ALIAS_CONTRACT_ADDRESS: ContractAddress = ContractAddress(PatriciaKey::TWO);
// The storage key of the alias counter in the alias contract.
const ALIAS_COUNTER_STORAGE_KEY: StorageKey = StorageKey(PatriciaKey::ZERO);
// The maximal contract address for which aliases are not used and all keys are serialized as is,
// without compression.
const MAX_NON_COMPRESSED_CONTRACT_ADDRESS: ContractAddress =
ContractAddress(PatriciaKey::from_hex_unchecked("0xf"));
// The minimal value for a key to be allocated an alias. Smaller keys are serialized as is (their
// alias is identical to the key).
static MIN_VALUE_FOR_ALIAS_ALLOC: LazyLock<PatriciaKey> =
LazyLock::new(|| PatriciaKey::try_from(INITIAL_AVAILABLE_ALIAS).unwrap());
const MIN_VALUE_FOR_ALIAS_ALLOC: PatriciaKey =
PatriciaKey::from_hex_unchecked(INITIAL_AVAILABLE_ALIAS_HEX);

/// Allocates aliases for the new addresses and storage keys in the alias contract.
/// Iterates over the addresses in ascending order. For each address, sets an alias for the new
/// storage keys (in ascending order) and for the address itself.
pub fn state_diff_with_alias_allocation<S: StateReader>(
state: &mut CachedState<S>,
) -> StateResult<StateMaps> {
let mut state_diff = state.to_state_diff()?.state_maps;

// Collect the contract addresses and the storage keys that need aliases.
let contract_addresses: BTreeSet<ContractAddress> =
state_diff.get_contract_addresses().into_iter().collect();
let mut contract_address_to_sorted_storage_keys = HashMap::new();
for (contract_address, storage_key) in state_diff.storage.keys() {
if contract_address > &MAX_NON_COMPRESSED_CONTRACT_ADDRESS {
contract_address_to_sorted_storage_keys
.entry(contract_address)
.or_insert_with(BTreeSet::new)
.insert(storage_key);
}
}

// Iterate over the addresses and the storage keys and update the aliases.
let mut alias_updater = AliasUpdater::new(state)?;
for contract_address in contract_addresses {
if let Some(storage_keys) = contract_address_to_sorted_storage_keys.get(&contract_address) {
for key in storage_keys {
alias_updater.insert_alias(key)?;
}
}
alias_updater.insert_alias(&StorageKey(contract_address.0))?;
}

let alias_storage_updates = alias_updater.finalize_updates();
state_diff.storage.extend(alias_storage_updates);
Ok(state_diff)
}

/// Generate updates for the alias contract with the new keys.
struct AliasUpdater<'a, S: StateReader> {
state: &'a CachedState<S>,
state: &'a S,
new_aliases: HashMap<AliasKey, Alias>,
next_free_alias: Option<Alias>,
}

impl<'a, S: StateReader> AliasUpdater<'a, S> {
fn new(state: &'a CachedState<S>) -> StateResult<Self> {
fn new(state: &'a S) -> StateResult<Self> {
let stored_counter =
state.get_storage_at(ALIAS_CONTRACT_ADDRESS, ALIAS_COUNTER_STORAGE_KEY)?;
Ok(Self {
Expand All @@ -47,7 +88,7 @@ impl<'a, S: StateReader> AliasUpdater<'a, S> {

/// Inserts the alias key to the updates if it's not already aliased.
fn insert_alias(&mut self, alias_key: &AliasKey) -> StateResult<()> {
if alias_key.0 >= *MIN_VALUE_FOR_ALIAS_ALLOC
if alias_key.0 >= MIN_VALUE_FOR_ALIAS_ALLOC
&& self.state.get_storage_at(ALIAS_CONTRACT_ADDRESS, *alias_key)? == Felt::ZERO
&& !self.new_aliases.contains_key(alias_key)
{
Expand Down
83 changes: 83 additions & 0 deletions crates/blockifier/src/state/stateful_compression_test.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
use std::collections::HashMap;

use rstest::rstest;
use starknet_api::core::{ClassHash, ContractAddress};
use starknet_api::state::StorageKey;
use starknet_types_core::felt::Felt;

use super::{
state_diff_with_alias_allocation,
AliasUpdater,
ALIAS_CONTRACT_ADDRESS,
ALIAS_COUNTER_STORAGE_KEY,
INITIAL_AVAILABLE_ALIAS,
MAX_NON_COMPRESSED_CONTRACT_ADDRESS,
};
use crate::state::cached_state::{CachedState, StorageEntry};
use crate::state::state_api::{State, StateReader};
use crate::test_utils::dict_state_reader::DictStateReader;

fn insert_to_alias_contract(
Expand Down Expand Up @@ -96,3 +100,82 @@ fn test_alias_updater(

assert_eq!(storage_diff, expected_storage_diff);
}

#[test]
fn test_iterate_aliases() {
let mut state = initial_state(0);
state
.set_storage_at(ContractAddress::from(0x201_u16), StorageKey::from(0x307_u16), Felt::ONE)
.unwrap();
state
.set_storage_at(ContractAddress::from(0x201_u16), StorageKey::from(0x309_u16), Felt::TWO)
.unwrap();
state
.set_storage_at(ContractAddress::from(0x201_u16), StorageKey::from(0x304_u16), Felt::THREE)
.unwrap();
state
.set_storage_at(MAX_NON_COMPRESSED_CONTRACT_ADDRESS, StorageKey::from(0x301_u16), Felt::ONE)
.unwrap();
state.get_class_hash_at(ContractAddress::from(0x202_u16)).unwrap();
state.set_class_hash_at(ContractAddress::from(0x202_u16), ClassHash(Felt::ONE)).unwrap();
state.increment_nonce(ContractAddress::from(0x200_u16)).unwrap();

let storage_diff = state_diff_with_alias_allocation(&mut state).unwrap().storage;
assert_eq!(
storage_diff,
vec![
(
(ALIAS_CONTRACT_ADDRESS, ALIAS_COUNTER_STORAGE_KEY),
INITIAL_AVAILABLE_ALIAS + Felt::from(6_u8)
),
((ALIAS_CONTRACT_ADDRESS, StorageKey::from(0x200_u16)), INITIAL_AVAILABLE_ALIAS),
(
(ALIAS_CONTRACT_ADDRESS, StorageKey::from(0x304_u16)),
INITIAL_AVAILABLE_ALIAS + Felt::ONE
),
(
(ALIAS_CONTRACT_ADDRESS, StorageKey::from(0x307_u16)),
INITIAL_AVAILABLE_ALIAS + Felt::TWO
),
(
(ALIAS_CONTRACT_ADDRESS, StorageKey::from(0x309_u16)),
INITIAL_AVAILABLE_ALIAS + Felt::THREE
),
(
(ALIAS_CONTRACT_ADDRESS, StorageKey::from(0x201_u16)),
INITIAL_AVAILABLE_ALIAS + Felt::from(4_u8)
),
(
(ALIAS_CONTRACT_ADDRESS, StorageKey::from(0x202_u16)),
INITIAL_AVAILABLE_ALIAS + Felt::from(5_u8)
),
((ContractAddress::from(0x201_u16), StorageKey::from(0x304_u16)), Felt::THREE),
((ContractAddress::from(0x201_u16), StorageKey::from(0x307_u16)), Felt::ONE),
((ContractAddress::from(0x201_u16), StorageKey::from(0x309_u16)), Felt::TWO),
((MAX_NON_COMPRESSED_CONTRACT_ADDRESS, StorageKey::from(0x301_u16)), Felt::ONE),
]
.into_iter()
.collect()
);
}

#[rstest]
fn test_read_only_state(#[values(0, 2)] n_existing_aliases: u8) {
let mut state = initial_state(n_existing_aliases);
state
.set_storage_at(ContractAddress::from(0x200_u16), StorageKey::from(0x300_u16), Felt::ZERO)
.unwrap();
state.get_nonce_at(ContractAddress::from(0x201_u16)).unwrap();
state.get_class_hash_at(ContractAddress::from(0x202_u16)).unwrap();
let storage_diff = state_diff_with_alias_allocation(&mut state).unwrap().storage;

let expected_storage_diff = if n_existing_aliases == 0 {
HashMap::from([(
(ALIAS_CONTRACT_ADDRESS, ALIAS_COUNTER_STORAGE_KEY),
INITIAL_AVAILABLE_ALIAS,
)])
} else {
HashMap::new()
};
assert_eq!(storage_diff, expected_storage_diff);
}
Original file line number Diff line number Diff line change
Expand Up @@ -1437,7 +1437,7 @@ fn test_count_actual_storage_changes(
n_allocated_keys: 2,
};

assert_eq!(expected_modified_contracts, state_changes_1.state_maps.get_modified_contracts());
assert_eq!(expected_modified_contracts, state_changes_1.state_maps.get_contract_addresses());
assert_eq!(expected_storage_updates_1, state_changes_1.state_maps.storage);
assert_eq!(state_changes_count_1, expected_state_changes_count_1);

Expand Down Expand Up @@ -1477,7 +1477,7 @@ fn test_count_actual_storage_changes(
n_allocated_keys: 0,
};

assert_eq!(expected_modified_contracts_2, state_changes_2.state_maps.get_modified_contracts());
assert_eq!(expected_modified_contracts_2, state_changes_2.state_maps.get_contract_addresses());
assert_eq!(expected_storage_updates_2, state_changes_2.state_maps.storage);
assert_eq!(state_changes_count_2, expected_state_changes_count_2);

Expand Down Expand Up @@ -1529,7 +1529,7 @@ fn test_count_actual_storage_changes(

assert_eq!(
expected_modified_contracts_transfer,
state_changes_transfer.state_maps.get_modified_contracts()
state_changes_transfer.state_maps.get_contract_addresses()
);
assert_eq!(expected_storage_update_transfer, state_changes_transfer.state_maps.storage);
assert_eq!(state_changes_count_3, expected_state_changes_count_3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl From<SerializableOfflineReexecutionData> for OfflineReexecutionData {
}
}

#[derive(Default)]
#[derive(Clone, Default)]
pub struct OfflineStateReader {
pub state_maps: StateMaps,
pub contract_class_mapping: StarknetContractClassMapping,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ pub struct GetTransactionByHashParams {
pub transaction_hash: String,
}

#[derive(Clone)]
pub struct RetryConfig {
pub(crate) n_retries: usize,
pub(crate) retry_interval_milliseconds: u64,
Expand All @@ -86,6 +87,7 @@ impl Default for RetryConfig {
}
}

#[derive(Clone)]
pub struct TestStateReader {
pub(crate) rpc_state_reader: RpcStateReader,
pub(crate) retry_config: RetryConfig,
Expand Down
8 changes: 6 additions & 2 deletions crates/blockifier_reexecution/src/state_reader/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ impl From<CommitmentStateDiff> for ComparableStateDiff {
}

pub fn reexecute_and_verify_correctness<
S: StateReader + Send + Sync,
S: StateReader + Send + Sync + Clone,
T: ConsecutiveReexecutionStateReaders<S>,
>(
consecutive_state_readers: T,
Expand All @@ -232,13 +232,17 @@ pub fn reexecute_and_verify_correctness<
assert_matches!(res, Ok(_));
}

// TODO(Yoav): Return the block state after the modifications in finalize().
// Note that after finalizing, the block state is None.
let block_state = transaction_executor.block_state.clone();

// Finalize block and read actual statediff.
let (actual_state_diff, _, _) =
transaction_executor.finalize().expect("Couldn't finalize block");

assert_eq_state_diff!(expected_state_diff, actual_state_diff);

transaction_executor.block_state
block_state
}

pub fn reexecute_block_for_testing(block_number: u64) {
Expand Down
4 changes: 4 additions & 0 deletions crates/starknet_api/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,10 @@ impl PatriciaKey {
pub fn key(&self) -> &StarkHash {
&self.0
}

pub const fn from_hex_unchecked(val: &str) -> Self {
Self(StarkHash::from_hex_unchecked(val))
}
}

impl From<u128> for PatriciaKey {
Expand Down
1 change: 1 addition & 0 deletions crates/starknet_gateway/src/rpc_state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use crate::rpc_objects::{
};
use crate::state_reader::{MempoolStateReader, StateReaderFactory};

#[derive(Clone)]
pub struct RpcStateReader {
pub config: RpcStateReaderConfig,
pub block_id: BlockId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ async fn end_to_end_flow(mut tx_generator: MultiAccountTransactionGenerator) {
let heights_to_build = next_height.iter_up_to(LAST_HEIGHT.unchecked_next());
let expected_content_ids = [
Felt::from_hex_unchecked(
"0x457e9172b9c70fb4363bb3ff31bf778d8f83828184a9a3f9badadc497f2b954",
"0x58ad05a6987a675eda038663d8e7dcc8e1d91c9057dd57f16d9b3b9602fc840",
),
Felt::from_hex_unchecked(
"0x572373fe992ac8c2413d5e727036316023ed6a2e8a2256b4952e223969e0221",
"0x79b59c5036c9427b5194796ede67bdfffed1f311a77382d715174fcfcc33003",
),
];

Expand Down

0 comments on commit 28e7cb8

Please sign in to comment.