Skip to content

Commit

Permalink
prover tested
Browse files Browse the repository at this point in the history
  • Loading branch information
rebenkoy committed Nov 26, 2023
1 parent e51e9a8 commit 42d1c18
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 81 deletions.
47 changes: 14 additions & 33 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,18 @@ impl<'circuit, F: PrimeField, G: Gate<'circuit, F> + From<PolyOp<'circuit, F>>>
fn deallocate<'constructed>(&'constructed self, idx: RunIndex) {
self.run_allocator.borrow_mut().deallocate(idx);
}

pub fn perepare_protostar_chellanges(&self, mut beta: F) -> Vec<F> {
let m = self.circuit.cs.constr_spec().num_nonlinear_constraints;
let mut p = 1;
let mut protostar_challenges = vec![];
while p < m {
protostar_challenges.push(beta);
beta = beta * beta;
p = p * 2;
}
protostar_challenges
}
}

pub struct CircuitRun<'constructed, 'circuit, F: PrimeField, G: Gate<'circuit, F> + From<PolyOp<'circuit, F>>>{
Expand All @@ -368,15 +380,8 @@ where
}
}

pub fn end(&self, mut beta: F) -> ProtostarWtns<F> {
let m = self.constructed.circuit.cs.constr_spec().num_nonlinear_constraints;
let mut p = 1;
let mut protostar_challenges = vec![];
while p < m {
protostar_challenges.push(beta);
beta = beta * beta;
p = p * 2;
}
pub fn end(&self, beta: F) -> ProtostarWtns<F> {
let protostar_challenges = self.constructed.perepare_protostar_chellanges(beta);

let mut pubs = vec![];
let mut round_wtns = vec![];
Expand Down Expand Up @@ -410,30 +415,6 @@ where
}
}

pub fn error_term(&self, beta: F) -> F {
let mut resulst = vec![];
for constr in self.constructed.circuit.cs.iter_constraints() {
let input_values: Vec<_> = constr.inputs.iter().map(|&x| self.cs.getvar(x)).collect();
resulst.extend(constr.gate.exec(&input_values));
}
let mut b = 1;
let mut betas = vec![beta];
while b < resulst.len() {
b *= 2;
let last = *betas.last().unwrap();
betas.push(last * last);
}

for beta_pow in betas.iter().rev() {
for i in b..((b * 2).min(resulst.len())) {
resulst[i - b] = resulst[i - b] + resulst[i] * beta_pow;
}
b /= 2;
}

resulst[0]
}

