Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor access list search #637

Merged
merged 14 commits into from
Sep 26, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ global init_access_lists:
// Store the segment scaled length
%increment
%mstore_global_metadata(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN)

// Reset the access lists pointers in the `GenerationState`
PROVER_INPUT(access_lists::reset)
POP // reset pushed a 0

JUMP

%macro init_access_lists
Expand Down
11 changes: 6 additions & 5 deletions evm_arithmetization/src/cpu/kernel/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//! the future execution and generate nondeterministically the corresponding
//! jumpdest table, before the actual CPU carries on with contract execution.

use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::collections::{BTreeSet, HashMap};

use anyhow::anyhow;
use ethereum_types::{BigEndianHash, U256};
Expand All @@ -19,6 +19,7 @@ use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::generation::debug_inputs;
use crate::generation::linked_list::LinkedListsPtrs;
use crate::generation::mpt::{load_linked_lists_and_txn_and_receipt_mpts, TrieRootPtrs};
use crate::generation::rlp::all_rlp_prover_inputs_reversed;
use crate::generation::state::{
Expand Down Expand Up @@ -115,8 +116,8 @@ pub(crate) struct ExtraSegmentData {
pub(crate) ger_prover_inputs: Vec<U256>,
pub(crate) trie_root_ptrs: TrieRootPtrs,
pub(crate) jumpdest_table: Option<HashMap<usize, Vec<usize>>>,
pub(crate) accounts: BTreeMap<U256, usize>,
pub(crate) storage: BTreeMap<(U256, U256), usize>,
pub(crate) access_lists_ptrs: LinkedListsPtrs,
pub(crate) state_ptrs: LinkedListsPtrs,
pub(crate) next_txn_index: usize,
}

Expand Down Expand Up @@ -235,8 +236,8 @@ impl<F: RichField> Interpreter<F> {
// Initialize the MPT's pointers.
let (trie_root_ptrs, state_leaves, storage_leaves, trie_data) =
load_linked_lists_and_txn_and_receipt_mpts(
&mut self.generation_state.accounts_pointers,
&mut self.generation_state.storage_pointers,
&mut self.generation_state.state_ptrs.accounts,
&mut self.generation_state.state_ptrs.storage,
&inputs.tries,
)
.expect("Invalid MPT data for preinitialization");
Expand Down
4 changes: 2 additions & 2 deletions evm_arithmetization/src/cpu/kernel/tests/account_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ pub(crate) fn initialize_mpts<F: RichField>(
// Load all MPTs.
let (mut trie_root_ptrs, state_leaves, storage_leaves, trie_data) =
load_linked_lists_and_txn_and_receipt_mpts(
&mut interpreter.generation_state.accounts_pointers,
&mut interpreter.generation_state.storage_pointers,
&mut interpreter.generation_state.state_ptrs.accounts,
&mut interpreter.generation_state.state_ptrs.storage,
trie_inputs,
)
.expect("Invalid MPT data for preinitialization");
Expand Down
18 changes: 17 additions & 1 deletion evm_arithmetization/src/generation/linked_list.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::collections::BTreeMap;
use std::fmt;
use std::marker::PhantomData;

use anyhow::Result;
use ethereum_types::U256;
use serde::{Deserialize, Serialize};

use crate::memory::segments::Segment;
use crate::util::u256_to_usize;
Expand Down Expand Up @@ -39,6 +41,20 @@ where
_marker: PhantomData<T>,
}

// Provides quick access to pointers that reference the memory location
// of a storage or accounts linked list node, containing a specific key.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub(crate) struct LinkedListsPtrs {
/// Each entry contains the pair (key, ptr) where key is the (hashed) key
/// of an account in the accounts linked list, and ptr is the respective
/// node address in memory.
pub(crate) accounts: BTreeMap<U256, usize>,
/// Each entry contains the pair ((account_key, slot_key), ptr) where
/// account_key is the (hashed) key of an account, slot_key is the slot
/// key, and ptr is the respective node address in memory.
pub(crate) storage: BTreeMap<(U256, U256), usize>,
}

pub(crate) fn empty_list_mem<const N: usize>(segment: Segment) -> [Option<U256>; N] {
std::array::from_fn(|i| {
if i == 0 {
Expand All @@ -63,7 +79,7 @@ impl<'a, const N: usize, T: LinkedListType> LinkedList<'a, N, T> {
mem: &'a [Option<U256>],
segment: Segment,
) -> Result<Self, ProgramError> {
if mem.is_empty() {
if mem.is_empty() || mem.len() % N != 0 {
4l0n50 marked this conversation as resolved.
Show resolved Hide resolved
return Err(ProgramError::ProverInputError(InvalidInput));
}
Ok(Self {
Expand Down
156 changes: 96 additions & 60 deletions evm_arithmetization/src/generation/prover_input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use plonky2::hash::hash_types::RichField;
use serde::{Deserialize, Serialize};

use super::linked_list::{
LinkedList, ACCOUNTS_LINKED_LIST_NODE_SIZE, STORAGE_LINKED_LIST_NODE_SIZE,
LinkedList, LinkedListsPtrs, ACCOUNTS_LINKED_LIST_NODE_SIZE, STORAGE_LINKED_LIST_NODE_SIZE,
};
use super::mpt::load_state_mpt;
use crate::cpu::kernel::cancun_constants::KZG_VERSIONED_HASH;
Expand Down Expand Up @@ -44,7 +44,9 @@ use crate::witness::util::{current_context_peek, stack_peek};
#[derive(PartialEq, Eq, Debug, Clone, Serialize, Deserialize)]
pub struct ProverInputFn(Vec<String>);

#[allow(dead_code)]
pub const ADDRESSES_ACCESS_LIST_LEN: usize = 2;
#[allow(dead_code)]
pub const STORAGE_KEYS_ACCESS_LIST_LEN: usize = 4;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we probably don't wanna keep all these dead code lines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was using those when debugging. How can I keep it for future use?

Copy link
Collaborator

@Nashtare Nashtare Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm ideally if you want to keep it then we could group all of these unused into some

#[allow(unused)]
pub(crate) mod debugging {

}

but is it really going to be useful in the future? Or was it helpful only for the transition of access list searching?

Copy link
Contributor Author

@4l0n50 4l0n50 Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the end there was a test using ADDRESSES_ACCESS_LIST_LEN


impl From<Vec<String>> for ProverInputFn {
Expand Down Expand Up @@ -327,6 +329,7 @@ impl<F: RichField> GenerationState<F> {
"storage_insert" => self.run_next_storage_insert(),
"address_remove" => self.run_next_addresses_remove(),
"storage_remove" => self.run_next_storage_remove(),
"reset" => self.run_reset(),
_ => Err(ProgramError::ProverInputError(InvalidInput)),
}
}
Expand Down Expand Up @@ -410,77 +413,101 @@ impl<F: RichField> GenerationState<F> {

/// Returns a pointer to an element in the list whose value is such that
/// `value <= addr < next_value` and `addr` is the top of the stack.
fn run_next_addresses_insert(&self) -> Result<U256, ProgramError> {
fn run_next_addresses_insert(&mut self) -> Result<U256, ProgramError> {
let addr = stack_peek(self, 0)?;
if let Some((([_, ptr], _), _)) = self
.get_addresses_access_list()?
.zip(self.get_addresses_access_list()?.skip(1))
.zip(self.get_addresses_access_list()?.skip(2))
.find(|&((_, [prev_addr, _]), [next_addr, _])| {
(prev_addr <= addr || prev_addr == U256::MAX) && addr < next_addr
})
{
Ok(ptr / U256::from(2))
} else {
Ok((Segment::AccessedAddresses as usize).into())

let (&pred_addr, &ptr) = self
.access_lists_ptrs
.accounts
.range(..=addr)
.next_back()
.unwrap_or((&U256::MAX, &(Segment::AccessedAddresses as usize)));

if pred_addr != addr {
self.access_lists_ptrs.accounts.insert(
addr,
u256_to_usize(
self.memory
.read_global_metadata(GlobalMetadata::AccessedAddressesLen),
)?,
);
}
Ok(U256::from(ptr / 2))
einar-polygon marked this conversation as resolved.
Show resolved Hide resolved
}

/// Returns a pointer to an element in the list whose value is such that
/// `value < addr == next_value` and addr is the top of the stack.
/// If the element is not in the list, it loops forever
fn run_next_addresses_remove(&self) -> Result<U256, ProgramError> {
fn run_next_addresses_remove(&mut self) -> Result<U256, ProgramError> {
let addr = stack_peek(self, 0)?;
if let Some(([_, ptr], _)) = self
.get_addresses_access_list()?
.zip(self.get_addresses_access_list()?.skip(2))
.find(|&(_, [next_addr, _])| next_addr == addr)
{
Ok(ptr / U256::from(2))
} else {
Ok((Segment::AccessedAddresses as usize).into())
}

let (_, &ptr) = self
.access_lists_ptrs
.accounts
.range(..addr)
.next_back()
.unwrap_or((&U256::MAX, &(Segment::AccessedAddresses as usize)));
self.access_lists_ptrs
.accounts
.remove(&addr)
.ok_or(ProgramError::ProverInputError(InvalidInput))?;

Ok(U256::from(ptr / 2))
}

/// Returns a pointer to the predecessor of the top of the stack in the
/// accessed storage keys list.
fn run_next_storage_insert(&self) -> Result<U256, ProgramError> {
fn run_next_storage_insert(&mut self) -> Result<U256, ProgramError> {
let addr = stack_peek(self, 0)?;
let key = stack_peek(self, 1)?;
if let Some((([.., ptr], _), _)) = self
.get_storage_keys_access_list()?
.zip(self.get_storage_keys_access_list()?.skip(1))
.zip(self.get_storage_keys_access_list()?.skip(2))
.find(
|&((_, [prev_addr, prev_key, ..]), [next_addr, next_key, ..])| {
let prev_is_less_or_equal = (prev_addr < addr || prev_addr == U256::MAX)
|| (prev_addr == addr && prev_key <= key);
let next_is_strictly_larger =
next_addr > addr || (next_addr == addr && next_key > key);
prev_is_less_or_equal && next_is_strictly_larger
},
)
{
Ok(ptr / U256::from(4))
} else {
Ok((Segment::AccessedStorageKeys as usize).into())

let (&(pred_addr, pred_slot_key), &ptr) = self
.access_lists_ptrs
.storage
.range(..=(addr, key))
.next_back()
.unwrap_or((
&(U256::MAX, U256::zero()),
&(Segment::AccessedStorageKeys as usize),
));
if pred_addr != addr || pred_slot_key != key {
self.access_lists_ptrs.storage.insert(
(addr, key),
u256_to_usize(
self.memory
.read_global_metadata(GlobalMetadata::AccessedStorageKeysLen),
)?,
);
}
Ok(U256::from(ptr / 4))
}

/// Returns a pointer to the predecessor of the top of the stack in the
/// accessed storage keys list.
fn run_next_storage_remove(&self) -> Result<U256, ProgramError> {
fn run_next_storage_remove(&mut self) -> Result<U256, ProgramError> {
let addr = stack_peek(self, 0)?;
let key = stack_peek(self, 1)?;
if let Some(([.., ptr], _)) = self
.get_storage_keys_access_list()?
.zip(self.get_storage_keys_access_list()?.skip(2))
.find(|&(_, [next_addr, next_key, ..])| (next_addr == addr && next_key == key))
{
Ok(ptr / U256::from(4))
} else {
Ok((Segment::AccessedStorageKeys as usize).into())
}

let (_, &ptr) = self
.access_lists_ptrs
.storage
.range(..(addr, key))
.next_back()
.unwrap_or((
&(U256::MAX, U256::zero()),
4l0n50 marked this conversation as resolved.
Show resolved Hide resolved
&(Segment::AccessedStorageKeys as usize),
));
self.access_lists_ptrs
.storage
.remove(&(addr, key))
.ok_or(ProgramError::ProverInputError(InvalidInput))?;

Ok(U256::from(ptr / 4))
}

fn run_reset(&mut self) -> Result<U256, ProgramError> {
self.access_lists_ptrs = LinkedListsPtrs::default();
Ok(U256::zero())
}

/// Returns a pointer to a node in the list such that
Expand All @@ -489,13 +516,14 @@ impl<F: RichField> GenerationState<F> {
let addr = stack_peek(self, 0)?;

let (&pred_addr, &pred_ptr) = self
.accounts_pointers
.state_ptrs
.accounts
.range(..=addr)
.next_back()
.unwrap_or((&U256::MAX, &(Segment::AccountsLinkedList as usize)));

if pred_addr != addr && input_fn.0[1].as_str() == "insert_account" {
self.accounts_pointers.insert(
self.state_ptrs.accounts.insert(
addr,
u256_to_usize(
self.memory
Expand All @@ -516,15 +544,16 @@ impl<F: RichField> GenerationState<F> {
let key = stack_peek(self, 1)?;

let (&(pred_addr, pred_slot_key), &pred_ptr) = self
.storage_pointers
.state_ptrs
.storage
.range(..=(addr, key))
.next_back()
.unwrap_or((
&(U256::MAX, U256::zero()),
&(Segment::StorageLinkedList as usize),
));
if (pred_addr != addr || pred_slot_key != key) && input_fn.0[1] == "insert_slot" {
self.storage_pointers.insert(
self.state_ptrs.storage.insert(
(addr, key),
u256_to_usize(
self.memory
Expand All @@ -544,11 +573,13 @@ impl<F: RichField> GenerationState<F> {
let addr = stack_peek(self, 0)?;

let (_, &ptr) = self
.accounts_pointers
.state_ptrs
.accounts
.range(..addr)
.next_back()
.unwrap_or((&U256::MAX, &(Segment::AccountsLinkedList as usize)));
self.accounts_pointers
self.state_ptrs
.accounts
.remove(&addr)
.ok_or(ProgramError::ProverInputError(InvalidInput))?;

Expand All @@ -564,14 +595,16 @@ impl<F: RichField> GenerationState<F> {
let key = stack_peek(self, 1)?;

let (_, &ptr) = self
.storage_pointers
.state_ptrs
.storage
.range(..(addr, key))
.next_back()
.unwrap_or((
&(U256::MAX, U256::zero()),
&(Segment::StorageLinkedList as usize),
));
self.storage_pointers
self.state_ptrs
.storage
.remove(&(addr, key))
.ok_or(ProgramError::ProverInputError(InvalidInput))?;

Expand All @@ -588,7 +621,8 @@ impl<F: RichField> GenerationState<F> {
let addr = stack_peek(self, 0)?;

let (_, &pred_ptr) = self
.storage_pointers
.state_ptrs
.storage
.range(..(addr, U256::zero()))
.next_back()
.unwrap_or((
Expand Down Expand Up @@ -820,6 +854,7 @@ impl<F: RichField> GenerationState<F> {
}
}

#[allow(dead_code)]
pub(crate) fn get_addresses_access_list(
&self,
) -> Result<LinkedList<ADDRESSES_ACCESS_LIST_LEN>, ProgramError> {
Expand All @@ -832,6 +867,7 @@ impl<F: RichField> GenerationState<F> {
)
}

#[allow(dead_code)]
pub(crate) fn get_storage_keys_access_list(
&self,
) -> Result<LinkedList<STORAGE_KEYS_ACCESS_LIST_LEN>, ProgramError> {
Expand Down
4 changes: 2 additions & 2 deletions evm_arithmetization/src/generation/segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ fn build_segment_data<F: RichField>(
trie_root_ptrs: interpreter.generation_state.trie_root_ptrs.clone(),
jumpdest_table: interpreter.generation_state.jumpdest_table.clone(),
next_txn_index: interpreter.generation_state.next_txn_index,
accounts: interpreter.generation_state.accounts_pointers.clone(),
storage: interpreter.generation_state.storage_pointers.clone(),
access_lists_ptrs: interpreter.generation_state.access_lists_ptrs.clone(),
state_ptrs: interpreter.generation_state.state_ptrs.clone(),
},
}
}
Expand Down
Loading
Loading