Skip to content

Commit

Permalink
static lookup and consts fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rebenkoy committed Nov 19, 2023
1 parent 96d7d70 commit 2fc2072
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 66 deletions.
14 changes: 7 additions & 7 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ where
output
}

fn apply_internal(&mut self, visibility: Visibility, round : usize, polyop: PolyOp<'circuit, F>, input: Vec<Variable>, constants: &'circuit [F]) -> Vec<Variable> {
fn apply_internal(&mut self, visibility: Visibility, round : usize, polyop: PolyOp<'circuit, F>, input: Vec<Variable>, constants: &[F]) -> Vec<Variable> {
assert!(round < self.ops.len(), "The round is too large.");

let op_index = self.ops[round].len();
Expand All @@ -275,21 +275,21 @@ where
output
}

pub fn apply(&mut self, round: usize, polyop: PolyOp<'circuit, F>, input: Vec<Variable>, constants: &'circuit[F]) -> Vec<Variable> {
pub fn apply(&mut self, round: usize, polyop: PolyOp<'circuit, F>, input: Vec<Variable>, constants: &[F]) -> Vec<Variable> {
self.apply_internal(Visibility::Private, round, polyop, input, constants)
}

pub fn apply_pub(&mut self, round : usize, polyop: PolyOp<'circuit, F>, input: Vec<Variable>, constants: &'circuit[F]) -> Vec<Variable> {
pub fn apply_pub(&mut self, round : usize, polyop: PolyOp<'circuit, F>, input: Vec<Variable>, constants: &[F]) -> Vec<Variable> {
self.apply_internal(Visibility::Public, round, polyop, input, constants)
}

// TODO: pass input by value since we clone it down the stack either way
pub fn constrain(&mut self, input: &[Variable], constants: &'circuit[F], gate: G) {
pub fn constrain(&mut self, input: &[Variable], constants: &[F], gate: G) {
println!("Using legacy unnamed constrains");
self._constrain(&input, &constants, gate)
}

fn _constrain(&mut self, input: &[Variable], constants: &'circuit[F], gate: G) {
fn _constrain(&mut self, input: &[Variable], constants: &[F], gate: G) {
assert!(gate.d() > 0, "Trying to constrain with gate of degree 0.");

let kind = if gate.d() == 1 { CommitKind::Zero } else { CommitKind::Group };
Expand All @@ -299,7 +299,7 @@ where
pub fn constrain_with(
&mut self,
input: &[Variable],
constants: &'circuit[F],
constants: &[F],
gate_fetcher: &dyn Fn(&FrozenMap<String, Box<G>>) -> G
) {
let gate = gate_fetcher(&self.gate_registry);
Expand Down Expand Up @@ -363,7 +363,7 @@ where
pub fn valid_witness(&self) -> () {
for constr in self.circuit.cs.iter_constraints() {
let input_values: Vec<_> = constr.inputs.iter().map(|&x| self.cs.getvar(x)).collect();
let result = constr.gate.exec(&input_values, &[]);
let result = constr.gate.exec(&input_values, &constr.constants);

assert!(result.iter().all(|&output| output == F::ZERO), "Constraint {:?} is not satisfied", constr);
}
Expand Down
100 changes: 45 additions & 55 deletions src/gadgets/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,11 @@ pub fn sum_of_fractions<'a, F:PrimeField+FieldUtils> (args: &[F], k: usize) -> F
}

/// Parses input as `a, c, vals[0], ... vals[k-1], nums[0], ... nums[k-1]` and returns F::ZERO if a == \sum nums[i]/(vals[i]-c) or one of denominators is F::ZERO itself
pub fn sum_of_fractions_with_nums<'a, F:PrimeField+FieldUtils> (args: &[F], k: usize) -> F {
let (tmp1, tmp2) = args.split_at(2);
assert!(tmp2.len() == 2 * k);
let (vals, nums) = tmp2.split_at(k);
pub fn sum_of_fractions_with_nums<'a, F:PrimeField+FieldUtils> (args: &[F], dens: &[F], k: usize) -> F {
let (tmp1, nums) = args.split_at(2);
assert!(nums.len() == k);
let (res, c) = (tmp1[0], tmp1[1]);
let (prod, skips) = montgomery(& vals.iter().map(|t| *t - c).collect_vec());
let (prod, skips) = montgomery(& dens.iter().map(|t| *t - c).collect_vec());
res * prod - skips.iter().zip_eq(nums.iter()).fold(F::ZERO, |acc, (skip, num)| acc + *skip * num)
}

Expand All @@ -87,15 +86,15 @@ pub fn invsum_flat_constrain<'a, F: PrimeField+FieldUtils>(
pub fn fracsum_flat_constrain<'a, F: PrimeField+FieldUtils>(
circuit: &mut Circuit<'a, F, Gatebb<'a, F>>,
nums: &[Variable],
dens: &[Variable],
dens: &'a[F],
res: Variable,
challenge: Variable,
) -> () {
assert!(dens.len()==nums.len());
let args = [res, challenge].iter().chain(dens.iter()).chain(nums.iter()).map(|x|*x).collect_vec();
let args = [res, challenge].iter().chain(nums.iter()).map(|x|*x).collect_vec();
let k = dens.len();
let gate = Gatebb::new(dens.len()+1, args.len(), 1, Rc::new(move |args, _|vec![sum_of_fractions_with_nums(args, k)]));
circuit.constrain(&args, &[], gate);
let gate = Gatebb::new(dens.len()+1, args.len(), 1, Rc::new(move |args, _|vec![sum_of_fractions_with_nums(args, &dens, k)]));
circuit.constrain(&args, &dens, gate);
}

/// Gadget which returns the sum of inverses of an array, shifted by a challenge.
Expand Down Expand Up @@ -149,7 +148,7 @@ pub fn invsum_gadget<'a, F: PrimeField+FieldUtils>(
pub fn fracsum_gadget<'a, F: PrimeField+FieldUtils>(
circuit: &mut Circuit<'a, F, Gatebb<'a, F>>,
nums: &[Variable],
dens: &[Variable],
dens: &'a [F],
challenge: Variable,
rate: usize,
round: usize,
Expand All @@ -162,10 +161,9 @@ pub fn fracsum_gadget<'a, F: PrimeField+FieldUtils>(
let mut dens = dens;
let mut num_chunk;
let mut den_chunk;
let advice = Advice::new(2*l+1, 0, l/rate, move |args: &[F], _|{
let (tmp, c) = args.split_at(2*l);
let advice = Advice::new(l+1, 0, l/rate, move |args: &[F], _|{
let (nums, c) = args.split_at(l);
let c = c[0];
let (nums, dens) = tmp.split_at(l);
let mut inv = dens.iter().map(|x|*x-c).collect_vec();
inv.batch_invert();
let mut ret = vec![];
Expand All @@ -181,7 +179,7 @@ pub fn fracsum_gadget<'a, F: PrimeField+FieldUtils>(
ret
});

let args = nums.iter().chain(dens.iter()).map(|x|*x).chain(once(challenge)).collect();
let args = nums.iter().map(|x|*x).chain(once(challenge)).collect();

let batches = circuit.advice(round, advice, args, vec![]);
for i in 0..l/rate {
Expand All @@ -193,13 +191,13 @@ pub fn fracsum_gadget<'a, F: PrimeField+FieldUtils>(
sum_gadget(circuit, &batches, round)
}
///
pub trait Lookup<F: PrimeField+FieldUtils> {
pub trait Lookup<'a, F: PrimeField+FieldUtils> {
/// Adds the variable to the list of variables to look up.
fn check<'a>(&mut self, circuit: &mut Circuit<'a, F, Gatebb<'a,F>>, var: Variable) -> ();
fn check(&mut self, circuit: &mut Circuit<'a, F, Gatebb<'a,F>>, var: Variable) -> ();
/// Seals the lookup and applies the constraints. Returns the challenge.
/// Round parameter is the round of a challenge - so it must be strictly larger than rounds of any
/// variable participating in a lookup.
fn finalize<'a>(
fn finalize(
self,
circuit: &mut Circuit<'a, F, Gatebb<'a,F>>,
table_round: usize,
Expand All @@ -209,57 +207,52 @@ pub trait Lookup<F: PrimeField+FieldUtils> {
) -> ();
}

pub struct RangeLookup<F: PrimeField+FieldUtils> {
pub struct StaticLookup<'c, F: PrimeField+FieldUtils> {
vars: Vec<Variable>,
round: usize,
challenge: ExternalValue<F>,
range: usize,
table: &'c [F],
}

impl<F: PrimeField+FieldUtils> RangeLookup<F> {
pub fn new(challenge_src: ExternalValue<F>, range: usize) -> Self {
impl<'c, F: PrimeField+FieldUtils> StaticLookup<'c, F> {
pub fn new(challenge_src: ExternalValue<F>, table: &'c [F]) -> Self {

Self{
vars: vec![],
round: 0,
challenge: challenge_src,
range,
table,
}
}
}

impl<F: PrimeField+FieldUtils> Lookup<F> for RangeLookup<F> {
fn check<'a>(&mut self, _circuit: &mut Circuit<'a, F, Gatebb<'a,F>>, var: Variable) -> () {
impl<'a, 'c: 'a, F: PrimeField+FieldUtils> Lookup<'a, F> for StaticLookup<'c, F> {
fn check(&mut self, _circuit: &mut Circuit<'a, F, Gatebb<'a,F>>, var: Variable) -> () {
if self.round < var.round {
self.round = var.round
}
self.vars.push(var);
}
fn finalize<'a>(
fn finalize(
self,
circuit: &mut Circuit<'a, F, Gatebb<'a,F>>,
table_round: usize,
access_round: usize,
challenge_round: usize,
rate: usize,
) -> () {
let Self{vars, round, challenge, range} = self;
let Self{vars, round, challenge, table} = self;

assert!(table_round <= access_round);
assert!(access_round >= round);
assert!(challenge_round > access_round);

// Table of values 0, 1, ..., range-1
let read_table = Advice::new(0, 0, range, move |_:&[F], _| {
(0..range).map(|i|F::from(i as u64)).collect()
});
let table = circuit.advice(table_round, read_table, vec![], vec![]);
// Access counts.
let compute_accesses = Advice::new(vars.len(), 0, range, move |vars: &[F], _|{
let mut ret = vec![0; range];
let compute_accesses = Advice::new(vars.len(), 0, table.len(), move |vars: &[F], _|{
let mut ret = vec![0; table.len()];
for var in vars{
let var = BigUint::from_bytes_le(var.to_repr().as_ref());
assert!(var < range.into(), "Error: lookup value out of range.");
assert!(var < table.len().into(), "Error: lookup value out of range.");
let u64_digits = var.to_u64_digits();
let mut i = 0 as usize;
if u64_digits.len() > 0 {
Expand All @@ -274,7 +267,7 @@ impl<F: PrimeField+FieldUtils> Lookup<F> for RangeLookup<F> {
let challenge = input(circuit, challenge, challenge_round);

let lhs = invsum_gadget(circuit, &vars, challenge, rate, challenge_round);
let rhs = fracsum_gadget(circuit, &access_counts, &table, challenge, rate, challenge_round);
let rhs = fracsum_gadget(circuit, &access_counts, table, challenge, rate, challenge_round);

eq_gadget(circuit, lhs, rhs);
}
Expand Down Expand Up @@ -394,21 +387,21 @@ mod test {
#[should_panic]
fn no_args() {
type F = bn256::Fr;
sum_of_fractions_with_nums::<F>(&[], 0);
sum_of_fractions_with_nums::<F>(&[], &[], 0);
}

#[test]
#[should_panic]
fn small_k() {
type F = bn256::Fr;
sum_of_fractions_with_nums::<F>(&[F::ONE, F::ONE, F::ONE, F::ONE, F::ONE], 1);
sum_of_fractions_with_nums::<F>(&[F::ONE, F::ONE], &[F::ONE, F::ONE, F::ONE], 1);
}

#[test]
#[should_panic]
fn big_k() {
type F = bn256::Fr;
sum_of_fractions_with_nums::<F>(&[F::ONE, F::ONE, F::ONE, F::ONE, F::ONE, F::ONE, F::ONE], 3);
sum_of_fractions_with_nums::<F>(&[F::ONE, F::ONE], &[F::ONE, F::ONE, F::ONE, F::ONE, F::ONE], 3);
}
}

Expand All @@ -417,7 +410,7 @@ mod test {
type F = bn256::Fr;
let c = F::random(OsRng);

assert_eq!(sum_of_fractions_with_nums::<F>(&[F::ZERO, c], 0), F::ZERO);
assert_eq!(sum_of_fractions_with_nums::<F>(&[F::ZERO, c], &[], 0), F::ZERO);
}

#[test]
Expand All @@ -431,9 +424,8 @@ mod test {
let sum = points.iter().zip_eq(&numerators).map(|(p, n)| (p - c).invert().unwrap() * n).fold(F::ZERO, |acc, n| acc + n);

let mut inputs = vec![sum, c];
inputs.extend(points.iter());
inputs.extend(numerators.iter());
assert_eq!(sum_of_fractions_with_nums::<F>(&inputs, indexes.len()), F::ZERO);
assert_eq!(sum_of_fractions_with_nums::<F>(&inputs, &points, indexes.len()), F::ZERO);
}

#[test]
Expand All @@ -445,9 +437,8 @@ mod test {
let points = indexes.clone().map(|_| F::random(OsRng)).collect_vec();
let numerators = indexes.clone().map(|_| F::random(OsRng)).collect_vec();
let mut inputs = vec![fake_sum, c];
inputs.extend(points.iter());
inputs.extend(numerators.iter());
let res = sum_of_fractions_with_nums::<F>(&inputs, indexes.len());
let res = sum_of_fractions_with_nums::<F>(&inputs, &points, indexes.len());
let padded = points.iter().map(|p| (p - c)).collect_vec();
let real_sum = padded.iter().zip_eq(&numerators).map(|(p, n)| p.invert().unwrap() * n).fold(F::ZERO, |acc, n| acc + n);
let real_denominator = padded.iter().fold(F::ONE, |acc, n| acc * n);
Expand Down Expand Up @@ -506,20 +497,17 @@ mod test {

let mut circuit = Circuit::new(TEST_LEN + 1, 1);
let challenge_value = circuit.ext_val(1)[0];
let points_values = circuit.ext_val(TEST_LEN);
let numerators_values = circuit.ext_val(TEST_LEN);
let result_value = circuit.ext_val(1)[0];

let challenge_variable = input(&mut circuit, challenge_value, 0);
let points_variables = points_values.clone().into_iter().map(|val| input(&mut circuit, val, 0)).collect_vec();
let numerator_variables = numerators_values.clone().into_iter().map(|val| input(&mut circuit, val, 0)).collect_vec();
let reslut_variable = input(&mut circuit, result_value, 0);

fracsum_flat_constrain(&mut circuit, &numerator_variables, &points_variables, reslut_variable, challenge_variable);
fracsum_flat_constrain(&mut circuit, &numerator_variables, &points, reslut_variable, challenge_variable);
let mut instance = circuit.finalize();

instance.set_ext(challenge_value, challenge);
points_values.into_iter().zip_eq(points).map(|(val, point)| instance.set_ext(val, point)).last();
numerators_values.into_iter().zip_eq(numerators).map(|(val, point)| instance.set_ext(val, point)).last();
instance.set_ext(result_value, result);

Expand Down Expand Up @@ -574,18 +562,15 @@ mod test {

let mut circuit = Circuit::new(TEST_LEN + 1, 1);
let challenge_value = circuit.ext_val(1)[0];
let points_values = circuit.ext_val(TEST_LEN);
let numerators_values = circuit.ext_val(TEST_LEN);

let challenge_variable = input(&mut circuit, challenge_value, 0);
let points_variables = points_values.clone().into_iter().map(|val| input(&mut circuit, val, 0)).collect_vec();
let numerator_variables = numerators_values.clone().into_iter().map(|val| input(&mut circuit, val, 0)).collect_vec();

let result_variable = fracsum_gadget(&mut circuit, &numerator_variables, &points_variables, challenge_variable, 3, 0);
let result_variable = fracsum_gadget(&mut circuit, &numerator_variables, &points, challenge_variable, 3, 0);
let mut circuit = circuit.finalize();

circuit.set_ext(challenge_value, challenge);
points_values.into_iter().zip_eq(points).map(|(val, point)| circuit.set_ext(val, point)).last();
numerators_values.into_iter().zip_eq(numerators).map(|(val, point)| circuit.set_ext(val, point)).last();

circuit.execute(0);
Expand All @@ -605,11 +590,12 @@ mod test {
let indexes = 0..TEST_LEN;
let range = 16;

let table = (0..range).map(|x| F::from(x as u64)).collect_vec();
let mut circuit = Circuit::new(range + 1, TEST_LEN + 1);

let challenge_value = circuit.ext_val(1)[0];
let test_values = circuit.ext_val(TEST_LEN);
let mut range_lookup = RangeLookup::new(challenge_value, range);
let mut range_lookup = StaticLookup::new(challenge_value, &table);

let test_variables = test_values.clone().into_iter().enumerate().map(|(i, v)| input(&mut circuit, v, i)).collect_vec();
test_variables.into_iter().map(|variable| range_lookup.check(&mut circuit, variable)).last();
Expand All @@ -636,11 +622,12 @@ mod test {
let range = 16;
let rounds = 3;

let table = (0..range).map(|x| F::from(x as u64)).collect_vec();
let mut circuit = Circuit::new(range + 1, rounds);

let challenge_value: ExternalValue<F> = circuit.ext_val(1)[0];
let test_value = circuit.ext_val(1)[0];
let mut range_lookup = RangeLookup::new(challenge_value, range);
let mut range_lookup = StaticLookup::new(challenge_value, &table);

let test_variable = input(&mut circuit, test_value, 0);
range_lookup.check(&mut circuit, test_variable);
Expand All @@ -654,11 +641,12 @@ mod test {
let range = 16;
let rounds = 3;

let table = (0..range).map(|x| F::from(x as u64)).collect_vec();
let mut circuit = Circuit::new(range + 1, rounds);

let challenge_value: ExternalValue<F> = circuit.ext_val(1)[0];
let test_value = circuit.ext_val(1)[0];
let mut range_lookup = RangeLookup::new(challenge_value, range);
let mut range_lookup = StaticLookup::new(challenge_value, &table);

let test_variable = input(&mut circuit, test_value, 2);
range_lookup.check(&mut circuit, test_variable);
Expand All @@ -672,11 +660,13 @@ mod test {
let range = 16;
let rounds = 3;

let table = (0..range).map(|x| F::from(x as u64)).collect_vec();
let mut circuit = Circuit::new(range + 1, rounds);


let challenge_value: ExternalValue<F> = circuit.ext_val(1)[0];
let test_value = circuit.ext_val(1)[0];
let mut range_lookup = RangeLookup::new(challenge_value, range);
let mut range_lookup = StaticLookup::new(challenge_value, &table);

let test_variable = input(&mut circuit, test_value, 2);
range_lookup.check(&mut circuit, test_variable);
Expand Down
Loading

0 comments on commit 2fc2072

Please sign in to comment.