Skip to content

Commit

Permalink
fix: handling of dict relationships in args gen (#534)
Browse files Browse the repository at this point in the history
Closes #533
  • Loading branch information
enitrat authored Jan 24, 2025
1 parent 857c004 commit c82e7d7
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 13 deletions.
1 change: 1 addition & 0 deletions cairo/ethereum/cancun/state.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ func close_transaction{
let transient_storage_tries_start = transient_storage_tries.value._data.value.dict_ptr_start;
let transient_storage_tries_end = transient_storage_tries.value._data.value.dict_ptr;
let parent_transient_storage_tries = transient_storage_tries.value._data.value.parent_dict;

with_attr error_message("IndexError") {
tempvar parent_transient_storage_tries_ptr = cast(parent_transient_storage_tries, felt);
if (cast(parent_transient_storage_tries_ptr, felt) == 0) {
Expand Down
20 changes: 9 additions & 11 deletions cairo/tests/ethereum/cancun/test_state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import Optional

import pytest
Expand Down Expand Up @@ -36,7 +37,7 @@
set_transient_storage,
touch_account,
)
from ethereum.cancun.trie import Trie
from ethereum.cancun.trie import Trie, copy_trie
from tests.utils.args_gen import State, TransientStorage, Withdrawal
from tests.utils.errors import strict_raises
from tests.utils.strategies import (
Expand Down Expand Up @@ -91,7 +92,7 @@ def state_with_snapshots(draw):

# Start with base state's tries
current_main_trie = base_state._main_trie
current_storage_tries = base_state._storage_tries.copy()
current_storage_tries = copy.deepcopy(base_state._storage_tries)
snapshots = []

for _ in range(num_snapshots):
Expand All @@ -100,11 +101,8 @@ def state_with_snapshots(draw):
new_accounts = draw(
st.dictionaries(keys=address, values=st.from_type(Account), max_size=5)
)
main_trie_data = current_main_trie._data.copy()
main_trie_data.update(new_accounts)
main_trie = Trie[Address, Optional[Account]](
secured=True, default=None, _data=main_trie_data
)
main_trie_copy = copy_trie(current_main_trie)
main_trie_copy._data.update(new_accounts)

# Add up to 5 new storage tries or update existing ones
new_storage_tries = draw(
Expand All @@ -114,11 +112,11 @@ def state_with_snapshots(draw):
max_size=5,
)
)
storage_tries = current_storage_tries.copy()
storage_tries = copy.deepcopy(current_storage_tries)
storage_tries.update(new_storage_tries)

# Update current state for next iteration
current_main_trie = main_trie
current_main_trie = main_trie_copy
current_storage_tries = storage_tries

return State(
Expand All @@ -139,7 +137,7 @@ def transient_storage_with_snapshots(draw):
num_snapshots = draw(st.integers(min_value=0, max_value=5))

# Start with base transient storage tries
current_tries = base_transient_storage._tries.copy()
current_tries = copy.deepcopy(base_transient_storage._tries)
snapshots = []

for _ in range(num_snapshots):
Expand All @@ -152,7 +150,7 @@ def transient_storage_with_snapshots(draw):
max_size=5,
)
)
tries = current_tries.copy()
tries = copy.deepcopy(current_tries)
tries.update(new_tries)

# Update current tries for next iteration
Expand Down
12 changes: 11 additions & 1 deletion cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def generate_trie_arg(
# In case of a Trie, we need the dict to be a defaultdict with the trie.default as the default value.
dict_ptr = segments.memory.get(data)
current_ptr = segments.memory.get(data + 1)

if isinstance(dict_manager, DictManager):
dict_manager.trackers[dict_ptr.segment_index].data = defaultdict(
lambda: default, dict_manager.trackers[dict_ptr.segment_index].data
Expand Down Expand Up @@ -947,16 +948,25 @@ def generate_dict_arg(
# This is required for tests where we read data from DictAccess segments while no dict method has been used.
# Equivalent to doing an initial dict_read of all keys.
# We only hash keys if they're in tuples.

# In case of a dict update, we need to get the prev_value from the dict_tracker of the parent_ptr.
# For consistency purposes when we drop the dict and put its prev values back in the parent_ptr.
parent_dict_end_ptr = segments.memory.get(parent_ptr + 1) if parent_ptr else None
initial_data = flatten(
[
(
(poseidon_hash_many(k) if get_args(arg_type)[0] in HASHED_TYPES else k),
v,
(
dict_manager.get_tracker(parent_dict_end_ptr).data.get(k, v)
if parent_dict_end_ptr
else v
),
v,
)
for k, v in data.items()
]
)

segments.load_data(dict_ptr, initial_data)
current_ptr = dict_ptr + len(initial_data)

Expand Down
15 changes: 15 additions & 0 deletions crates/cairo-addons/src/vm/dict_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,21 @@ impl PyDictManager {
Ok(PyTrackerMapping { inner: self.inner.clone() })
}

fn get_tracker(&self, ptr: PyRelocatable) -> PyResult<PyDictTracker> {
self.inner
.borrow()
.trackers
.get(&ptr.inner.segment_index)
.cloned()
.map(|tracker| PyDictTracker { inner: tracker })
.ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!(
"segment_index {} not found",
ptr.inner.segment_index
))
})
}

fn insert(&mut self, segment_index: isize, value: &PyDictTracker) -> PyResult<()> {
if self.inner.borrow().trackers.contains_key(&segment_index) {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
Expand Down
2 changes: 1 addition & 1 deletion python/cairo-addons/src/cairo_addons/hints/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ def update_dict_tracker(
ap: RelocatableValue,
):
dict_tracker = dict_manager.get_tracker(ids.current_tracker_ptr)
dict_tracker.current_ptr = ids.new_tracker_ptr
dict_tracker.current_ptr = ids.new_tracker_ptr.address_

0 comments on commit c82e7d7

Please sign in to comment.