Skip to content

Commit

Permalink
Search in linked lists using a BTree (#603)
Browse files Browse the repository at this point in the history
* Search in linked lists using a Btree

* Minor

* Remove unused

* Minor

* Apply suggestions from code review

Co-authored-by: Robin Salen <[email protected]>

* Get accounts and storage segment lengths from metadata

* Remove linked lists len prover input

* Rustfmt

---------

Co-authored-by: Robin Salen <[email protected]>
Co-authored-by: Robin Salen <[email protected]>
  • Loading branch information
3 people authored Sep 13, 2024
1 parent fec2365 commit 3ef67fd
Show file tree
Hide file tree
Showing 9 changed files with 380 additions and 211 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ insert_new_account:
/// Returns 0 if the account was not found or `original_ptr` if it was already present.
global search_account:
// stack: addr_key, retdest
PROVER_INPUT(linked_list::insert_account)
PROVER_INPUT(linked_list::search_account)
// stack: pred_ptr/4, addr_key, retdest
%get_valid_account_ptr
// stack: pred_ptr, addr_key, retdest
Expand Down Expand Up @@ -685,7 +685,7 @@ next_node_ok:
/// Returns `value` if the storage key was inserted, `old_value` if it was already present.
global search_slot:
// stack: addr_key, key, value, retdest
PROVER_INPUT(linked_list::insert_slot)
PROVER_INPUT(linked_list::search_slot)
// stack: pred_ptr/5, addr_key, key, value, retdest
%get_valid_slot_ptr

Expand Down
16 changes: 13 additions & 3 deletions evm_arithmetization/src/cpu/kernel/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//! the future execution and generate nondeterministically the corresponding
//! jumpdest table, before the actual CPU carries on with contract execution.

use std::collections::{BTreeSet, HashMap};
use std::collections::{BTreeMap, BTreeSet, HashMap};

use anyhow::anyhow;
use ethereum_types::{BigEndianHash, U256};
Expand Down Expand Up @@ -115,6 +115,8 @@ pub(crate) struct ExtraSegmentData {
pub(crate) ger_prover_inputs: Vec<U256>,
pub(crate) trie_root_ptrs: TrieRootPtrs,
pub(crate) jumpdest_table: Option<HashMap<usize, Vec<usize>>>,
pub(crate) accounts: BTreeMap<U256, usize>,
pub(crate) storage: BTreeMap<(U256, U256), usize>,
pub(crate) next_txn_index: usize,
}

Expand Down Expand Up @@ -232,8 +234,12 @@ impl<F: RichField> Interpreter<F> {

// Initialize the MPT's pointers.
let (trie_root_ptrs, state_leaves, storage_leaves, trie_data) =
load_linked_lists_and_txn_and_receipt_mpts(&inputs.tries)
.expect("Invalid MPT data for preinitialization");
load_linked_lists_and_txn_and_receipt_mpts(
&mut self.generation_state.accounts_pointers,
&mut self.generation_state.storage_pointers,
&inputs.tries,
)
.expect("Invalid MPT data for preinitialization");

let trie_roots_after = &inputs.trie_roots_after;
self.generation_state.trie_root_ptrs = trie_root_ptrs;
Expand All @@ -253,6 +259,10 @@ impl<F: RichField> Interpreter<F> {
);
self.insert_preinitialized_segment(Segment::StorageLinkedList, preinit_storage_ll_segment);

// Initialize the accounts and storage BTrees.
self.generation_state.insert_all_slots_in_memory();
self.generation_state.insert_all_accounts_in_memory();

// Update the RLP and withdrawal prover inputs.
let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns);
let withdrawal_prover_inputs = all_withdrawals_prover_inputs_reversed(&inputs.withdrawals);
Expand Down
11 changes: 9 additions & 2 deletions evm_arithmetization/src/cpu/kernel/tests/account_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ pub(crate) fn initialize_mpts<F: RichField>(
) {
// Load all MPTs.
let (mut trie_root_ptrs, state_leaves, storage_leaves, trie_data) =
load_linked_lists_and_txn_and_receipt_mpts(trie_inputs)
.expect("Invalid MPT data for preinitialization");
load_linked_lists_and_txn_and_receipt_mpts(
&mut interpreter.generation_state.accounts_pointers,
&mut interpreter.generation_state.storage_pointers,
trie_inputs,
)
.expect("Invalid MPT data for preinitialization");

interpreter.generation_state.memory.contexts[0].segments
[Segment::AccountsLinkedList.unscale()]
Expand All @@ -44,6 +48,9 @@ pub(crate) fn initialize_mpts<F: RichField>(
trie_data.clone();
interpreter.generation_state.trie_root_ptrs = trie_root_ptrs.clone();

interpreter.generation_state.insert_all_slots_in_memory();
interpreter.generation_state.insert_all_accounts_in_memory();

if trie_root_ptrs.state_root_ptr.is_none() {
trie_root_ptrs.state_root_ptr = Some(
load_state_mpt(
Expand Down
76 changes: 56 additions & 20 deletions evm_arithmetization/src/cpu/kernel/tests/mpt/linked_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ use rand::{thread_rng, Rng};
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::cpu::kernel::interpreter::Interpreter;
use crate::generation::linked_list::LinkedList;
use crate::generation::linked_list::AccountsLinkedList;
use crate::generation::linked_list::StorageLinkedList;
use crate::memory::segments::Segment;
use crate::witness::memory::MemoryAddress;
use crate::witness::memory::MemorySegmentState;

fn init_logger() {
let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug"));
Expand Down Expand Up @@ -80,7 +82,8 @@ fn test_list_iterator() -> Result<()> {
.memory
.get_preinit_memory(Segment::AccountsLinkedList);
let mut accounts_list =
LinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList).unwrap();
AccountsLinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList)
.unwrap();

let Some([addr, ptr, ptr_cpy, scaled_pos_1]) = accounts_list.next() else {
return Err(anyhow::Error::msg("Couldn't get value"));
Expand All @@ -102,7 +105,7 @@ fn test_list_iterator() -> Result<()> {
.memory
.get_preinit_memory(Segment::StorageLinkedList);
let mut storage_list =
LinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap();
StorageLinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap();
let Some([addr, key, ptr, ptr_cpy, scaled_pos_1]) = storage_list.next() else {
return Err(anyhow::Error::msg("Couldn't get value"));
};
Expand Down Expand Up @@ -171,8 +174,16 @@ fn test_insert_account() -> Result<()> {
.memory
.get_preinit_memory(Segment::AccountsLinkedList);
let mut list =
LinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList).unwrap();
AccountsLinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList)
.unwrap();

let Some([addr, ptr, ptr_cpy, _]) = list.next() else {
return Err(anyhow::Error::msg("Couldn't get value"));
};
// This is the dummy node
assert_eq!(addr, U256::MAX);
assert_eq!(ptr, U256::zero());
assert_eq!(ptr_cpy, U256::zero());
let Some([addr, ptr, ptr_cpy, scaled_next_pos]) = list.next() else {
return Err(anyhow::Error::msg("Couldn't get value"));
};
Expand Down Expand Up @@ -251,7 +262,16 @@ fn test_insert_storage() -> Result<()> {
.memory
.get_preinit_memory(Segment::StorageLinkedList);
let mut list =
LinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap();
StorageLinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap();

let Some([inserted_addr, inserted_key, ptr, ptr_cpy, _]) = list.next() else {
return Err(anyhow::Error::msg("Couldn't get value"));
};
// This is the dummy node.
assert_eq!(inserted_addr, U256::MAX);
assert_eq!(inserted_key, U256::zero());
assert_eq!(ptr, U256::zero());
assert_eq!(ptr_cpy, U256::zero());

let Some([inserted_addr, inserted_key, ptr, ptr_cpy, scaled_next_pos]) = list.next() else {
return Err(anyhow::Error::msg("Couldn't get value"));
Expand Down Expand Up @@ -292,9 +312,17 @@ fn test_insert_and_delete_accounts() -> Result<()> {
Some((Segment::AccountsLinkedList as usize).into()),
];
let init_len = init_accounts_ll.len();
interpreter.generation_state.memory.contexts[0].segments
[Segment::AccountsLinkedList.unscale()]
.content = init_accounts_ll;

interpreter
.generation_state
.memory
.insert_preinitialized_segment(
Segment::AccountsLinkedList,
MemorySegmentState {
content: init_accounts_ll,
},
);

interpreter.set_global_metadata_field(
GlobalMetadata::AccountsLinkedListNextAvailable,
(Segment::AccountsLinkedList as usize + init_len).into(),
Expand Down Expand Up @@ -433,19 +461,22 @@ fn test_insert_and_delete_accounts() -> Result<()> {
.generation_state
.memory
.get_preinit_memory(Segment::AccountsLinkedList);
let list =
LinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList).unwrap();
let list = AccountsLinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList)
.unwrap();

for (i, [addr, ptr, ptr_cpy, _]) in list.enumerate() {
if addr == U256::MAX {
assert_eq!(addr, U256::MAX);
assert_eq!(ptr, U256::zero());
assert_eq!(ptr_cpy, U256::zero());
break;
if i > 0 {
break;
}
} else {
let addr_in_list = U256::from(new_addresses[i - 1].0.as_slice());
assert_eq!(addr, addr_in_list);
assert_eq!(ptr, addr + delta_ptr);
}
let addr_in_list = U256::from(new_addresses[i].0.as_slice());
assert_eq!(addr, addr_in_list);
assert_eq!(ptr, addr + delta_ptr);
}

Ok(())
Expand Down Expand Up @@ -640,20 +671,25 @@ fn test_insert_and_delete_storage() -> Result<()> {
.generation_state
.memory
.get_preinit_memory(Segment::StorageLinkedList);
let list = LinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap();
let list =
StorageLinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap();

for (i, [addr, key, ptr, ptr_cpy, _]) in list.enumerate() {
if addr == U256::MAX {
assert_eq!(addr, U256::MAX);
assert_eq!(key, U256::zero());
assert_eq!(ptr, U256::zero());
assert_eq!(ptr_cpy, U256::zero());
break;
if i > 0 {
break;
}
} else {
let [addr_in_list, key_in_list] =
new_addresses[i - 1].map(|x| U256::from(x.0.as_slice()));
assert_eq!(addr, addr_in_list);
assert_eq!(key, key_in_list);
assert_eq!(ptr, addr + delta_ptr);
}
let [addr_in_list, key_in_list] = new_addresses[i].map(|x| U256::from(x.0.as_slice()));
assert_eq!(addr, addr_in_list);
assert_eq!(key, key_in_list);
assert_eq!(ptr, addr + delta_ptr);
}

Ok(())
Expand Down
75 changes: 64 additions & 11 deletions evm_arithmetization/src/generation/linked_list.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::fmt;
use std::marker::PhantomData;

use anyhow::Result;
use ethereum_types::U256;
Expand All @@ -8,15 +9,37 @@ use crate::util::u256_to_usize;
use crate::witness::errors::ProgramError;
use crate::witness::errors::ProverInputError::InvalidInput;

pub const ACCOUNTS_LINKED_LIST_NODE_SIZE: usize = 4;
pub const STORAGE_LINKED_LIST_NODE_SIZE: usize = 5;

pub(crate) trait LinkedListType {}
#[derive(Clone)]
/// A linked list that starts from the first node after the special node and
/// iterates forever.
pub(crate) struct Cyclic;
#[derive(Clone)]
/// A linked list that starts from the special node and iterates until the last
/// node.
pub(crate) struct Bounded;
impl LinkedListType for Cyclic {}
impl LinkedListType for Bounded {}

pub(crate) type AccountsLinkedList<'a> = LinkedList<'a, ACCOUNTS_LINKED_LIST_NODE_SIZE>;
pub(crate) type StorageLinkedList<'a> = LinkedList<'a, STORAGE_LINKED_LIST_NODE_SIZE>;

// A linked list implemented using a vector `access_list_mem`.
// In this representation, the values of nodes are stored in the range
// `access_list_mem[i..i + node_size - 1]`, and `access_list_mem[i + node_size -
// 1]` holds the address of the next node, where i = node_size * j.
#[derive(Clone)]
pub(crate) struct LinkedList<'a, const N: usize> {
pub(crate) struct LinkedList<'a, const N: usize, T = Cyclic>
where
T: LinkedListType,
{
mem: &'a [Option<U256>],
offset: usize,
pos: usize,
_marker: PhantomData<T>,
}

pub(crate) fn empty_list_mem<const N: usize>(segment: Segment) -> [Option<U256>; N] {
Expand All @@ -31,15 +54,15 @@ pub(crate) fn empty_list_mem<const N: usize>(segment: Segment) -> [Option<U256>;
})
}

impl<'a, const N: usize> LinkedList<'a, N> {
pub const fn from_mem_and_segment(
impl<'a, const N: usize, T: LinkedListType> LinkedList<'a, N, T> {
pub fn from_mem_and_segment(
mem: &'a [Option<U256>],
segment: Segment,
) -> Result<Self, ProgramError> {
Self::from_mem_len_and_segment(mem, segment)
}

pub const fn from_mem_len_and_segment(
pub fn from_mem_len_and_segment(
mem: &'a [Option<U256>],
segment: Segment,
) -> Result<Self, ProgramError> {
Expand All @@ -50,6 +73,7 @@ impl<'a, const N: usize> LinkedList<'a, N> {
mem,
offset: segment as usize,
pos: 0,
_marker: PhantomData,
})
}
}
Expand All @@ -58,9 +82,8 @@ impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Linked List {{")?;
let cloned_list = self.clone();
for node in cloned_list {
if node[0] == U256::MAX {
writeln!(f, "{:?}", node)?;
for (i, node) in cloned_list.enumerate() {
if i > 0 && node[0] == U256::MAX {
break;
}
writeln!(f, "{:?} ->", node)?;
Expand All @@ -69,17 +92,47 @@ impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N> {
}
}

impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N, Bounded> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Linked List {{")?;
let cloned_list = self.clone();
for node in cloned_list {
writeln!(f, "{:?} ->", node)?;
}
write!(f, "}}")
}
}

impl<'a, const N: usize> Iterator for LinkedList<'a, N> {
type Item = [U256; N];

fn next(&mut self) -> Option<Self::Item> {
// The first node is always the special node, so we skip it in the first
// iteration.
let node = Some(std::array::from_fn(|i| {
self.mem[self.pos + i].unwrap_or_default()
}));
if let Ok(new_pos) = u256_to_usize(self.mem[self.pos + N - 1].unwrap_or_default()) {
self.pos = new_pos - self.offset;
Some(std::array::from_fn(|i| {
node
} else {
None
}
}
}

impl<'a, const N: usize> Iterator for LinkedList<'a, N, Bounded> {
type Item = [U256; N];

fn next(&mut self) -> Option<Self::Item> {
if self.mem[self.pos] != Some(U256::MAX) {
let node = Some(std::array::from_fn(|i| {
self.mem[self.pos + i].unwrap_or_default()
}))
}));
if let Ok(new_pos) = u256_to_usize(self.mem[self.pos + N - 1].unwrap_or_default()) {
self.pos = new_pos - self.offset;
node
} else {
None
}
} else {
None
}
Expand Down
Loading

0 comments on commit 3ef67fd

Please sign in to comment.