pub fn iter_constraints(&self) -> impl Iterator<Item = &Constraint<'circuit, F, G>> {
self.constructed.circuit.cs.iter_constraints()
}
Expand Down
4 changes: 2 additions & 2 deletions src/constraint_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ pub trait CS<'c, F: PrimeField, G: Gate<'c, F>> {
pub struct ProtoGalaxyConstraintSystem<'c, F: PrimeField, G: Gate<'c, F>> {
pub spec: WitnessSpec,
pub max_degree: usize,
pub linear_constraints: ConstraintGroup<'c, F, G>,
pub non_linear_constraints: BTreeMap<usize, ConstraintGroup<'c, F, G>>,
linear_constraints: ConstraintGroup<'c, F, G>,
non_linear_constraints: BTreeMap<usize, ConstraintGroup<'c, F, G>>,
}

impl<'c, F: PrimeField, G: Gate<'c, F>> ProtoGalaxyConstraintSystem<'c, F, G> {
Expand Down
72 changes: 51 additions & 21 deletions src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ff::{PrimeField, BatchInvert};
use halo2::arithmetic::lagrange_interpolate;
use itertools::Itertools;

use crate::{witness::ProtostarWtns, gate::Gate, circuit::PolyOp, constraint_system::{ProtoGalaxyConstraintSystem, Visibility}, utils::{cross_terms_combination::{combine_cross_terms, self, EvalLayout}, field_precomp::FieldUtils, inv_lagrange_prod}, gadgets::range::lagrange_choice};
use crate::{witness::{ProtostarWtns, ProtostarLhsWtns}, gate::Gate, circuit::PolyOp, constraint_system::{ProtoGalaxyConstraintSystem, Visibility}, utils::{cross_terms_combination::{combine_cross_terms, self, EvalLayout}, field_precomp::FieldUtils, inv_lagrange_prod}, gadgets::range::lagrange_choice};

pub struct ProtoGalaxyProver {

Expand All @@ -23,10 +23,10 @@ impl ProtoGalaxyProver
> (
&self,
cs: &ProtoGalaxyConstraintSystem<'circuit, F, G>,
template: &ProtostarWtns<F>,
template: &ProtostarLhsWtns<F>,
) -> (Vec<Vec<usize>>, Vec<Vec<usize>>) {
let mut pubs_degrees: Vec<Vec<usize>> = template.lhs.pubs.iter().map(|v| v.iter().map(|_| 0).collect_vec()).collect_vec();
let mut privs_degrees: Vec<Vec<usize>> = template.lhs.round_wtns.iter().map(|v| v.iter().map(|_| 0).collect_vec()).collect_vec();
let mut pubs_degrees: Vec<Vec<usize>> = template.pubs.iter().map(|v| v.iter().map(|_| 0).collect_vec()).collect_vec();
let mut privs_degrees: Vec<Vec<usize>> = template.round_wtns.iter().map(|v| v.iter().map(|_| 0).collect_vec()).collect_vec();

for constraint in cs.iter_non_linear_constraints() {
for variable in &constraint.inputs {
Expand Down Expand Up @@ -130,8 +130,8 @@ impl ProtoGalaxyProver

pub fn prove<'circuit, F: PrimeField + FieldUtils, G: Gate<'circuit, F> + From<PolyOp<'circuit, F>>>(
&self,
a: &ProtostarWtns<F>,
b: &ProtostarWtns<F>,
a: &ProtostarLhsWtns<F>,
b: &ProtostarLhsWtns<F>,
cs: &ProtoGalaxyConstraintSystem<'circuit, F, G>,
) -> Vec<F> {
let (pubs_degrees, privs_degrees) = self.calculate_powers(cs, a);
Expand All @@ -142,15 +142,15 @@ impl ProtoGalaxyProver

// ^ that might be moved to 'new'

self.fill_variable_combinations(&mut privs_combinations, &privs_degrees, &a.lhs.round_wtns, &b.lhs.round_wtns);
self.fill_variable_combinations(&mut pubs_combinations, &pubs_degrees, &a.lhs.pubs, &b.lhs.pubs);
self.fill_variable_combinations(&mut privs_combinations, &privs_degrees, &a.round_wtns, &b.round_wtns);
self.fill_variable_combinations(&mut pubs_combinations, &pubs_degrees, &a.pubs, &b.pubs);

let evals = self.evaluate(cs, &pubs_combinations, &privs_combinations);
let pg_challenges = self.combine_challenges(&a.lhs.protostar_challenges, &b.lhs.protostar_challenges);
let pg_challenges = self.combine_challenges(&a.protostar_challenges, &b.protostar_challenges);
let mut cross_terms = combine_cross_terms(evals, layout, pg_challenges);
let cross_terms = self.leave_quotient(&mut cross_terms);

let points = self.prepare_interpolation_points(cs.max_degree, a.lhs.protostar_challenges.len());
let points = self.prepare_interpolation_points(cs.max_degree, a.protostar_challenges.len());
lagrange_interpolate(&points, cross_terms)
}
}
Expand All @@ -159,8 +159,11 @@ impl ProtoGalaxyProver
mod test {
use std::rc::Rc;

use ff::Field;
use halo2::halo2curves::bn256;
use crate::{gate::Gatebb, circuit::{Circuit, Advice}, gadgets::input::input};
use itertools::{Unfold, unfold};
use rand_core::OsRng;
use crate::{gate::Gatebb, circuit::{Circuit, Advice}, gadgets::input::input, constraint_system::CS, witness::{Module, compute_error_term, ProtostarLhsWtns}};

use super::*;

Expand All @@ -175,7 +178,9 @@ mod test {
1,
|args, _| vec![args[0] * args[1]]
), vec![input_vars[0], input_vars[1]])[0];
circuit.constrain(&[input_vars[0], input_vars[1], mul_a_res], Gatebb::<F>::new(2, 3, 1, Rc::new(|args, _| vec![args[0] * args[1] - args[2]]), vec![]));
circuit.constrain(&[input_vars[0], input_vars[1], mul_a_res], Gatebb::<F>::new(2, 3, 1, Rc::new(|args, _|
{let res = vec![args[0] * args[1] - args[2]]; res}
), vec![]));

let mul_b_res = circuit.advice(1, Advice::new(
2,
Expand All @@ -197,24 +202,49 @@ mod test {
let mut run_a = constructed.spawn();
let mut run_b = constructed.spawn();

for (idx, i) in inputs.iter().enumerate() {
run_a.set_ext(*i, F::from((idx + 3) as u64));
run_b.set_ext(*i, F::from((idx + 10) as u64));
for i in inputs {
run_a.set_ext(i, F::random(OsRng));
run_b.set_ext(i, F::random(OsRng));
}

run_a.execute(1);
run_b.execute(1);

let beta = F::from(2);
let beta_a = F::random(OsRng);
let beta_b = F::random(OsRng);

let pgp = ProtoGalaxyProver::new();

let a_wtns = run_a.end(beta);
let b_wtns = run_b.end(beta);
let a_wtns = run_a.end(beta_a);
let b_wtns = run_b.end(beta_b);

let res = pgp.prove(&a_wtns, &b_wtns, &constructed.circuit.cs);
run_a.valid_witness();
run_b.valid_witness();

// Now we create random witnesses with same shape

let a_wtns = ProtostarLhsWtns::random_like(&mut OsRng, &a_wtns.lhs);
let b_wtns = ProtostarLhsWtns::random_like(&mut OsRng, &b_wtns.lhs);

let q = pgp.prove(&a_wtns, &b_wtns, &constructed.circuit.cs);

let a_err = compute_error_term(&a_wtns, &constructed.circuit.cs);
let b_err = compute_error_term(&b_wtns, &constructed.circuit.cs);

let t = F::random(OsRng);
let mut fold_wtns = a_wtns.clone();
fold_wtns.neg();
fold_wtns.add_assign(b_wtns.clone());
fold_wtns.scale(t);
fold_wtns.add_assign(a_wtns.clone());

let fold_err = compute_error_term(&fold_wtns, &constructed.circuit.cs);

let q_eval: F = unfold(F::ONE, |next| {
let tmp = *next;
*next = *next * t;
Some(tmp)
}).zip(q).map(|(pow, c)| c * pow).sum();

assert_eq!(fold_err, a_err + (b_err - a_err) * t + t * (t - F::ONE) * q_eval)

}
}
106 changes: 81 additions & 25 deletions src/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::marker::PhantomData;
use ff::PrimeField;
use halo2::{halo2curves::CurveAffine, arithmetic::best_multiexp};
use itertools::Itertools;
use rand_core::RngCore;

use crate::{gate::Gate, constraint_system::{ProtoGalaxyConstraintSystem, Variable, CS, Visibility, WitnessSpec}, commitment::{CommitmentKey, CkWtns, CtRound, ErrGroup, CkRelaxed}, circuit::{ExternalValue, ConstructedCircuit, PolyOp}, utils::field_precomp::FieldUtils, folding::shape::{ProtostarLhs, ProtostarInstance}};

Expand Down Expand Up @@ -146,6 +147,7 @@ pub trait Module<F> {
fn scale(&mut self, scale: F) -> ();
}

#[derive(Clone)]
pub struct ProtostarLhsWtns<F: PrimeField> {
pub round_wtns: Vec<Vec<F>>,
pub pubs: Vec<Vec<F>>,
Expand All @@ -160,46 +162,56 @@ impl<F: PrimeField> ProtostarLhsWtns<F> {
protostar_challenges: self.protostar_challenges.clone(),
}
}

pub fn random_like<RNG: RngCore>(mut rng: &mut RNG, other: &Self) -> Self {
Self {
round_wtns: other.round_wtns.iter().map(|r| r.iter().map(|_| F::random(&mut rng)).collect_vec()).collect_vec(),
pubs: other.pubs.iter().map(|r| r.iter().map(|_| F::random(&mut rng)).collect_vec()).collect_vec(),
protostar_challenges: other.protostar_challenges.iter().map(|_| F::random(&mut rng)).collect_vec(),
}
}
}

impl<F: PrimeField> Module<F> for ProtostarLhsWtns<F> {
fn add_assign(&mut self, other: Self) -> () {
self.round_wtns.iter_mut().zip_eq(other.round_wtns.iter()).map(|(s, o)| {
s.iter_mut().zip_eq(o.iter()).map(|(s, o)| *s = *s + o)
}).last();
self.pubs.iter_mut().zip_eq(other.pubs.iter()).map(|(s, o)| {
s.iter_mut().zip_eq(o.iter()).map(|(s, o)| *s = *s + o)
}).last();
self.protostar_challenges.iter_mut().zip_eq(other.protostar_challenges.iter()).map(|(s, o)| {
self.round_wtns.iter_mut().zip_eq(other.round_wtns.iter()).for_each(|(s, o)| {
s.iter_mut().zip_eq(o.iter()).for_each(|(s, o)| *s = *s + o)
});
self.pubs.iter_mut().zip_eq(other.pubs.iter()).for_each(|(s, o)| {
s.iter_mut().zip_eq(o.iter()).for_each(|(s, o)| *s = *s + o)
});
self.protostar_challenges.iter_mut().zip_eq(other.protostar_challenges.iter()).for_each(|(s, o)| {
*s = *s + o
}).last();
});
}

fn neg(&mut self) -> () {
self.round_wtns.iter_mut().map(|s| {
s.iter_mut().map(|s| *s = -*s)
}).last();
self.pubs.iter_mut().map(|s| {
s.iter_mut().map(|s| *s = -*s)
}).last();
self.protostar_challenges.iter_mut().map(|s| {
self.round_wtns.iter_mut().for_each(|s| {
s.iter_mut().for_each(|s| *s = -*s)
});
self.pubs.iter_mut().for_each(|s| {
s.iter_mut().for_each(|s| *s = -*s)
});
self.protostar_challenges.iter_mut().for_each(|s| {
*s = -*s
}).last();
});
}

fn scale(&mut self, scale: F) -> () {
self.round_wtns.iter_mut().map(|s| {
s.iter_mut().map(|s| *s = *s * scale)
}).last();
self.pubs.iter_mut().map(|s| {
s.iter_mut().map(|s| *s = *s * scale)
}).last();
self.protostar_challenges.iter_mut().map(|s| {
self.round_wtns.iter_mut().for_each(|s| {
s.iter_mut().for_each(|s| *s = *s * scale)
});
self.pubs.iter_mut().for_each(|s| {
s.iter_mut().for_each(|s| *s = *s * scale)
});
self.protostar_challenges.iter_mut().for_each(|s| {
*s = *s * scale
}).last();
});
}
}


#[derive(Clone)]
pub struct ProtostarWtns<F: PrimeField> {
pub lhs: ProtostarLhsWtns<F>,
pub error: F
Expand Down Expand Up @@ -229,4 +241,48 @@ impl<F: PrimeField> ProtostarWtns<F> {
error: self.error,
}
}
}
}

pub fn compute_error_term<'circuit, F: PrimeField, G: Gate<'circuit, F>>(wtns: &ProtostarLhsWtns<F>, cs: &ProtoGalaxyConstraintSystem<'circuit, F, G>) -> F {
let betas = &wtns.protostar_challenges;
let mut results = vec![];
for constr in cs.iter_non_linear_constraints() {
let input_values: Vec<_> = constr.inputs.iter().map(|&x| match x.visibility {
Visibility::Public => wtns.pubs[x.round][x.index],
Visibility::Private => wtns.round_wtns[x.round][x.index],
}).collect();
results.extend(constr.gate.exec(&input_values));
}

assert!(betas.len() > 0, "No challenges supplied for error_term");
let mut mid = 1 << (betas.len() - 1);
assert!(mid < results.len());
assert!(mid * 2 >= results.len());

// example
// results : |.|.|.|.|.|.|.|
// idx : 0 1 2 3 4 5 6
// mid : ^
// beta^4 : | | | | |1|1|1|
// beta^2 : | | |1|1| | |1|
// beta^1 : | |1| |1| |1| |
//
// split at mid : |.|.|.|.| |.|.|.|
// old idx : 0 1 2 3 4 5 6
//
// Now we multiply right half by beta^4 and element-wise add it to first half inplace
//
// results : | r0 + r4 * beta^4 | r1 + r5 * beta^4 | r2 + r6 * beta^4 | r3 | junk | junk | junk |
//
// Now we forget about junk und repeat with first half until single element is left.

for beta_pow in betas.iter().rev() {
let (left, right) = results.split_at_mut(mid);
for (l, r) in left.iter_mut().zip(right.iter()) {
*l = *l + *r * beta_pow;
}
mid /= 2;
}

results[0]
}

0 comments on commit 42d1c18

Please sign in to comment.