Skip to content

Commit

Permalink
fix unwraps
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancoGiachetta committed Nov 15, 2024
1 parent b7bfc6e commit 13e00ff
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 44 deletions.
113 changes: 69 additions & 44 deletions vm/src/hint_processor/cairo_1_hint_processor/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Most of the `EvalCircuit` implementation is derived from the `cairo-lang-runner` crate.
// https://github.com/starkware-libs/cairo/blob/main/crates/cairo-lang-runner/src/casm_run/circuit.rs

use core::{array, ops::Deref};
use core::ops::Deref;

use ark_ff::{One, Zero};
use num_bigint::{BigInt, BigUint, ToBigInt};
Expand All @@ -12,7 +12,10 @@ use starknet_types_core::felt::Felt;
use crate::{
stdlib::boxed::Box,
types::relocatable::{MaybeRelocatable, Relocatable},
vm::{errors::hint_errors::HintError, vm_core::VirtualMachine},
vm::{
errors::{hint_errors::HintError, memory_errors::MemoryError},
vm_core::VirtualMachine,
},
};

// A gate is defined by 3 offsets, the first two are the inputs and the third is the output.
Expand All @@ -31,79 +34,97 @@ struct Circuit<'a> {
}

impl Circuit<'_> {
fn read_add_mod_value(&mut self, offset: usize) -> Option<BigUint> {
self.read_circuit_value((self.add_mod_offsets + offset).unwrap())
fn read_add_mod_value(&mut self, offset: usize) -> Result<Option<BigUint>, MemoryError> {
self.read_circuit_value((self.add_mod_offsets + offset)?)
}

fn read_mul_mod_value(&mut self, offset: usize) -> Option<BigUint> {
self.read_circuit_value((self.mul_mod_offsets + offset).unwrap())
fn read_mul_mod_value(&mut self, offset: usize) -> Result<Option<BigUint>, MemoryError> {
self.read_circuit_value((self.mul_mod_offsets + offset)?)
}

fn read_circuit_value(&mut self, offset: Relocatable) -> Option<BigUint> {
let value_ptr = self.get_value_ptr(offset);
read_circuit_value(self.vm, value_ptr)
fn read_circuit_value(&mut self, offset: Relocatable) -> Result<Option<BigUint>, MemoryError> {
let value_ptr = self.get_value_ptr(offset)?;
Ok(read_circuit_value(self.vm, value_ptr)?)
}

fn write_add_mod_value(&mut self, offset: usize, value: BigUint) {
self.write_circuit_value((self.add_mod_offsets + offset).unwrap(), value);
fn write_add_mod_value(&mut self, offset: usize, value: BigUint) -> Result<(), MemoryError> {
self.write_circuit_value((self.add_mod_offsets + offset)?, value)?;

Ok(())
}

fn write_mul_mod_value(&mut self, offset: usize, value: BigUint) {
self.write_circuit_value((self.mul_mod_offsets + offset).unwrap(), value);
fn write_mul_mod_value(&mut self, offset: usize, value: BigUint) -> Result<(), MemoryError> {
self.write_circuit_value((self.mul_mod_offsets + offset)?, value)?;

Ok(())
}

fn write_circuit_value(&mut self, offset: Relocatable, value: BigUint) {
let value_ptr = self.get_value_ptr(offset);
write_circuit_value(self.vm, value_ptr, value);
fn write_circuit_value(
&mut self,
offset: Relocatable,
value: BigUint,
) -> Result<(), MemoryError> {
let value_ptr = self.get_value_ptr(offset)?;
write_circuit_value(self.vm, value_ptr, value)?;

Ok(())
}

fn get_value_ptr(&self, address: Relocatable) -> Relocatable {
(self.values_ptr + self.vm.get_integer(address).unwrap().as_ref()).unwrap()
fn get_value_ptr(&self, address: Relocatable) -> Result<Relocatable, MemoryError> {
(self.values_ptr + self.vm.get_integer(address)?.as_ref()).map_err(|e| MemoryError::Math(e))
}
}

fn read_circuit_value(vm: &mut VirtualMachine, add: Relocatable) -> Option<BigUint> {
fn read_circuit_value(
vm: &mut VirtualMachine,
add: Relocatable,
) -> Result<Option<BigUint>, MemoryError> {
let mut res = BigUint::zero();

for l in (0..LIMBS_COUNT).rev() {
let add_l = (add + l).unwrap();
let add_l = (add + l)?;
match vm.get_maybe(&add_l) {
Some(MaybeRelocatable::Int(limb)) => res = (res << 96) + limb.to_biguint(),
_ => return None,
_ => return Ok(None),
}
}

Some(res)
Ok(Some(res))
}

fn write_circuit_value(vm: &mut VirtualMachine, add: Relocatable, mut value: BigUint) {
fn write_circuit_value(
vm: &mut VirtualMachine,
add: Relocatable,
mut value: BigUint,
) -> Result<(), MemoryError> {
for l in 0..LIMBS_COUNT {
// get the nth limb from a circuit value
let (new_value, rem) = value.div_rem(&(BigUint::one() << 96u8));
vm.insert_value((add + l).unwrap(), Felt::from(rem))
.unwrap();
vm.insert_value((add + l)?, Felt::from(rem))?;
value = new_value;
}

Ok(())
}

// Finds the inverse of a value.
//
// If the value has no inverse, find a nullifier so that:
// value * nullifier = 0 (mod modulus)
fn find_inverse(value: BigUint, modulus: &BigUint) -> (bool, BigUint) {
fn find_inverse(value: BigUint, modulus: &BigUint) -> Result<(bool, BigUint), HintError> {
let ex_gcd = value
.to_bigint()
.unwrap()
.extended_gcd(&modulus.to_bigint().unwrap());
.ok_or(HintError::BigUintToBigIntFail)?
.extended_gcd(&modulus.to_bigint().ok_or(HintError::BigUintToBigIntFail)?);

let gcd = ex_gcd.gcd.to_biguint().unwrap();
if gcd.is_one() {
return (true, get_modulus(&ex_gcd.x, modulus));
return Ok((true, get_modulus(&ex_gcd.x, modulus)));
}

let nullifier = modulus / gcd;

(false, nullifier)
Ok((false, nullifier))
}

fn get_modulus(value: &BigInt, modulus: &BigUint) -> BigUint {
Expand All @@ -124,7 +145,7 @@ fn compute_gates(
n_mul_mods: usize,
modulus_ptr: Relocatable,
) -> Result<usize, HintError> {
let modulus = read_circuit_value(vm, modulus_ptr).unwrap();
let modulus = read_circuit_value(vm, modulus_ptr)?.unwrap();
let mut circuit = Circuit {
vm,
values_ptr,
Expand All @@ -141,21 +162,21 @@ fn compute_gates(

loop {
while addmod_idx < n_add_mods {
let lhs = circuit.read_add_mod_value(3 * addmod_idx);
let rhs = circuit.read_add_mod_value(3 * addmod_idx + 1);
let lhs = circuit.read_add_mod_value(3 * addmod_idx)?;
let rhs = circuit.read_add_mod_value(3 * addmod_idx + 1)?;

match (lhs, rhs) {
(Some(l), Some(r)) => {
let res = (l + r) % &circuit.modulus;
circuit.write_add_mod_value(3 * addmod_idx + 2, res);
circuit.write_add_mod_value(3 * addmod_idx + 2, res)?;
}
// sub gate: lhs = res - rhs
(None, Some(r)) => {
let Some(res) = circuit.read_add_mod_value(3 * addmod_idx + 2) else {
let Some(res) = circuit.read_add_mod_value(3 * addmod_idx + 2)? else {
break;
};
let value = (res + &circuit.modulus - r) % &circuit.modulus;
circuit.write_add_mod_value(3 * addmod_idx, value);
circuit.write_add_mod_value(3 * addmod_idx, value)?;
}
_ => break,
}
Expand All @@ -167,18 +188,18 @@ fn compute_gates(
break;
}

let lhs = circuit.read_mul_mod_value(3 * mulmod_idx);
let rhs = circuit.read_mul_mod_value(3 * mulmod_idx + 1);
let lhs = circuit.read_mul_mod_value(3 * mulmod_idx)?;
let rhs = circuit.read_mul_mod_value(3 * mulmod_idx + 1)?;

match (lhs, rhs) {
(Some(l), Some(r)) => {
let res = (l * r) % &circuit.modulus;
circuit.write_mul_mod_value(3 * mulmod_idx + 2, res);
circuit.write_mul_mod_value(3 * mulmod_idx + 2, res)?;
}
// inverse gate: lhs = 1 / rhs
(None, Some(r)) => {
let (success, res) = find_inverse(r, &circuit.modulus);
circuit.write_mul_mod_value(3 * mulmod_idx, res);
let (success, res) = find_inverse(r, &circuit.modulus)?;
circuit.write_mul_mod_value(3 * mulmod_idx, res)?;

if !success {
first_failure_idx = mulmod_idx;
Expand Down Expand Up @@ -209,7 +230,7 @@ fn fill_instances(
mut offsets_ptr: Relocatable,
) -> Result<(), HintError> {
for i in 0..n_instances {
let instance_ptr = (built_ptr + i * MOD_BUILTIN_INSTACE_SIZE).unwrap();
let instance_ptr = (built_ptr + i * MOD_BUILTIN_INSTACE_SIZE)?;

for (idx, value) in modulus.iter().enumerate() {
vm.insert_value((instance_ptr + idx)?, *value)?;
Expand Down Expand Up @@ -256,8 +277,12 @@ pub fn eval_circuit(
modulus_ptr,
)?;

let modulus: [Felt; 4] =
array::from_fn(|l| *vm.get_integer((modulus_ptr + l).unwrap()).unwrap().deref());
let modulus: [Felt; 4] = [
*vm.get_integer(modulus_ptr)?.deref(),
*vm.get_integer((modulus_ptr + 1)?)?.deref(),
*vm.get_integer((modulus_ptr + 2)?)?.deref(),
*vm.get_integer((modulus_ptr + 3)?)?.deref(),
];

fill_instances(
vm,
Expand Down
2 changes: 2 additions & 0 deletions vm/src/vm/errors/hint_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ pub enum HintError {
BigintToU32Fail,
#[error("BigInt to BigUint failed, BigInt is negative")]
BigIntToBigUintFail,
#[error("BigUint to BigInt failed")]
BigUintToBigIntFail,
#[error("Assertion failed, 0 <= ids.a % PRIME < range_check_builtin.bound \n a = {0} is out of range")]
ValueOutOfRange(Box<Felt252>),
#[error("Assertion failed, 0 <= ids.a % PRIME < range_check_builtin.bound \n a = {0} is out of range")]
Expand Down

0 comments on commit 13e00ff

Please sign in to comment.