Skip to content

Commit

Permalink
Improve DescentMap typing
Browse files Browse the repository at this point in the history
  • Loading branch information
notlesh committed Apr 13, 2024
1 parent 05d943e commit 0387988
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 19 deletions.
8 changes: 6 additions & 2 deletions src/hints/patricia.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::cairo_types::trie::NodeEdge;
use crate::hints::types::{skip_verification_if_configured, Preimage};
use crate::hints::vars;
use crate::starknet::starknet_storage::StorageLeaf;
use crate::starkware_utils::commitment_tree::base_types::{DescentMap, Height};
use crate::starkware_utils::commitment_tree::base_types::{DescentMap, DescentStart, Height, NodePath};
use crate::starkware_utils::commitment_tree::patricia_tree::patricia_guess_descents::patricia_guess_descents;
use crate::starkware_utils::commitment_tree::update_tree::{
build_update_tree, decode_node, DecodeNodeCase, DecodedNode, TreeUpdate,
Expand Down Expand Up @@ -114,7 +114,11 @@ pub fn set_ap_to_descend(
let height = get_integer_from_var_name(vars::ids::HEIGHT, vm, ids_data, ap_tracking)?;
let path = get_integer_from_var_name(vars::ids::PATH, vm, ids_data, ap_tracking)?;

let ap = match descent_map.get(&(height, path)) {
let height = height.try_into()?;
let path = NodePath(path.to_biguint());

let descent_start = DescentStart(height, path);
let ap = match descent_map.get(&descent_start) {
None => Felt252::ZERO,
Some(value) => {
exec_scopes.insert_value(vars::ids::DESCEND, value.clone());
Expand Down
12 changes: 8 additions & 4 deletions src/starkware_utils/commitment_tree/base_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::storage::storage::HASH_BYTES;

pub type TreeIndex = BigUint;

#[derive(Debug, Clone, PartialEq, Default)]
#[derive(Debug, Clone, PartialEq, Default, Eq, Hash)]
pub struct NodePath(pub BigUint);

impl Display for NodePath {
Expand All @@ -34,7 +34,7 @@ impl Serializable for NodePath {
}
}

#[derive(Debug, Copy, Clone, PartialEq, Default)]
#[derive(Debug, Copy, Clone, PartialEq, Default, Eq)]
pub struct Length(pub u64);

impl Sub<u64> for Length {
Expand Down Expand Up @@ -64,7 +64,7 @@ impl Serializable for Length {
}
}

#[derive(Debug, Copy, Clone, PartialEq, Default)]
#[derive(Debug, Copy, Clone, PartialEq, Default, Eq, Hash)]
pub struct Height(pub u64);

impl TryFrom<Felt252> for Height {
Expand All @@ -90,4 +90,8 @@ impl Display for Height {
}
}

pub type DescentMap = HashMap<(Felt252, Felt252), Vec<Felt252>>;
#[derive(Debug, Clone, PartialEq, Default, Eq, Hash)]
pub struct DescentStart(pub Height, pub NodePath);
#[derive(Debug, Clone, PartialEq, Default, Eq)]
pub struct DescentPath(pub Length, pub NodePath);
pub type DescentMap = HashMap<DescentStart, DescentPath>;
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use cairo_vm::vm::errors::hint_errors::HintError;
use cairo_vm::Felt252;
use num_bigint::BigUint;
use num_traits::ToPrimitive;
use std::ops::{Add, Mul};

use crate::starkware_utils::commitment_tree::base_types::{DescentMap, Height, NodePath};
use crate::starkware_utils::commitment_tree::base_types::{DescentMap, DescentPath, DescentStart, Height, Length, NodePath};
use crate::starkware_utils::commitment_tree::update_tree::{TreeUpdate, UpdateTree};

type Preimage = HashMap<Felt252, Vec<Felt252>>;
Expand Down Expand Up @@ -246,15 +247,15 @@ where
// length <= 1 is not a descent.
if length > 1 {
descent_map.insert(
(Felt252::from(orig_height.0), Felt252::from(orig_path.0)),
vec![Felt252::from(length), Felt252::from(&path.0 % (BigUint::from(1u64) << length))],
DescentStart(orig_height, orig_path),
DescentPath(Length(length), NodePath(path.0.clone() % (BigUint::from(1u64) << length))),
);
}

if height.0 > 0 {
let next_height = Height(height.0 - 1);
descent_map.extend(get_descents(next_height, NodePath(&path.0 * 2u64), lefts.0, lefts.1, lefts.2)?);
descent_map.extend(get_descents(next_height, NodePath(path.0 * 2u64 + 1u64), rights.0, rights.1, rights.2)?);
descent_map.extend(get_descents(next_height, NodePath(path.0.clone().mul(2u64)), lefts.0, lefts.1, lefts.2)?);
descent_map.extend(get_descents(next_height, NodePath(path.0.mul(2u64).add(1u64)), rights.0, rights.1, rights.2)?);
}

Ok(descent_map)
Expand Down Expand Up @@ -329,10 +330,11 @@ mod tests {
fn print_descent_map(descent_map: &DescentMap) {
for (key, value) in descent_map {
println!(
"{}-{}: {:?}",
key.0.to_biguint(),
key.1.to_biguint(),
value.iter().map(|x| x.to_biguint()).collect::<Vec<_>>()
"{}-{}: {}-{}",
key.0,
key.1,
value.0,
value.1,
)
}
}
Expand Down Expand Up @@ -390,7 +392,7 @@ mod tests {
print_descent_map(&descent_map);
assert_eq!(
descent_map,
DescentMap::from([((Felt252::from(3), Felt252::from(0)), vec![Felt252::from(3), Felt252::from(1)])]),
DescentMap::from([(DescentStart(Height(3), NodePath(0usize.into())), DescentPath(Length(3), NodePath(1usize.into())))]),
);
}

Expand Down Expand Up @@ -428,7 +430,7 @@ mod tests {
print_descent_map(&descent_map);
assert_eq!(
descent_map,
DescentMap::from([((Felt252::from(3), Felt252::from(0)), vec![Felt252::from(2), Felt252::from(0)])]),
DescentMap::from([(DescentStart(Height(3), NodePath(0usize.into())), DescentPath(Length(2), NodePath(0usize.into())))]),
);
}

Expand Down Expand Up @@ -467,8 +469,8 @@ mod tests {
assert_eq!(
descent_map,
DescentMap::from([
((Felt252::from(2), Felt252::from(0)), vec![Felt252::from(2), Felt252::from(1)]),
((Felt252::from(2), Felt252::from(1)), vec![Felt252::from(2), Felt252::from(0)]),
(DescentStart(Height(2), NodePath(0usize.into())), DescentPath(Length(2), NodePath(1usize.into()))),
(DescentStart(Height(2), NodePath(1usize.into())), DescentPath(Length(2), NodePath(0usize.into()))),
]),
);
}
Expand Down

0 comments on commit 0387988

Please sign in to comment.