From 3ef67fdb002d67091a3434f13af749c573ed2531 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alonso=20Gonz=C3=A1lez?= Date: Fri, 13 Sep 2024 18:53:56 +0200 Subject: [PATCH] Search in linked lists using a BTree (#603) * Search in linked lists using a Btree * Minor * Remove unused * Minor * Apply suggestions from code review Co-authored-by: Robin Salen <30937548+Nashtare@users.noreply.github.com> * Get accounts and storage segment lengths from metadata * Remove linked lists len prover input * Rustfmt --------- Co-authored-by: Robin Salen <30937548+Nashtare@users.noreply.github.com> Co-authored-by: Robin Salen --- .../asm/mpt/linked_list/linked_list.asm | 4 +- .../src/cpu/kernel/interpreter.rs | 16 +- .../src/cpu/kernel/tests/account_code.rs | 11 +- .../src/cpu/kernel/tests/mpt/linked_list.rs | 76 ++++-- .../src/generation/linked_list.rs | 75 ++++- evm_arithmetization/src/generation/mpt.rs | 73 +++-- .../src/generation/prover_input.rs | 256 ++++++++---------- .../src/generation/segments.rs | 2 + evm_arithmetization/src/generation/state.rs | 78 +++++- 9 files changed, 380 insertions(+), 211 deletions(-) diff --git a/evm_arithmetization/src/cpu/kernel/asm/mpt/linked_list/linked_list.asm b/evm_arithmetization/src/cpu/kernel/asm/mpt/linked_list/linked_list.asm index ef1213aae..39e4604d3 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/mpt/linked_list/linked_list.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/mpt/linked_list/linked_list.asm @@ -224,7 +224,7 @@ insert_new_account: /// Returns 0 if the account was not found or `original_ptr` if it was already present. global search_account: // stack: addr_key, retdest - PROVER_INPUT(linked_list::insert_account) + PROVER_INPUT(linked_list::search_account) // stack: pred_ptr/4, addr_key, retdest %get_valid_account_ptr // stack: pred_ptr, addr_key, retdest @@ -685,7 +685,7 @@ next_node_ok: /// Returns `value` if the storage key was inserted, `old_value` if it was already present. global search_slot: // stack: addr_key, key, value, retdest - PROVER_INPUT(linked_list::insert_slot) + PROVER_INPUT(linked_list::search_slot) // stack: pred_ptr/5, addr_key, key, value, retdest %get_valid_slot_ptr diff --git a/evm_arithmetization/src/cpu/kernel/interpreter.rs b/evm_arithmetization/src/cpu/kernel/interpreter.rs index 8a1f6471c..10af8d495 100644 --- a/evm_arithmetization/src/cpu/kernel/interpreter.rs +++ b/evm_arithmetization/src/cpu/kernel/interpreter.rs @@ -5,7 +5,7 @@ //! the future execution and generate nondeterministically the corresponding //! jumpdest table, before the actual CPU carries on with contract execution. -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; use anyhow::anyhow; use ethereum_types::{BigEndianHash, U256}; @@ -115,6 +115,8 @@ pub(crate) struct ExtraSegmentData { pub(crate) ger_prover_inputs: Vec, pub(crate) trie_root_ptrs: TrieRootPtrs, pub(crate) jumpdest_table: Option>>, + pub(crate) accounts: BTreeMap, + pub(crate) storage: BTreeMap<(U256, U256), usize>, pub(crate) next_txn_index: usize, } @@ -232,8 +234,12 @@ impl Interpreter { // Initialize the MPT's pointers. let (trie_root_ptrs, state_leaves, storage_leaves, trie_data) = - load_linked_lists_and_txn_and_receipt_mpts(&inputs.tries) - .expect("Invalid MPT data for preinitialization"); + load_linked_lists_and_txn_and_receipt_mpts( + &mut self.generation_state.accounts_pointers, + &mut self.generation_state.storage_pointers, + &inputs.tries, + ) + .expect("Invalid MPT data for preinitialization"); let trie_roots_after = &inputs.trie_roots_after; self.generation_state.trie_root_ptrs = trie_root_ptrs; @@ -253,6 +259,10 @@ impl Interpreter { ); self.insert_preinitialized_segment(Segment::StorageLinkedList, preinit_storage_ll_segment); + // Initialize the accounts and storage BTrees. + self.generation_state.insert_all_slots_in_memory(); + self.generation_state.insert_all_accounts_in_memory(); + // Update the RLP and withdrawal prover inputs. let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns); let withdrawal_prover_inputs = all_withdrawals_prover_inputs_reversed(&inputs.withdrawals); diff --git a/evm_arithmetization/src/cpu/kernel/tests/account_code.rs b/evm_arithmetization/src/cpu/kernel/tests/account_code.rs index a69d29c83..ff4dba48e 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/account_code.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/account_code.rs @@ -31,8 +31,12 @@ pub(crate) fn initialize_mpts( ) { // Load all MPTs. let (mut trie_root_ptrs, state_leaves, storage_leaves, trie_data) = - load_linked_lists_and_txn_and_receipt_mpts(trie_inputs) - .expect("Invalid MPT data for preinitialization"); + load_linked_lists_and_txn_and_receipt_mpts( + &mut interpreter.generation_state.accounts_pointers, + &mut interpreter.generation_state.storage_pointers, + trie_inputs, + ) + .expect("Invalid MPT data for preinitialization"); interpreter.generation_state.memory.contexts[0].segments [Segment::AccountsLinkedList.unscale()] @@ -44,6 +48,9 @@ pub(crate) fn initialize_mpts( trie_data.clone(); interpreter.generation_state.trie_root_ptrs = trie_root_ptrs.clone(); + interpreter.generation_state.insert_all_slots_in_memory(); + interpreter.generation_state.insert_all_accounts_in_memory(); + if trie_root_ptrs.state_root_ptr.is_none() { trie_root_ptrs.state_root_ptr = Some( load_state_mpt( diff --git a/evm_arithmetization/src/cpu/kernel/tests/mpt/linked_list.rs b/evm_arithmetization/src/cpu/kernel/tests/mpt/linked_list.rs index 69f7061bc..a6e6fa513 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/mpt/linked_list.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/mpt/linked_list.rs @@ -12,9 +12,11 @@ use rand::{thread_rng, Rng}; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::interpreter::Interpreter; -use crate::generation::linked_list::LinkedList; +use crate::generation::linked_list::AccountsLinkedList; +use crate::generation::linked_list::StorageLinkedList; use crate::memory::segments::Segment; use crate::witness::memory::MemoryAddress; +use crate::witness::memory::MemorySegmentState; fn init_logger() { let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug")); @@ -80,7 +82,8 @@ fn test_list_iterator() -> Result<()> { .memory .get_preinit_memory(Segment::AccountsLinkedList); let mut accounts_list = - LinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList).unwrap(); + AccountsLinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList) + .unwrap(); let Some([addr, ptr, ptr_cpy, scaled_pos_1]) = accounts_list.next() else { return Err(anyhow::Error::msg("Couldn't get value")); @@ -102,7 +105,7 @@ fn test_list_iterator() -> Result<()> { .memory .get_preinit_memory(Segment::StorageLinkedList); let mut storage_list = - LinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap(); + StorageLinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap(); let Some([addr, key, ptr, ptr_cpy, scaled_pos_1]) = storage_list.next() else { return Err(anyhow::Error::msg("Couldn't get value")); }; @@ -171,8 +174,16 @@ fn test_insert_account() -> Result<()> { .memory .get_preinit_memory(Segment::AccountsLinkedList); let mut list = - LinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList).unwrap(); + AccountsLinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList) + .unwrap(); + let Some([addr, ptr, ptr_cpy, _]) = list.next() else { + return Err(anyhow::Error::msg("Couldn't get value")); + }; + // This is the dummy node + assert_eq!(addr, U256::MAX); + assert_eq!(ptr, U256::zero()); + assert_eq!(ptr_cpy, U256::zero()); let Some([addr, ptr, ptr_cpy, scaled_next_pos]) = list.next() else { return Err(anyhow::Error::msg("Couldn't get value")); }; @@ -251,7 +262,16 @@ fn test_insert_storage() -> Result<()> { .memory .get_preinit_memory(Segment::StorageLinkedList); let mut list = - LinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap(); + StorageLinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap(); + + let Some([inserted_addr, inserted_key, ptr, ptr_cpy, _]) = list.next() else { + return Err(anyhow::Error::msg("Couldn't get value")); + }; + // This is the dummy node. + assert_eq!(inserted_addr, U256::MAX); + assert_eq!(inserted_key, U256::zero()); + assert_eq!(ptr, U256::zero()); + assert_eq!(ptr_cpy, U256::zero()); let Some([inserted_addr, inserted_key, ptr, ptr_cpy, scaled_next_pos]) = list.next() else { return Err(anyhow::Error::msg("Couldn't get value")); @@ -292,9 +312,17 @@ fn test_insert_and_delete_accounts() -> Result<()> { Some((Segment::AccountsLinkedList as usize).into()), ]; let init_len = init_accounts_ll.len(); - interpreter.generation_state.memory.contexts[0].segments - [Segment::AccountsLinkedList.unscale()] - .content = init_accounts_ll; + + interpreter + .generation_state + .memory + .insert_preinitialized_segment( + Segment::AccountsLinkedList, + MemorySegmentState { + content: init_accounts_ll, + }, + ); + interpreter.set_global_metadata_field( GlobalMetadata::AccountsLinkedListNextAvailable, (Segment::AccountsLinkedList as usize + init_len).into(), @@ -433,19 +461,22 @@ fn test_insert_and_delete_accounts() -> Result<()> { .generation_state .memory .get_preinit_memory(Segment::AccountsLinkedList); - let list = - LinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList).unwrap(); + let list = AccountsLinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList) + .unwrap(); for (i, [addr, ptr, ptr_cpy, _]) in list.enumerate() { if addr == U256::MAX { assert_eq!(addr, U256::MAX); assert_eq!(ptr, U256::zero()); assert_eq!(ptr_cpy, U256::zero()); - break; + if i > 0 { + break; + } + } else { + let addr_in_list = U256::from(new_addresses[i - 1].0.as_slice()); + assert_eq!(addr, addr_in_list); + assert_eq!(ptr, addr + delta_ptr); } - let addr_in_list = U256::from(new_addresses[i].0.as_slice()); - assert_eq!(addr, addr_in_list); - assert_eq!(ptr, addr + delta_ptr); } Ok(()) @@ -640,7 +671,8 @@ fn test_insert_and_delete_storage() -> Result<()> { .generation_state .memory .get_preinit_memory(Segment::StorageLinkedList); - let list = LinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap(); + let list = + StorageLinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap(); for (i, [addr, key, ptr, ptr_cpy, _]) in list.enumerate() { if addr == U256::MAX { @@ -648,12 +680,16 @@ fn test_insert_and_delete_storage() -> Result<()> { assert_eq!(key, U256::zero()); assert_eq!(ptr, U256::zero()); assert_eq!(ptr_cpy, U256::zero()); - break; + if i > 0 { + break; + } + } else { + let [addr_in_list, key_in_list] = + new_addresses[i - 1].map(|x| U256::from(x.0.as_slice())); + assert_eq!(addr, addr_in_list); + assert_eq!(key, key_in_list); + assert_eq!(ptr, addr + delta_ptr); } - let [addr_in_list, key_in_list] = new_addresses[i].map(|x| U256::from(x.0.as_slice())); - assert_eq!(addr, addr_in_list); - assert_eq!(key, key_in_list); - assert_eq!(ptr, addr + delta_ptr); } Ok(()) diff --git a/evm_arithmetization/src/generation/linked_list.rs b/evm_arithmetization/src/generation/linked_list.rs index d8a313c90..fbe0e4965 100644 --- a/evm_arithmetization/src/generation/linked_list.rs +++ b/evm_arithmetization/src/generation/linked_list.rs @@ -1,4 +1,5 @@ use std::fmt; +use std::marker::PhantomData; use anyhow::Result; use ethereum_types::U256; @@ -8,15 +9,37 @@ use crate::util::u256_to_usize; use crate::witness::errors::ProgramError; use crate::witness::errors::ProverInputError::InvalidInput; +pub const ACCOUNTS_LINKED_LIST_NODE_SIZE: usize = 4; +pub const STORAGE_LINKED_LIST_NODE_SIZE: usize = 5; + +pub(crate) trait LinkedListType {} +#[derive(Clone)] +/// A linked list that starts from the first node after the special node and +/// iterates forever. +pub(crate) struct Cyclic; +#[derive(Clone)] +/// A linked list that starts from the special node and iterates until the last +/// node. +pub(crate) struct Bounded; +impl LinkedListType for Cyclic {} +impl LinkedListType for Bounded {} + +pub(crate) type AccountsLinkedList<'a> = LinkedList<'a, ACCOUNTS_LINKED_LIST_NODE_SIZE>; +pub(crate) type StorageLinkedList<'a> = LinkedList<'a, STORAGE_LINKED_LIST_NODE_SIZE>; + // A linked list implemented using a vector `access_list_mem`. // In this representation, the values of nodes are stored in the range // `access_list_mem[i..i + node_size - 1]`, and `access_list_mem[i + node_size - // 1]` holds the address of the next node, where i = node_size * j. #[derive(Clone)] -pub(crate) struct LinkedList<'a, const N: usize> { +pub(crate) struct LinkedList<'a, const N: usize, T = Cyclic> +where + T: LinkedListType, +{ mem: &'a [Option], offset: usize, pos: usize, + _marker: PhantomData, } pub(crate) fn empty_list_mem(segment: Segment) -> [Option; N] { @@ -31,15 +54,15 @@ pub(crate) fn empty_list_mem(segment: Segment) -> [Option; }) } -impl<'a, const N: usize> LinkedList<'a, N> { - pub const fn from_mem_and_segment( +impl<'a, const N: usize, T: LinkedListType> LinkedList<'a, N, T> { + pub fn from_mem_and_segment( mem: &'a [Option], segment: Segment, ) -> Result { Self::from_mem_len_and_segment(mem, segment) } - pub const fn from_mem_len_and_segment( + pub fn from_mem_len_and_segment( mem: &'a [Option], segment: Segment, ) -> Result { @@ -50,6 +73,7 @@ impl<'a, const N: usize> LinkedList<'a, N> { mem, offset: segment as usize, pos: 0, + _marker: PhantomData, }) } } @@ -58,9 +82,8 @@ impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "Linked List {{")?; let cloned_list = self.clone(); - for node in cloned_list { - if node[0] == U256::MAX { - writeln!(f, "{:?}", node)?; + for (i, node) in cloned_list.enumerate() { + if i > 0 && node[0] == U256::MAX { break; } writeln!(f, "{:?} ->", node)?; @@ -69,17 +92,47 @@ impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N> { } } +impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N, Bounded> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "Linked List {{")?; + let cloned_list = self.clone(); + for node in cloned_list { + writeln!(f, "{:?} ->", node)?; + } + write!(f, "}}") + } +} + impl<'a, const N: usize> Iterator for LinkedList<'a, N> { type Item = [U256; N]; fn next(&mut self) -> Option { - // The first node is always the special node, so we skip it in the first - // iteration. + let node = Some(std::array::from_fn(|i| { + self.mem[self.pos + i].unwrap_or_default() + })); if let Ok(new_pos) = u256_to_usize(self.mem[self.pos + N - 1].unwrap_or_default()) { self.pos = new_pos - self.offset; - Some(std::array::from_fn(|i| { + node + } else { + None + } + } +} + +impl<'a, const N: usize> Iterator for LinkedList<'a, N, Bounded> { + type Item = [U256; N]; + + fn next(&mut self) -> Option { + if self.mem[self.pos] != Some(U256::MAX) { + let node = Some(std::array::from_fn(|i| { self.mem[self.pos + i].unwrap_or_default() - })) + })); + if let Ok(new_pos) = u256_to_usize(self.mem[self.pos + N - 1].unwrap_or_default()) { + self.pos = new_pos - self.offset; + node + } else { + None + } } else { None } diff --git a/evm_arithmetization/src/generation/mpt.rs b/evm_arithmetization/src/generation/mpt.rs index adb89b5a6..85f85e75f 100644 --- a/evm_arithmetization/src/generation/mpt.rs +++ b/evm_arithmetization/src/generation/mpt.rs @@ -1,5 +1,5 @@ use core::ops::Deref; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use bytes::Bytes; use ethereum_types::{Address, BigEndianHash, H256, U256}; @@ -10,8 +10,9 @@ use rlp::{Decodable, DecoderError, Encodable, PayloadInfo, Rlp, RlpStream}; use rlp_derive::{RlpDecodable, RlpEncodable}; use serde::{Deserialize, Serialize}; -use super::linked_list::empty_list_mem; -use super::prover_input::{ACCOUNTS_LINKED_LIST_NODE_SIZE, STORAGE_LINKED_LIST_NODE_SIZE}; +use super::linked_list::{ + empty_list_mem, ACCOUNTS_LINKED_LIST_NODE_SIZE, STORAGE_LINKED_LIST_NODE_SIZE, +}; use super::TrimmedTrieInputs; use crate::cpu::kernel::constants::trie_type::PartialTrieType; use crate::generation::TrieInputs; @@ -133,7 +134,7 @@ fn parse_storage_value(value_rlp: &[u8]) -> Result, ProgramError> { Ok(vec![value]) } -const fn parse_storage_value_no_return(_value_rlp: &[u8]) -> Result, ProgramError> { +fn parse_storage_value_no_return(_value_rlp: &[u8]) -> Result, ProgramError> { Ok(vec![]) } @@ -335,6 +336,8 @@ fn get_state_and_storage_leaves( state_leaves: &mut Vec>, storage_leaves: &mut Vec>, trie_data: &mut Vec>, + accounts_pointers: &mut BTreeMap, + storage_pointers: &mut BTreeMap<(U256, U256), usize>, storage_tries_by_state_key: &HashMap, ) -> Result<(), ProgramError> { match trie.deref() { @@ -357,6 +360,8 @@ fn get_state_and_storage_leaves( state_leaves, storage_leaves, trie_data, + accounts_pointers, + storage_pointers, storage_tries_by_state_key, )?; } @@ -371,6 +376,8 @@ fn get_state_and_storage_leaves( state_leaves, storage_leaves, trie_data, + accounts_pointers, + storage_pointers, storage_tries_by_state_key, )?; @@ -400,14 +407,12 @@ fn get_state_and_storage_leaves( // The last leaf must point to the new one. let len = state_leaves.len(); - state_leaves[len - 1] = Some(U256::from( - Segment::AccountsLinkedList as usize + state_leaves.len(), - )); + state_leaves[len - 1] = Some(U256::from(Segment::AccountsLinkedList as usize + len)); // The nibbles are the address. - let address = merged_key + let addr_key = merged_key .try_into() .map_err(|_| ProgramError::IntegerTooLarge)?; - state_leaves.push(Some(address)); + state_leaves.push(Some(addr_key)); // Set `value_ptr_ptr`. state_leaves.push(Some(trie_data.len().into())); // Set counter. @@ -422,13 +427,16 @@ fn get_state_and_storage_leaves( trie_data.push(Some(0.into())); trie_data.push(Some(code_hash.into_uint())); get_storage_leaves( - address, + addr_key, empty_nibbles(), storage_trie, storage_leaves, + storage_pointers, &parse_storage_value, )?; + accounts_pointers.insert(addr_key, Segment::AccountsLinkedList as usize + len); + Ok(()) } _ => Ok(()), @@ -436,10 +444,11 @@ fn get_state_and_storage_leaves( } pub(crate) fn get_storage_leaves( - address: U256, + addr_key: U256, key: Nibbles, trie: &HashedPartialTrie, storage_leaves: &mut Vec>, + storage_pointers: &mut BTreeMap<(U256, U256), usize>, parse_value: &F, ) -> Result<(), ProgramError> where @@ -453,7 +462,14 @@ where count: 1, packed: i.into(), }); - get_storage_leaves(address, extended_key, child, storage_leaves, parse_value)?; + get_storage_leaves( + addr_key, + extended_key, + child, + storage_leaves, + storage_pointers, + parse_value, + )?; } Ok(()) @@ -461,7 +477,14 @@ where Node::Extension { nibbles, child } => { let extended_key = key.merge_nibbles(nibbles); - get_storage_leaves(address, extended_key, child, storage_leaves, parse_value)?; + get_storage_leaves( + addr_key, + extended_key, + child, + storage_leaves, + storage_pointers, + parse_value, + )?; Ok(()) } @@ -469,17 +492,14 @@ where // The last leaf must point to the new one. let len = storage_leaves.len(); let merged_key = key.merge_nibbles(nibbles); - storage_leaves[len - 1] = Some(U256::from( - Segment::StorageLinkedList as usize + storage_leaves.len(), - )); + storage_leaves[len - 1] = Some(U256::from(Segment::StorageLinkedList as usize + len)); // Write the address. - storage_leaves.push(Some(address)); + storage_leaves.push(Some(addr_key)); // Write the key. - storage_leaves.push(Some( - merged_key - .try_into() - .map_err(|_| ProgramError::IntegerTooLarge)?, - )); + let slot_key = merged_key + .try_into() + .map_err(|_| ProgramError::IntegerTooLarge)?; + storage_leaves.push(Some(slot_key)); // Write `value_ptr_ptr`. let leaves = parse_value(value)? .into_iter() @@ -495,6 +515,11 @@ where // Set the next node as the initial node. storage_leaves.push(Some((Segment::StorageLinkedList as usize).into())); + storage_pointers.insert( + (addr_key, slot_key), + Segment::StorageLinkedList as usize + len, + ); + Ok(()) } _ => Ok(()), @@ -514,6 +539,8 @@ type TriePtrsLinkedLists = ( ); pub(crate) fn load_linked_lists_and_txn_and_receipt_mpts( + accounts_pointers: &mut BTreeMap, + storage_pointers: &mut BTreeMap<(U256, U256), usize>, trie_inputs: &TrieInputs, ) -> Result { let mut state_leaves = @@ -546,6 +573,8 @@ pub(crate) fn load_linked_lists_and_txn_and_receipt_mpts( &mut state_leaves, &mut storage_leaves, &mut trie_data, + accounts_pointers, + storage_pointers, &storage_tries_by_state_key, )?; diff --git a/evm_arithmetization/src/generation/prover_input.rs b/evm_arithmetization/src/generation/prover_input.rs index 4a5043d92..16c1e5310 100644 --- a/evm_arithmetization/src/generation/prover_input.rs +++ b/evm_arithmetization/src/generation/prover_input.rs @@ -10,7 +10,9 @@ use num_bigint::BigUint; use plonky2::hash::hash_types::RichField; use serde::{Deserialize, Serialize}; -use super::linked_list::LinkedList; +use super::linked_list::{ + LinkedList, ACCOUNTS_LINKED_LIST_NODE_SIZE, STORAGE_LINKED_LIST_NODE_SIZE, +}; use super::mpt::load_state_mpt; use crate::cpu::kernel::cancun_constants::KZG_VERSIONED_HASH; use crate::cpu::kernel::constants::cancun_constants::{ @@ -26,6 +28,7 @@ use crate::generation::prover_input::EvmField::{ }; use crate::generation::prover_input::FieldOp::{Inverse, Sqrt}; use crate::generation::state::GenerationState; +use crate::generation::GlobalMetadata; use crate::memory::segments::Segment; use crate::memory::segments::Segment::BnPairing; use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, sha2, u256_to_u8, u256_to_usize}; @@ -43,8 +46,6 @@ pub struct ProverInputFn(Vec); pub const ADDRESSES_ACCESS_LIST_LEN: usize = 2; pub const STORAGE_KEYS_ACCESS_LIST_LEN: usize = 4; -pub const ACCOUNTS_LINKED_LIST_NODE_SIZE: usize = 4; -pub const STORAGE_LINKED_LIST_NODE_SIZE: usize = 5; impl From> for ProverInputFn { fn from(v: Vec) -> Self { @@ -92,23 +93,26 @@ impl GenerationState { fn run_trie_ptr(&mut self, input_fn: &ProverInputFn) -> Result { let trie = input_fn.0[1].as_str(); match trie { - "state" => match self.trie_root_ptrs.state_root_ptr { - Some(state_root_ptr) => Ok(state_root_ptr), - None => { - let mut new_content = self.memory.get_preinit_memory(Segment::TrieData); - - let n = load_state_mpt(&self.inputs.trimmed_tries, &mut new_content)?; - - self.memory.insert_preinitialized_segment( - Segment::TrieData, - crate::witness::memory::MemorySegmentState { - content: new_content, - }, - ); - Ok(n) - } - } - .map(U256::from), + "state" => self + .trie_root_ptrs + .state_root_ptr + .map_or_else( + || { + let mut new_content = self.memory.get_preinit_memory(Segment::TrieData); + + let n = load_state_mpt(&self.inputs.trimmed_tries, &mut new_content)?; + + self.memory.insert_preinitialized_segment( + Segment::TrieData, + crate::witness::memory::MemorySegmentState { + content: new_content, + }, + ); + Ok(n) + }, + Ok, + ) + .map(U256::from), "txn" => Ok(U256::from(self.trie_root_ptrs.txn_root_ptr)), "receipt" => Ok(U256::from(self.trie_root_ptrs.receipt_root_ptr)), "trie_data_size" => Ok(self @@ -331,43 +335,11 @@ impl GenerationState { /// jump address. fn run_linked_list(&mut self, input_fn: &ProverInputFn) -> Result { match input_fn.0[1].as_str() { - "insert_account" => self.run_next_insert_account(), + "insert_account" | "search_account" => self.run_next_insert_account(input_fn), "remove_account" => self.run_next_remove_account(), - "insert_slot" => self.run_next_insert_slot(), + "insert_slot" | "search_slot" => self.run_next_insert_slot(input_fn), "remove_slot" => self.run_next_remove_slot(), "remove_address_slots" => self.run_next_remove_address_slots(), - "accounts_linked_list_len" => { - let len = self - .memory - .preinitialized_segments - .get(&Segment::AccountsLinkedList) - .unwrap_or(&crate::witness::memory::MemorySegmentState { content: vec![] }) - .content - .len() - .max( - self.memory.contexts[0].segments[Segment::AccountsLinkedList.unscale()] - .content - .len(), - ); - - Ok((Segment::AccountsLinkedList as usize + len).into()) - } - "storage_linked_list_len" => { - let len = self - .memory - .preinitialized_segments - .get(&Segment::StorageLinkedList) - .unwrap_or(&crate::witness::memory::MemorySegmentState { content: vec![] }) - .content - .len() - .max( - self.memory.contexts[0].segments[Segment::StorageLinkedList.unscale()] - .content - .len(), - ); - - Ok((Segment::StorageLinkedList as usize + len).into()) - } _ => Err(ProgramError::ProverInputError(InvalidInput)), } } @@ -513,103 +485,97 @@ impl GenerationState { /// Returns a pointer to a node in the list such that /// `node[0] <= addr < next_node[0]` and `addr` is the top of the stack. - fn run_next_insert_account(&self) -> Result { + fn run_next_insert_account(&mut self, input_fn: &ProverInputFn) -> Result { let addr = stack_peek(self, 0)?; - let accounts_mem = self.memory.get_preinit_memory(Segment::AccountsLinkedList); - let accounts_linked_list = - LinkedList::::from_mem_and_segment( - &accounts_mem, - Segment::AccountsLinkedList, - )?; - - if let Some(([.., pred_ptr], [_, ..], _)) = - accounts_linked_list - .tuple_windows() - .find(|&(_, [prev_addr, ..], [next_addr, ..])| { - (prev_addr <= addr || prev_addr == U256::MAX) && addr < next_addr - }) - { - Ok(pred_ptr / U256::from(ACCOUNTS_LINKED_LIST_NODE_SIZE)) - } else { - Ok((Segment::AccountsLinkedList as usize).into()) + + let (&pred_addr, &pred_ptr) = self + .accounts_pointers + .range(..=addr) + .next_back() + .unwrap_or((&U256::MAX, &(Segment::AccountsLinkedList as usize))); + + if pred_addr != addr && input_fn.0[1].as_str() == "insert_account" { + self.accounts_pointers.insert( + addr, + u256_to_usize( + self.memory + .read_global_metadata(GlobalMetadata::AccountsLinkedListNextAvailable), + )?, + ); } + + Ok(U256::from(pred_ptr / ACCOUNTS_LINKED_LIST_NODE_SIZE)) } - /// Returns an unscaled pointer to an element in the list such that + /// Returns an unscaled pointer to a node in the list such that /// `node[0] <= addr < next_node[0]`, or node[0] == addr and `node[1] <= /// key < next_node[1]`, where `addr` and `key` are the elements at the top /// of the stack. - fn run_next_insert_slot(&self) -> Result { + fn run_next_insert_slot(&mut self, input_fn: &ProverInputFn) -> Result { let addr = stack_peek(self, 0)?; let key = stack_peek(self, 1)?; - let storage_mem = self.memory.get_preinit_memory(Segment::StorageLinkedList); - let storage_linked_list = - LinkedList::::from_mem_and_segment( - &storage_mem, - Segment::StorageLinkedList, - )?; - - if let Some(([.., pred_ptr], _, _)) = storage_linked_list.tuple_windows().find( - |&(_, [prev_addr, prev_key, ..], [next_addr, next_key, ..])| { - let prev_is_less_or_equal = (prev_addr < addr || prev_addr == U256::MAX) - || (prev_addr == addr && prev_key <= key); - let next_is_strictly_larger = - next_addr > addr || (next_addr == addr && next_key > key); - prev_is_less_or_equal && next_is_strictly_larger - }, - ) { - Ok((pred_ptr - U256::from(Segment::StorageLinkedList as usize)) - / U256::from(STORAGE_LINKED_LIST_NODE_SIZE)) - } else { - Ok(U256::zero()) + + let (&(pred_addr, pred_slot_key), &pred_ptr) = self + .storage_pointers + .range(..=(addr, key)) + .next_back() + .unwrap_or(( + &(U256::MAX, U256::zero()), + &(Segment::StorageLinkedList as usize), + )); + if (pred_addr != addr || pred_slot_key != key) && input_fn.0[1] == "insert_slot" { + self.storage_pointers.insert( + (addr, key), + u256_to_usize( + self.memory + .read_global_metadata(GlobalMetadata::StorageLinkedListNextAvailable), + )?, + ); } + Ok(U256::from( + (pred_ptr - Segment::StorageLinkedList as usize) / STORAGE_LINKED_LIST_NODE_SIZE, + )) } - /// Returns a pointer `ptr` to a node of the form [next_addr, ..] in the + /// Returns a pointer `ptr` to a node of the form [..] -> [next_addr, ..] /// list such that `next_addr = addr` and `addr` is the top of the stack. /// If the element is not in the list, loops forever. - fn run_next_remove_account(&self) -> Result { + fn run_next_remove_account(&mut self) -> Result { let addr = stack_peek(self, 0)?; - let accounts_mem = self.memory.get_preinit_memory(Segment::AccountsLinkedList); - let accounts_linked_list = - LinkedList::::from_mem_and_segment( - &accounts_mem, - Segment::AccountsLinkedList, - )?; - - if let Some(([.., ptr], _, _)) = accounts_linked_list - .tuple_windows() - .find(|&(_, _, [next_node_addr, ..])| next_node_addr == addr) - { - Ok(ptr / ACCOUNTS_LINKED_LIST_NODE_SIZE) - } else { - Ok((Segment::AccountsLinkedList as usize).into()) - } + + let (_, &ptr) = self + .accounts_pointers + .range(..addr) + .next_back() + .unwrap_or((&U256::MAX, &(Segment::AccountsLinkedList as usize))); + self.accounts_pointers + .remove(&addr) + .ok_or(ProgramError::ProverInputError(InvalidInput))?; + + Ok(U256::from(ptr / ACCOUNTS_LINKED_LIST_NODE_SIZE)) } /// Returns a pointer `ptr` to a node = `[next_addr, next_key]` in the list /// such that `next_addr == addr` and `next_key == key`, /// and `addr, key` are the elements at the top of the stack. /// If the element is not in the list, loops forever. - fn run_next_remove_slot(&self) -> Result { + fn run_next_remove_slot(&mut self) -> Result { let addr = stack_peek(self, 0)?; let key = stack_peek(self, 1)?; - let storage_mem = self.memory.get_preinit_memory(Segment::StorageLinkedList); - let storage_linked_list = - LinkedList::::from_mem_and_segment( - &storage_mem, - Segment::StorageLinkedList, - )?; - - if let Some(([.., ptr], _, _)) = storage_linked_list - .tuple_windows() - .find(|&(_, _, [next_addr, next_key, ..])| next_addr == addr && next_key == key) - { - Ok((ptr - U256::from(Segment::StorageLinkedList as usize)) - / U256::from(STORAGE_LINKED_LIST_NODE_SIZE)) - } else { - Ok((Segment::StorageLinkedList as usize).into()) - } + + let (_, &ptr) = self + .storage_pointers + .range(..(addr, key)) + .next_back() + .unwrap_or(( + &(U256::MAX, U256::zero()), + &(Segment::StorageLinkedList as usize), + )); + self.storage_pointers + .remove(&(addr, key)) + .ok_or(ProgramError::ProverInputError(InvalidInput))?; + + Ok(U256::from(ptr - Segment::StorageLinkedList as usize) / STORAGE_LINKED_LIST_NODE_SIZE) } /// Returns a pointer `ptr` to a storage node in the storage linked list. @@ -618,27 +584,21 @@ impl GenerationState { /// `next_addr = @U256_MAX`. This is used to determine the first storage /// node for the account at `addr`. `addr` is the element at the top of the /// stack. - fn run_next_remove_address_slots(&self) -> Result { + fn run_next_remove_address_slots(&mut self) -> Result { let addr = stack_peek(self, 0)?; - let storage_mem = self.memory.get_preinit_memory(Segment::StorageLinkedList); - let storage_linked_list = - LinkedList::::from_mem_and_segment( - &storage_mem, - Segment::StorageLinkedList, - )?; - - if let Some(([.., pred_ptr], _, _)) = storage_linked_list.tuple_windows().find( - |&(_, [prev_addr, _, ..], [next_addr, _, ..])| { - let prev_is_less = prev_addr < addr || prev_addr == U256::MAX; - let next_is_larger_or_equal = next_addr >= addr; - prev_is_less && next_is_larger_or_equal - }, - ) { - Ok((pred_ptr - U256::from(Segment::StorageLinkedList as usize)) - / U256::from(STORAGE_LINKED_LIST_NODE_SIZE)) - } else { - Ok((Segment::StorageLinkedList as usize).into()) - } + + let (_, &pred_ptr) = self + .storage_pointers + .range(..(addr, U256::zero())) + .next_back() + .unwrap_or(( + &(U256::MAX, U256::zero()), + &(Segment::StorageLinkedList as usize), + )); + + Ok(U256::from( + (pred_ptr - Segment::StorageLinkedList as usize) / STORAGE_LINKED_LIST_NODE_SIZE, + )) } /// Returns the first part of the KZG precompile output. diff --git a/evm_arithmetization/src/generation/segments.rs b/evm_arithmetization/src/generation/segments.rs index f25f2e8bc..7f123aa10 100644 --- a/evm_arithmetization/src/generation/segments.rs +++ b/evm_arithmetization/src/generation/segments.rs @@ -74,6 +74,8 @@ fn build_segment_data( trie_root_ptrs: interpreter.generation_state.trie_root_ptrs.clone(), jumpdest_table: interpreter.generation_state.jumpdest_table.clone(), next_txn_index: interpreter.generation_state.next_txn_index, + accounts: interpreter.generation_state.accounts_pointers.clone(), + storage: interpreter.generation_state.storage_pointers.clone(), }, } } diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index 2630ba544..b094c15f9 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use std::mem::size_of; use anyhow::{anyhow, bail}; @@ -8,6 +8,7 @@ use keccak_hash::keccak; use log::Level; use plonky2::hash::hash_types::RichField; +use super::linked_list::{AccountsLinkedList, StorageLinkedList}; use super::mpt::TrieRootPtrs; use super::segments::GenerationSegmentData; use super::{TrieInputs, TrimmedGenerationInputs, NUM_EXTRA_CYCLES_AFTER}; @@ -190,6 +191,7 @@ pub(crate) trait State { let mut final_registers = RegistersState::default(); let mut running = true; let mut final_clock = 0; + loop { let registers = self.get_registers(); let pc = registers.program_counter; @@ -373,6 +375,16 @@ pub struct GenerationState { /// the code (not necessarily pointing to an opcode) such that for every /// j in [i, i+32] it holds that code[j] < 0x7f - j + i. pub(crate) jumpdest_table: Option>>, + + /// Each entry contains the pair (key, ptr) where key is the (hashed) key + /// of an account in the accounts linked list, and ptr is the respective + /// node address in memory. + pub(crate) accounts_pointers: BTreeMap, + + /// Each entry contains the pair ((account_key, slot_key), ptr) where + /// account_key is the (hashed) key of an account, slot_key is the slot + /// key, and ptr is the respective node address in memory. + pub(crate) storage_pointers: BTreeMap<(U256, U256), usize>, } impl GenerationState { @@ -380,9 +392,14 @@ impl GenerationState { &mut self, trie_inputs: &TrieInputs, ) -> TrieRootPtrs { + let generation_state = self.get_mut_generation_state(); let (trie_roots_ptrs, state_leaves, storage_leaves, trie_data) = - load_linked_lists_and_txn_and_receipt_mpts(trie_inputs) - .expect("Invalid MPT data for preinitialization"); + load_linked_lists_and_txn_and_receipt_mpts( + &mut generation_state.accounts_pointers, + &mut generation_state.storage_pointers, + trie_inputs, + ) + .expect("Invalid MPT data for preinitialization"); self.memory.insert_preinitialized_segment( Segment::AccountsLinkedList, @@ -430,11 +447,16 @@ impl GenerationState { receipt_root_ptr: 0, }, jumpdest_table: None, + accounts_pointers: BTreeMap::new(), + storage_pointers: BTreeMap::new(), ger_prover_inputs, }; let trie_root_ptrs = state.preinitialize_linked_lists_and_txn_and_receipt_mpts(&inputs.tries); + state.insert_all_accounts_in_memory(); + state.insert_all_slots_in_memory(); + state.trie_root_ptrs = trie_root_ptrs; Ok(state) } @@ -542,6 +564,8 @@ impl GenerationState { receipt_root_ptr: 0, }, jumpdest_table: None, + accounts_pointers: self.accounts_pointers.clone(), + storage_pointers: self.storage_pointers.clone(), } } @@ -558,6 +582,10 @@ impl GenerationState { .clone_from(&segment_data.extra_data.trie_root_ptrs); self.jumpdest_table .clone_from(&segment_data.extra_data.jumpdest_table); + self.accounts_pointers + .clone_from(&segment_data.extra_data.accounts); + self.storage_pointers + .clone_from(&segment_data.extra_data.storage); self.next_txn_index = segment_data.extra_data.next_txn_index; self.registers = RegistersState { program_counter: self.registers.program_counter, @@ -567,6 +595,50 @@ impl GenerationState { ..segment_data.registers_before }; } + + /// Insert all the slots stored in the `StorageLinkedList`` segment into + /// the accounts `BtreeMap`. + pub(crate) fn insert_all_slots_in_memory(&mut self) { + let storage_mem = self.memory.get_preinit_memory(Segment::StorageLinkedList); + self.storage_pointers.extend( + StorageLinkedList::from_mem_and_segment(&storage_mem, Segment::StorageLinkedList) + .expect("There must be at least an empty storage linked list") + .tuple_windows() + .enumerate() + .map_while( + |(i, ([prev_account_key, .., ptr], [account_key, slot_key, ..]))| { + if i != 0 && prev_account_key == U256::MAX { + None + } else { + Some(( + (account_key, slot_key), + u256_to_usize(ptr).expect("Node pointer must fit in a usize"), + )) + } + }, + ), + ); + } + + pub(crate) fn insert_all_accounts_in_memory(&mut self) { + let accounts_mem = self.memory.get_preinit_memory(Segment::AccountsLinkedList); + self.accounts_pointers.extend( + AccountsLinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList) + .expect("There must be at least an empty accounts linked list") + .tuple_windows() + .enumerate() + .map_while(|(i, ([prev_account_key, .., ptr], [account_key, ..]))| { + if i != 0 && prev_account_key == U256::MAX { + None + } else { + Some(( + account_key, + u256_to_usize(ptr).expect("Node pointer must fit in a usize"), + )) + } + }), + ); + } } impl State for GenerationState {