diff --git a/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm b/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm index eda8c4fdd..0c5e5bcde 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm @@ -45,6 +45,11 @@ global init_access_lists: // Store the segment scaled length %increment %mstore_global_metadata(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN) + + // Reset the access lists pointers in the `GenerationState` + PROVER_INPUT(access_lists::reset) + POP // reset pushed a 0 + JUMP %macro init_access_lists diff --git a/evm_arithmetization/src/cpu/kernel/interpreter.rs b/evm_arithmetization/src/cpu/kernel/interpreter.rs index b831a921a..78e11206e 100644 --- a/evm_arithmetization/src/cpu/kernel/interpreter.rs +++ b/evm_arithmetization/src/cpu/kernel/interpreter.rs @@ -5,7 +5,7 @@ //! the future execution and generate nondeterministically the corresponding //! jumpdest table, before the actual CPU carries on with contract execution. -use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::collections::{BTreeSet, HashMap}; use anyhow::anyhow; use ethereum_types::{BigEndianHash, U256}; @@ -19,6 +19,7 @@ use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::generation::debug_inputs; +use crate::generation::linked_list::LinkedListsPtrs; use crate::generation::mpt::{load_linked_lists_and_txn_and_receipt_mpts, TrieRootPtrs}; use crate::generation::rlp::all_rlp_prover_inputs_reversed; use crate::generation::state::{ @@ -115,8 +116,8 @@ pub(crate) struct ExtraSegmentData { pub(crate) ger_prover_inputs: Vec, pub(crate) trie_root_ptrs: TrieRootPtrs, pub(crate) jumpdest_table: Option>>, - pub(crate) accounts: BTreeMap, - pub(crate) storage: BTreeMap<(U256, U256), usize>, + pub(crate) access_lists_ptrs: LinkedListsPtrs, + pub(crate) state_ptrs: LinkedListsPtrs, pub(crate) next_txn_index: usize, } @@ -235,8 +236,8 @@ impl Interpreter { // Initialize the MPT's pointers. let (trie_root_ptrs, state_leaves, storage_leaves, trie_data) = load_linked_lists_and_txn_and_receipt_mpts( - &mut self.generation_state.accounts_pointers, - &mut self.generation_state.storage_pointers, + &mut self.generation_state.state_ptrs.accounts, + &mut self.generation_state.state_ptrs.storage, &inputs.tries, ) .expect("Invalid MPT data for preinitialization"); diff --git a/evm_arithmetization/src/cpu/kernel/tests/account_code.rs b/evm_arithmetization/src/cpu/kernel/tests/account_code.rs index 9f404afd4..8af4ed4a8 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/account_code.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/account_code.rs @@ -32,8 +32,8 @@ pub(crate) fn initialize_mpts( // Load all MPTs. let (mut trie_root_ptrs, state_leaves, storage_leaves, trie_data) = load_linked_lists_and_txn_and_receipt_mpts( - &mut interpreter.generation_state.accounts_pointers, - &mut interpreter.generation_state.storage_pointers, + &mut interpreter.generation_state.state_ptrs.accounts, + &mut interpreter.generation_state.state_ptrs.storage, trie_inputs, ) .expect("Invalid MPT data for preinitialization"); diff --git a/evm_arithmetization/src/cpu/kernel/tests/mpt/linked_list.rs b/evm_arithmetization/src/cpu/kernel/tests/mpt/linked_list.rs index d80baae61..b31c05233 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/mpt/linked_list.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/mpt/linked_list.rs @@ -12,7 +12,7 @@ use rand::{thread_rng, Rng}; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::Interpreter; -use crate::generation::linked_list::LinkedList; +use crate::generation::linked_list::testing::LinkedList; use crate::generation::linked_list::ACCOUNTS_LINKED_LIST_NODE_SIZE; use crate::generation::linked_list::STORAGE_LINKED_LIST_NODE_SIZE; use crate::memory::segments::Segment; diff --git a/evm_arithmetization/src/generation/linked_list.rs b/evm_arithmetization/src/generation/linked_list.rs index 25b735bfa..b2465bb88 100644 --- a/evm_arithmetization/src/generation/linked_list.rs +++ b/evm_arithmetization/src/generation/linked_list.rs @@ -1,42 +1,27 @@ -use std::fmt; -use std::marker::PhantomData; +use std::collections::BTreeMap; -use anyhow::Result; use ethereum_types::U256; +use serde::{Deserialize, Serialize}; use crate::memory::segments::Segment; -use crate::util::u256_to_usize; -use crate::witness::errors::ProgramError; -use crate::witness::errors::ProverInputError::InvalidInput; pub const ACCOUNTS_LINKED_LIST_NODE_SIZE: usize = 4; pub const STORAGE_LINKED_LIST_NODE_SIZE: usize = 5; -pub(crate) trait LinkedListType {} -#[derive(Clone)] -/// A linked list that starts from the first node after the special node and -/// iterates forever. -pub(crate) struct Cyclic; -#[derive(Clone)] -/// A linked list that starts from the special node and iterates until the last -/// node. -pub(crate) struct Bounded; -impl LinkedListType for Cyclic {} -impl LinkedListType for Bounded {} - -// A linked list implemented using a vector `access_list_mem`. -// In this representation, the values of nodes are stored in the range -// `access_list_mem[i..i + node_size - 1]`, and `access_list_mem[i + node_size - -// 1]` holds the address of the next node, where i = node_size * j. -#[derive(Clone)] -pub(crate) struct LinkedList<'a, const N: usize, T = Cyclic> -where - T: LinkedListType, -{ - mem: &'a [Option], - offset: usize, - pos: usize, - _marker: PhantomData, +pub const DUMMYHEAD: (U256, U256) = (U256::MAX, U256::zero()); + +// Provides quick access to pointers that reference the memory location +// of a storage or accounts linked list node, containing a specific key. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub(crate) struct LinkedListsPtrs { + /// Each entry contains the pair (key, ptr) where key is the (hashed) key + /// of an account in the accounts linked list, and ptr is the respective + /// node address in memory. + pub(crate) accounts: BTreeMap, + /// Each entry contains the pair ((account_key, slot_key), ptr) where + /// account_key is the (hashed) key of an account, slot_key is the slot + /// key, and ptr is the respective node address in memory. + pub(crate) storage: BTreeMap<(U256, U256), usize>, } pub(crate) fn empty_list_mem(segment: Segment) -> [Option; N] { @@ -51,76 +36,99 @@ pub(crate) fn empty_list_mem(segment: Segment) -> [Option; }) } -impl<'a, const N: usize, T: LinkedListType> LinkedList<'a, N, T> { - pub fn from_mem_and_segment( +#[cfg(test)] +pub(crate) mod testing { + use std::fmt; + use std::marker::PhantomData; + + use anyhow::Result; + + use super::*; + use crate::util::u256_to_usize; + use crate::witness::errors::ProgramError; + use crate::witness::errors::ProverInputError::InvalidInput; + + pub const ADDRESSES_ACCESS_LIST_LEN: usize = 2; + pub(crate) trait LinkedListType {} + #[derive(Clone)] + /// A linked list that starts from the first node after the special node and + /// iterates forever. + pub(crate) struct Cyclic; + #[derive(Clone)] + /// A linked list that starts from the special node and iterates until the + /// last node. + pub(crate) struct Bounded; + impl LinkedListType for Cyclic {} + impl LinkedListType for Bounded {} + + // A linked list implemented using a vector `access_list_mem`. + // In this representation, the values of nodes are stored in the range + // `access_list_mem[i..i + node_size - 1]`, and `access_list_mem[i + node_size - + // 1]` holds the address of the next node, where i = node_size * j. + #[derive(Clone)] + pub(crate) struct LinkedList<'a, const N: usize, T = Cyclic> + where + T: LinkedListType, + { mem: &'a [Option], - segment: Segment, - ) -> Result { - Self::from_mem_len_and_segment(mem, segment) + offset: usize, + pos: usize, + _marker: PhantomData, } - pub fn from_mem_len_and_segment( - mem: &'a [Option], - segment: Segment, - ) -> Result { - if mem.is_empty() { - return Err(ProgramError::ProverInputError(InvalidInput)); + impl<'a, const N: usize, T: LinkedListType> LinkedList<'a, N, T> { + pub fn from_mem_and_segment( + mem: &'a [Option], + segment: Segment, + ) -> Result { + Self::from_mem_len_and_segment(mem, segment) } - Ok(Self { - mem, - offset: segment as usize, - pos: 0, - _marker: PhantomData, - }) - } -} -impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "Linked List {{")?; - let cloned_list = self.clone(); - for (i, node) in cloned_list.enumerate() { - if i > 0 && node[0] == U256::MAX { - break; + pub fn from_mem_len_and_segment( + mem: &'a [Option], + segment: Segment, + ) -> Result { + if mem.len() % N != 0 { + return Err(ProgramError::ProverInputError(InvalidInput)); } - writeln!(f, "{:?} ->", node)?; + Ok(Self { + mem, + offset: segment as usize, + pos: 0, + _marker: PhantomData, + }) } - write!(f, "}}") } -} -impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N, Bounded> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "Linked List {{")?; - let cloned_list = self.clone(); - for node in cloned_list { - writeln!(f, "{:?} ->", node)?; + impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "Linked List {{")?; + let cloned_list = self.clone(); + for (i, node) in cloned_list.enumerate() { + if i > 0 && node[0] == U256::MAX { + break; + } + writeln!(f, "{:?} ->", node)?; + } + write!(f, "}}") } - write!(f, "}}") } -} - -impl<'a, const N: usize> Iterator for LinkedList<'a, N> { - type Item = [U256; N]; - fn next(&mut self) -> Option { - let node = Some(std::array::from_fn(|i| { - self.mem[self.pos + i].unwrap_or_default() - })); - if let Ok(new_pos) = u256_to_usize(self.mem[self.pos + N - 1].unwrap_or_default()) { - self.pos = new_pos - self.offset; - node - } else { - None + impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N, Bounded> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "Linked List {{")?; + let cloned_list = self.clone(); + for node in cloned_list { + writeln!(f, "{:?} ->", node)?; + } + write!(f, "}}") } } -} -impl<'a, const N: usize> Iterator for LinkedList<'a, N, Bounded> { - type Item = [U256; N]; + impl<'a, const N: usize> Iterator for LinkedList<'a, N> { + type Item = [U256; N]; - fn next(&mut self) -> Option { - if self.mem[self.pos] != Some(U256::MAX) { + fn next(&mut self) -> Option { let node = Some(std::array::from_fn(|i| { self.mem[self.pos + i].unwrap_or_default() })); @@ -130,8 +138,26 @@ impl<'a, const N: usize> Iterator for LinkedList<'a, N, Bounded> { } else { None } - } else { - None + } + } + + impl<'a, const N: usize> Iterator for LinkedList<'a, N, Bounded> { + type Item = [U256; N]; + + fn next(&mut self) -> Option { + if self.mem[self.pos] != Some(U256::MAX) { + let node = Some(std::array::from_fn(|i| { + self.mem[self.pos + i].unwrap_or_default() + })); + if let Ok(new_pos) = u256_to_usize(self.mem[self.pos + N - 1].unwrap_or_default()) { + self.pos = new_pos - self.offset; + node + } else { + None + } + } else { + None + } } } } diff --git a/evm_arithmetization/src/generation/prover_input.rs b/evm_arithmetization/src/generation/prover_input.rs index 366b75a35..704e2f4c6 100644 --- a/evm_arithmetization/src/generation/prover_input.rs +++ b/evm_arithmetization/src/generation/prover_input.rs @@ -10,8 +10,10 @@ use num_bigint::BigUint; use plonky2::hash::hash_types::RichField; use serde::{Deserialize, Serialize}; +#[cfg(test)] +use super::linked_list::testing::{LinkedList, ADDRESSES_ACCESS_LIST_LEN}; use super::linked_list::{ - LinkedList, ACCOUNTS_LINKED_LIST_NODE_SIZE, STORAGE_LINKED_LIST_NODE_SIZE, + LinkedListsPtrs, ACCOUNTS_LINKED_LIST_NODE_SIZE, DUMMYHEAD, STORAGE_LINKED_LIST_NODE_SIZE, }; use super::mpt::load_state_mpt; use crate::cpu::kernel::cancun_constants::KZG_VERSIONED_HASH; @@ -44,9 +46,6 @@ use crate::witness::util::{current_context_peek, stack_peek}; #[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)] pub struct ProverInputFn(Vec); -pub const ADDRESSES_ACCESS_LIST_LEN: usize = 2; -pub const STORAGE_KEYS_ACCESS_LIST_LEN: usize = 4; - impl From> for ProverInputFn { fn from(v: Vec) -> Self { Self(v) @@ -327,6 +326,7 @@ impl GenerationState { "storage_insert" => self.run_next_storage_insert(), "address_remove" => self.run_next_addresses_remove(), "storage_remove" => self.run_next_storage_remove(), + "reset" => self.run_reset(), _ => Err(ProgramError::ProverInputError(InvalidInput)), } } @@ -410,77 +410,95 @@ impl GenerationState { /// Returns a pointer to an element in the list whose value is such that /// `value <= addr < next_value` and `addr` is the top of the stack. - fn run_next_addresses_insert(&self) -> Result { + fn run_next_addresses_insert(&mut self) -> Result { let addr = stack_peek(self, 0)?; - if let Some((([_, ptr], _), _)) = self - .get_addresses_access_list()? - .zip(self.get_addresses_access_list()?.skip(1)) - .zip(self.get_addresses_access_list()?.skip(2)) - .find(|&((_, [prev_addr, _]), [next_addr, _])| { - (prev_addr <= addr || prev_addr == U256::MAX) && addr < next_addr - }) - { - Ok(ptr / U256::from(2)) - } else { - Ok((Segment::AccessedAddresses as usize).into()) + + let (&pred_addr, &ptr) = self + .access_lists_ptrs + .accounts + .range(..=addr) + .next_back() + .unwrap_or((&U256::MAX, &(Segment::AccessedAddresses as usize))); + + if pred_addr != addr { + self.access_lists_ptrs.accounts.insert( + addr, + u256_to_usize( + self.memory + .read_global_metadata(GlobalMetadata::AccessedAddressesLen), + )?, + ); } + Ok(U256::from(ptr / 2)) } /// Returns a pointer to an element in the list whose value is such that /// `value < addr == next_value` and addr is the top of the stack. /// If the element is not in the list, it loops forever - fn run_next_addresses_remove(&self) -> Result { + fn run_next_addresses_remove(&mut self) -> Result { let addr = stack_peek(self, 0)?; - if let Some(([_, ptr], _)) = self - .get_addresses_access_list()? - .zip(self.get_addresses_access_list()?.skip(2)) - .find(|&(_, [next_addr, _])| next_addr == addr) - { - Ok(ptr / U256::from(2)) - } else { - Ok((Segment::AccessedAddresses as usize).into()) - } + + let (_, &ptr) = self + .access_lists_ptrs + .accounts + .range(..addr) + .next_back() + .unwrap_or((&U256::MAX, &(Segment::AccessedAddresses as usize))); + self.access_lists_ptrs + .accounts + .remove(&addr) + .ok_or(ProgramError::ProverInputError(InvalidInput))?; + + Ok(U256::from(ptr / 2)) } /// Returns a pointer to the predecessor of the top of the stack in the /// accessed storage keys list. - fn run_next_storage_insert(&self) -> Result { + fn run_next_storage_insert(&mut self) -> Result { let addr = stack_peek(self, 0)?; let key = stack_peek(self, 1)?; - if let Some((([.., ptr], _), _)) = self - .get_storage_keys_access_list()? - .zip(self.get_storage_keys_access_list()?.skip(1)) - .zip(self.get_storage_keys_access_list()?.skip(2)) - .find( - |&((_, [prev_addr, prev_key, ..]), [next_addr, next_key, ..])| { - let prev_is_less_or_equal = (prev_addr < addr || prev_addr == U256::MAX) - || (prev_addr == addr && prev_key <= key); - let next_is_strictly_larger = - next_addr > addr || (next_addr == addr && next_key > key); - prev_is_less_or_equal && next_is_strictly_larger - }, - ) - { - Ok(ptr / U256::from(4)) - } else { - Ok((Segment::AccessedStorageKeys as usize).into()) + + let (&(pred_addr, pred_slot_key), &ptr) = self + .access_lists_ptrs + .storage + .range(..=(addr, key)) + .next_back() + .unwrap_or((&DUMMYHEAD, &(Segment::AccessedStorageKeys as usize))); + if pred_addr != addr || pred_slot_key != key { + self.access_lists_ptrs.storage.insert( + (addr, key), + u256_to_usize( + self.memory + .read_global_metadata(GlobalMetadata::AccessedStorageKeysLen), + )?, + ); } + Ok(U256::from(ptr / 4)) } /// Returns a pointer to the predecessor of the top of the stack in the /// accessed storage keys list. - fn run_next_storage_remove(&self) -> Result { + fn run_next_storage_remove(&mut self) -> Result { let addr = stack_peek(self, 0)?; let key = stack_peek(self, 1)?; - if let Some(([.., ptr], _)) = self - .get_storage_keys_access_list()? - .zip(self.get_storage_keys_access_list()?.skip(2)) - .find(|&(_, [next_addr, next_key, ..])| (next_addr == addr && next_key == key)) - { - Ok(ptr / U256::from(4)) - } else { - Ok((Segment::AccessedStorageKeys as usize).into()) - } + + let (_, &ptr) = self + .access_lists_ptrs + .storage + .range(..(addr, key)) + .next_back() + .unwrap_or((&DUMMYHEAD, &(Segment::AccessedStorageKeys as usize))); + self.access_lists_ptrs + .storage + .remove(&(addr, key)) + .ok_or(ProgramError::ProverInputError(InvalidInput))?; + + Ok(U256::from(ptr / 4)) + } + + fn run_reset(&mut self) -> Result { + self.access_lists_ptrs = LinkedListsPtrs::default(); + Ok(U256::zero()) } /// Returns a pointer to a node in the list such that @@ -489,13 +507,14 @@ impl GenerationState { let addr = stack_peek(self, 0)?; let (&pred_addr, &pred_ptr) = self - .accounts_pointers + .state_ptrs + .accounts .range(..=addr) .next_back() .unwrap_or((&U256::MAX, &(Segment::AccountsLinkedList as usize))); if pred_addr != addr && input_fn.0[1].as_str() == "insert_account" { - self.accounts_pointers.insert( + self.state_ptrs.accounts.insert( addr, u256_to_usize( self.memory @@ -516,15 +535,13 @@ impl GenerationState { let key = stack_peek(self, 1)?; let (&(pred_addr, pred_slot_key), &pred_ptr) = self - .storage_pointers + .state_ptrs + .storage .range(..=(addr, key)) .next_back() - .unwrap_or(( - &(U256::MAX, U256::zero()), - &(Segment::StorageLinkedList as usize), - )); + .unwrap_or((&DUMMYHEAD, &(Segment::StorageLinkedList as usize))); if (pred_addr != addr || pred_slot_key != key) && input_fn.0[1] == "insert_slot" { - self.storage_pointers.insert( + self.state_ptrs.storage.insert( (addr, key), u256_to_usize( self.memory @@ -544,11 +561,13 @@ impl GenerationState { let addr = stack_peek(self, 0)?; let (_, &ptr) = self - .accounts_pointers + .state_ptrs + .accounts .range(..addr) .next_back() .unwrap_or((&U256::MAX, &(Segment::AccountsLinkedList as usize))); - self.accounts_pointers + self.state_ptrs + .accounts .remove(&addr) .ok_or(ProgramError::ProverInputError(InvalidInput))?; @@ -564,14 +583,13 @@ impl GenerationState { let key = stack_peek(self, 1)?; let (_, &ptr) = self - .storage_pointers + .state_ptrs + .storage .range(..(addr, key)) .next_back() - .unwrap_or(( - &(U256::MAX, U256::zero()), - &(Segment::StorageLinkedList as usize), - )); - self.storage_pointers + .unwrap_or((&DUMMYHEAD, &(Segment::StorageLinkedList as usize))); + self.state_ptrs + .storage .remove(&(addr, key)) .ok_or(ProgramError::ProverInputError(InvalidInput))?; @@ -588,19 +606,30 @@ impl GenerationState { let addr = stack_peek(self, 0)?; let (_, &pred_ptr) = self - .storage_pointers + .state_ptrs + .storage .range(..(addr, U256::zero())) .next_back() - .unwrap_or(( - &(U256::MAX, U256::zero()), - &(Segment::StorageLinkedList as usize), - )); + .unwrap_or((&DUMMYHEAD, &(Segment::StorageLinkedList as usize))); Ok(U256::from( (pred_ptr - Segment::StorageLinkedList as usize) / STORAGE_LINKED_LIST_NODE_SIZE, )) } + #[cfg(test)] + pub(crate) fn get_addresses_access_list( + &self, + ) -> Result, ProgramError> { + // `GlobalMetadata::AccessedAddressesLen` stores the value of the next available + // virtual address in the segment. In order to get the length we need + // to substract `Segment::AccessedAddresses` as usize. + LinkedList::from_mem_and_segment( + &self.memory.contexts[0].segments[Segment::AccessedAddresses.unscale()].content, + Segment::AccessedAddresses, + ) + } + /// Returns the first part of the KZG precompile output. fn run_kzg_point_eval(&mut self) -> Result { let versioned_hash = stack_peek(self, 0)?; @@ -819,30 +848,6 @@ impl GenerationState { } } } - - pub(crate) fn get_addresses_access_list( - &self, - ) -> Result, ProgramError> { - // `GlobalMetadata::AccessedAddressesLen` stores the value of the next available - // virtual address in the segment. In order to get the length we need - // to substract `Segment::AccessedAddresses` as usize. - LinkedList::from_mem_and_segment( - &self.memory.contexts[0].segments[Segment::AccessedAddresses.unscale()].content, - Segment::AccessedAddresses, - ) - } - - pub(crate) fn get_storage_keys_access_list( - &self, - ) -> Result, ProgramError> { - // GlobalMetadata::AccessedStorageKeysLen stores the value of the next available - // virtual address in the segment. In order to get the length we need - // to substract `Segment::AccessedStorageKeys` as usize. - LinkedList::from_mem_and_segment( - &self.memory.contexts[0].segments[Segment::AccessedStorageKeys.unscale()].content, - Segment::AccessedStorageKeys, - ) - } } /// For all address in `jumpdest_table` smaller than `largest_address`, diff --git a/evm_arithmetization/src/generation/segments.rs b/evm_arithmetization/src/generation/segments.rs index 7f123aa10..fc7718614 100644 --- a/evm_arithmetization/src/generation/segments.rs +++ b/evm_arithmetization/src/generation/segments.rs @@ -74,8 +74,8 @@ fn build_segment_data( trie_root_ptrs: interpreter.generation_state.trie_root_ptrs.clone(), jumpdest_table: interpreter.generation_state.jumpdest_table.clone(), next_txn_index: interpreter.generation_state.next_txn_index, - accounts: interpreter.generation_state.accounts_pointers.clone(), - storage: interpreter.generation_state.storage_pointers.clone(), + access_lists_ptrs: interpreter.generation_state.access_lists_ptrs.clone(), + state_ptrs: interpreter.generation_state.state_ptrs.clone(), }, } } diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index 1659055f7..1ea87bd0c 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, HashMap}; +use std::collections::HashMap; use std::mem::size_of; use anyhow::{anyhow, bail}; @@ -8,6 +8,7 @@ use keccak_hash::keccak; use log::Level; use plonky2::hash::hash_types::RichField; +use super::linked_list::LinkedListsPtrs; use super::mpt::TrieRootPtrs; use super::segments::GenerationSegmentData; use super::{TrieInputs, TrimmedGenerationInputs, NUM_EXTRA_CYCLES_AFTER}; @@ -375,15 +376,13 @@ pub struct GenerationState { /// j in [i, i+32] it holds that code[j] < 0x7f - j + i. pub(crate) jumpdest_table: Option>>, - /// Each entry contains the pair (key, ptr) where key is the (hashed) key - /// of an account in the accounts linked list, and ptr is the respective - /// node address in memory. - pub(crate) accounts_pointers: BTreeMap, + /// Provides quick access to pointers that reference the location + /// of either and account or a slot in the respective access list. + pub(crate) access_lists_ptrs: LinkedListsPtrs, - /// Each entry contains the pair ((account_key, slot_key), ptr) where - /// account_key is the (hashed) key of an account, slot_key is the slot - /// key, and ptr is the respective node address in memory. - pub(crate) storage_pointers: BTreeMap<(U256, U256), usize>, + /// Provides quick access to pointers that reference the memory location of + /// either and account or a slot in the respective access list. + pub(crate) state_ptrs: LinkedListsPtrs, } impl GenerationState { @@ -394,8 +393,8 @@ impl GenerationState { let generation_state = self.get_mut_generation_state(); let (trie_roots_ptrs, state_leaves, storage_leaves, trie_data) = load_linked_lists_and_txn_and_receipt_mpts( - &mut generation_state.accounts_pointers, - &mut generation_state.storage_pointers, + &mut generation_state.state_ptrs.accounts, + &mut generation_state.state_ptrs.storage, trie_inputs, ) .expect("Invalid MPT data for preinitialization"); @@ -446,8 +445,8 @@ impl GenerationState { receipt_root_ptr: 0, }, jumpdest_table: None, - accounts_pointers: BTreeMap::new(), - storage_pointers: BTreeMap::new(), + access_lists_ptrs: LinkedListsPtrs::default(), + state_ptrs: LinkedListsPtrs::default(), ger_prover_inputs, }; let trie_root_ptrs = @@ -463,6 +462,8 @@ impl GenerationState { ) -> Result { let mut state = Self { inputs: trimmed_inputs.clone(), + state_ptrs: segment_data.extra_data.state_ptrs.clone(), + access_lists_ptrs: segment_data.extra_data.access_lists_ptrs.clone(), ..Default::default() }; @@ -560,8 +561,8 @@ impl GenerationState { receipt_root_ptr: 0, }, jumpdest_table: None, - accounts_pointers: self.accounts_pointers.clone(), - storage_pointers: self.storage_pointers.clone(), + access_lists_ptrs: self.access_lists_ptrs.clone(), + state_ptrs: self.state_ptrs.clone(), } } @@ -578,10 +579,10 @@ impl GenerationState { .clone_from(&segment_data.extra_data.trie_root_ptrs); self.jumpdest_table .clone_from(&segment_data.extra_data.jumpdest_table); - self.accounts_pointers - .clone_from(&segment_data.extra_data.accounts); - self.storage_pointers - .clone_from(&segment_data.extra_data.storage); + self.state_ptrs + .clone_from(&segment_data.extra_data.state_ptrs); + self.access_lists_ptrs + .clone_from(&segment_data.extra_data.access_lists_ptrs); self.next_txn_index = segment_data.extra_data.next_txn_index; self.registers = RegistersState { program_counter: self.registers.program_counter,