From 2beb6fd281ecb56ac6ce79111ebf20b347abc2f4 Mon Sep 17 00:00:00 2001 From: avi-starkware Date: Wed, 20 Nov 2024 14:54:48 +0200 Subject: [PATCH] feat(cairo_native): add a struct containing all global contract caches (#2186) --- crates/blockifier/src/state/global_cache.rs | 77 ++++++++++++++++--- .../src/py_block_executor.rs | 3 +- .../papyrus_state_reader/src/papyrus_state.rs | 4 +- crates/starknet_batcher/src/block_builder.rs | 3 +- 4 files changed, 74 insertions(+), 13 deletions(-) diff --git a/crates/blockifier/src/state/global_cache.rs b/crates/blockifier/src/state/global_cache.rs index 670045fe5f..0a21c3759d 100644 --- a/crates/blockifier/src/state/global_cache.rs +++ b/crates/blockifier/src/state/global_cache.rs @@ -1,33 +1,44 @@ use std::sync::{Arc, Mutex, MutexGuard}; use cached::{Cached, SizedCache}; +#[cfg(feature = "cairo_native")] +use cairo_native::executor::AotContractExecutor; use starknet_api::core::ClassHash; +#[cfg(feature = "cairo_native")] +use starknet_api::state::ContractClass as SierraContractClass; +#[cfg(feature = "cairo_native")] use crate::execution::contract_class::RunnableContractClass; -// Note: `ContractClassLRUCache` key-value types must align with `ContractClassMapping`. -type ContractClassLRUCache = SizedCache; -pub type LockedContractClassCache<'a> = MutexGuard<'a, ContractClassLRUCache>; +type ContractClassLRUCache = SizedCache; +pub type LockedContractClassCache<'a, T> = MutexGuard<'a, ContractClassLRUCache>; #[derive(Debug, Clone)] // Thread-safe LRU cache for contract classes, optimized for inter-language sharing when // `blockifier` compiles as a shared library. // TODO(Yoni, 1/1/2025): consider defining CachedStateReader. -pub struct GlobalContractCache(pub Arc>); +pub struct GlobalContractCache(pub Arc>>); + +#[cfg(feature = "cairo_native")] +#[derive(Debug, Clone)] +pub enum CachedCairoNative { + Compiled(AotContractExecutor), + CompilationFailed, +} pub const GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST: usize = 100; -impl GlobalContractCache { +impl GlobalContractCache { /// Locks the cache for atomic access. Although conceptually shared, writing to this cache is /// only possible for one writer at a time. - pub fn lock(&self) -> LockedContractClassCache<'_> { + pub fn lock(&self) -> LockedContractClassCache<'_, T> { self.0.lock().expect("Global contract cache is poisoned.") } - pub fn get(&self, class_hash: &ClassHash) -> Option { + pub fn get(&self, class_hash: &ClassHash) -> Option { self.lock().cache_get(class_hash).cloned() } - pub fn set(&self, class_hash: ClassHash, contract_class: RunnableContractClass) { + pub fn set(&self, class_hash: ClassHash, contract_class: T) { self.lock().cache_set(class_hash, contract_class); } @@ -36,6 +47,54 @@ impl GlobalContractCache { } pub fn new(cache_size: usize) -> Self { - Self(Arc::new(Mutex::new(ContractClassLRUCache::with_size(cache_size)))) + Self(Arc::new(Mutex::new(ContractClassLRUCache::::with_size(cache_size)))) + } +} + +#[cfg(feature = "cairo_native")] +pub struct GlobalContractCacheManager { + pub casm_cache: GlobalContractCache, + pub native_cache: GlobalContractCache, + pub sierra_cache: GlobalContractCache>, +} + +#[cfg(feature = "cairo_native")] +impl GlobalContractCacheManager { + pub fn get_casm(&self, class_hash: &ClassHash) -> Option { + self.casm_cache.get(class_hash) + } + + pub fn set_casm(&self, class_hash: ClassHash, contract_class: RunnableContractClass) { + self.casm_cache.set(class_hash, contract_class); + } + + pub fn get_native(&self, class_hash: &ClassHash) -> Option { + self.native_cache.get(class_hash) + } + + pub fn set_native(&self, class_hash: ClassHash, contract_executor: CachedCairoNative) { + self.native_cache.set(class_hash, contract_executor); + } + + pub fn get_sierra(&self, class_hash: &ClassHash) -> Option> { + self.sierra_cache.get(class_hash) + } + + pub fn set_sierra(&self, class_hash: ClassHash, contract_class: Arc) { + self.sierra_cache.set(class_hash, contract_class); + } + + pub fn new(cache_size: usize) -> Self { + Self { + casm_cache: GlobalContractCache::new(cache_size), + native_cache: GlobalContractCache::new(cache_size), + sierra_cache: GlobalContractCache::new(cache_size), + } + } + + pub fn clear(&mut self) { + self.casm_cache.clear(); + self.native_cache.clear(); + self.sierra_cache.clear(); } } diff --git a/crates/native_blockifier/src/py_block_executor.rs b/crates/native_blockifier/src/py_block_executor.rs index 23cf2c828b..17f8604bb1 100644 --- a/crates/native_blockifier/src/py_block_executor.rs +++ b/crates/native_blockifier/src/py_block_executor.rs @@ -6,6 +6,7 @@ use blockifier::blockifier::transaction_executor::{TransactionExecutor, Transact use blockifier::bouncer::BouncerConfig; use blockifier::context::{BlockContext, ChainInfo, FeeTokenAddresses}; use blockifier::execution::call_info::CallInfo; +use blockifier::execution::contract_class::RunnableContractClass; use blockifier::fee::receipt::TransactionReceipt; use blockifier::state::global_cache::GlobalContractCache; use blockifier::transaction::objects::{ExecutionResourcesTraits, TransactionExecutionInfo}; @@ -129,7 +130,7 @@ pub struct PyBlockExecutor { pub tx_executor: Option>, /// `Send` trait is required for `pyclass` compatibility as Python objects must be threadsafe. pub storage: Box, - pub global_contract_cache: GlobalContractCache, + pub global_contract_cache: GlobalContractCache, } #[pymethods] diff --git a/crates/papyrus_state_reader/src/papyrus_state.rs b/crates/papyrus_state_reader/src/papyrus_state.rs index b957a170b0..3ff2cdc46e 100644 --- a/crates/papyrus_state_reader/src/papyrus_state.rs +++ b/crates/papyrus_state_reader/src/papyrus_state.rs @@ -23,14 +23,14 @@ type RawPapyrusReader<'env> = papyrus_storage::StorageTxn<'env, RO>; pub struct PapyrusReader { storage_reader: StorageReader, latest_block: BlockNumber, - global_class_hash_to_class: GlobalContractCache, + global_class_hash_to_class: GlobalContractCache, } impl PapyrusReader { pub fn new( storage_reader: StorageReader, latest_block: BlockNumber, - global_class_hash_to_class: GlobalContractCache, + global_class_hash_to_class: GlobalContractCache, ) -> Self { Self { storage_reader, latest_block, global_class_hash_to_class } } diff --git a/crates/starknet_batcher/src/block_builder.rs b/crates/starknet_batcher/src/block_builder.rs index 2c616ef528..2c862bb4ba 100644 --- a/crates/starknet_batcher/src/block_builder.rs +++ b/crates/starknet_batcher/src/block_builder.rs @@ -11,6 +11,7 @@ use blockifier::blockifier::transaction_executor::{ }; use blockifier::bouncer::{BouncerConfig, BouncerWeights}; use blockifier::context::{BlockContext, ChainInfo}; +use blockifier::execution::contract_class::RunnableContractClass; use blockifier::state::cached_state::CommitmentStateDiff; use blockifier::state::errors::StateError; use blockifier::state::global_cache::GlobalContractCache; @@ -295,7 +296,7 @@ impl SerializeConfig for BlockBuilderConfig { pub struct BlockBuilderFactory { pub block_builder_config: BlockBuilderConfig, pub storage_reader: StorageReader, - pub global_class_hash_to_class: GlobalContractCache, + pub global_class_hash_to_class: GlobalContractCache, } impl BlockBuilderFactory {