Skip to content

Commit

Permalink
Avoid deeply nested LinearCombinations in `EvaluationsVar::interpol…
Browse files Browse the repository at this point in the history
…ate_and_evaluate` (#145)

Co-authored-by: Pratyush Mishra <[email protected]>
  • Loading branch information
winderica and Pratyush authored Sep 13, 2024
1 parent 4fb0adf commit 381abcc
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 100 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@

### Bug Fixes

- [\#145](https://github.com/arkworks-rs/r1cs-std/pull/145)
- Avoid deeply nested `LinearCombinations` in `EvaluationsVar::interpolate_and_evaluate` to fix the stack overflow issue when calling `.value()` on the evaluation result.

## 0.4.0

- [\#117](https://github.com/arkworks-rs/r1cs-std/pull/117) Fix result of `precomputed_base_scalar_mul_le` to not discard previous value.
Expand Down
204 changes: 104 additions & 100 deletions src/poly/evaluations/univariate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,13 @@ impl<F: PrimeField> EvaluationsVar<F> {
.as_ref()
.expect("lagrange interpolator has not been initialized. ");
let lagrange_coeffs = self.compute_lagrange_coefficients(interpolation_point)?;
let mut interpolation: FpVar<F> = FpVar::zero();
for i in 0..lagrange_interpolator.domain_order {
let intermediate = &lagrange_coeffs[i] * &self.evals[i];
interpolation += &intermediate
}

let interpolation = lagrange_coeffs
.iter()
.zip(&self.evals)
.take(lagrange_interpolator.domain_order)
.map(|(coeff, eval)| coeff * eval)
.sum::<FpVar<F>>();

Ok(interpolation)
}
Expand Down Expand Up @@ -208,31 +210,33 @@ impl<F: PrimeField> EvaluationsVar<F> {
let alpha_coset_offset_inv =
interpolation_point.mul_by_inverse_unchecked(&self.domain.offset())?;

// `res` stores the sum of all lagrange polynomials evaluated at alpha
let mut res = FpVar::<F>::zero();

let domain_size = self.domain.size() as usize;
for i in 0..domain_size {
// a'^{-1} where a is the base coset element
let subgroup_point_inv = subgroup_points[(domain_size - i) % domain_size];
debug_assert_eq!(subgroup_points[i] * subgroup_point_inv, F::one());
// alpha * offset^{-1} * a'^{-1} - 1
let lag_denom = &alpha_coset_offset_inv * subgroup_point_inv - F::one();
// lag_denom cannot be zero, so we use `unchecked`.
//
// Proof: lag_denom is zero if and only if alpha * (coset_offset *
// subgroup_point)^{-1} == 1. This can happen only if `alpha` is
// itself in the coset.
//
// Earlier we asserted that `lhs_numerator` is not zero.
// Since `lhs_numerator` is just the vanishing polynomial for the coset
// evaluated at `alpha`, and since this is non-zero, `alpha` is not
// in the coset.
let lag_coeff = lhs.mul_by_inverse_unchecked(&lag_denom)?;

let lag_interpoland = &self.evals[i] * lag_coeff;
res += lag_interpoland
}

// `evals` stores all lagrange polynomials evaluated at alpha
let evals = (0..domain_size)
.map(|i| {
// a'^{-1} where a is the base coset element
let subgroup_point_inv = subgroup_points[(domain_size - i) % domain_size];
debug_assert_eq!(subgroup_points[i] * subgroup_point_inv, F::one());
// alpha * offset^{-1} * a'^{-1} - 1
let lag_denom = &alpha_coset_offset_inv * subgroup_point_inv - F::one();
// lag_denom cannot be zero, so we use `unchecked`.
//
// Proof: lag_denom is zero if and only if alpha * (coset_offset *
// subgroup_point)^{-1} == 1. This can happen only if `alpha` is
// itself in the coset.
//
// Earlier we asserted that `lhs_numerator` is not zero.
// Since `lhs_numerator` is just the vanishing polynomial for the coset
// evaluated at `alpha`, and since this is non-zero, `alpha` is not
// in the coset.
let lag_coeff = lhs.mul_by_inverse_unchecked(&lag_denom)?;

Ok(&self.evals[i] * lag_coeff)
})
.collect::<Result<Vec<_>, _>>()?;

let res = evals.iter().sum();

Ok(res)
}
Expand Down Expand Up @@ -378,87 +382,87 @@ mod tests {

#[test]
fn test_interpolate_constant_offset() {
let mut rng = test_rng();
let poly = DensePolynomial::rand(15, &mut rng);
let gen = Fr::get_root_of_unity(1 << 4).unwrap();
assert_eq!(gen.pow(&[1 << 4]), Fr::one());
let domain = Radix2DomainVar::new(
gen,
4, // 2^4 = 16
FpVar::constant(Fr::rand(&mut rng)),
)
.unwrap();
let mut coset_point = domain.offset().value().unwrap();
let mut oracle_evals = Vec::new();
for _ in 0..(1 << 4) {
oracle_evals.push(poly.evaluate(&coset_point));
coset_point *= gen;
}
let cs = ConstraintSystem::new_ref();
let evaluations_fp: Vec<_> = oracle_evals
.iter()
.map(|x| FpVar::new_input(ns!(cs, "evaluations"), || Ok(x)).unwrap())
.collect();
let evaluations_var = EvaluationsVar::from_vec_and_domain(evaluations_fp, domain, true);

let interpolate_point = Fr::rand(&mut rng);
let interpolate_point_fp =
FpVar::new_input(ns!(cs, "interpolate point"), || Ok(interpolate_point)).unwrap();

let expected = poly.evaluate(&interpolate_point);

let actual = evaluations_var
.interpolate_and_evaluate(&interpolate_point_fp)
.unwrap()
.value()
.unwrap();
for n in [11, 12, 13, 14] {
let mut rng = test_rng();

let poly = DensePolynomial::rand((1 << n) - 1, &mut rng);
let gen = Fr::get_root_of_unity(1 << n).unwrap();
assert_eq!(gen.pow(&[1 << n]), Fr::one());
let domain = Radix2DomainVar::new(gen, n, FpVar::constant(Fr::rand(&mut rng))).unwrap();
let mut coset_point = domain.offset().value().unwrap();
let mut oracle_evals = Vec::new();
for _ in 0..(1 << n) {
oracle_evals.push(poly.evaluate(&coset_point));
coset_point *= gen;
}
let cs = ConstraintSystem::new_ref();
let evaluations_fp: Vec<_> = oracle_evals
.iter()
.map(|x| FpVar::new_input(ns!(cs, "evaluations"), || Ok(x)).unwrap())
.collect();
let evaluations_var = EvaluationsVar::from_vec_and_domain(evaluations_fp, domain, true);

let interpolate_point = Fr::rand(&mut rng);
let interpolate_point_fp =
FpVar::new_input(ns!(cs, "interpolate point"), || Ok(interpolate_point)).unwrap();

let expected = poly.evaluate(&interpolate_point);

let actual = evaluations_var
.interpolate_and_evaluate(&interpolate_point_fp)
.unwrap()
.value()
.unwrap();

assert_eq!(actual, expected);
assert!(cs.is_satisfied().unwrap());
println!("number of constraints: {}", cs.num_constraints())
assert_eq!(actual, expected);
assert!(cs.is_satisfied().unwrap());
println!("number of constraints: {}", cs.num_constraints());
}
}

#[test]
fn test_interpolate_non_constant_offset() {
let mut rng = test_rng();
let poly = DensePolynomial::rand(15, &mut rng);
let gen = Fr::get_root_of_unity(1 << 4).unwrap();
assert_eq!(gen.pow(&[1 << 4]), Fr::one());
let cs = ConstraintSystem::new_ref();
let domain = Radix2DomainVar::new(
gen,
4, // 2^4 = 16
FpVar::new_witness(ns!(cs, "offset"), || Ok(Fr::rand(&mut rng))).unwrap(),
)
.unwrap();
let mut coset_point = domain.offset().value().unwrap();
let mut oracle_evals = Vec::new();
for _ in 0..(1 << 4) {
oracle_evals.push(poly.evaluate(&coset_point));
coset_point *= gen;
}
for n in [11, 12, 13, 14] {
let mut rng = test_rng();
let poly = DensePolynomial::rand((1 << n) - 1, &mut rng);
let gen = Fr::get_root_of_unity(1 << n).unwrap();
assert_eq!(gen.pow(&[1 << n]), Fr::one());
let cs = ConstraintSystem::new_ref();
let domain = Radix2DomainVar::new(
gen,
n,
FpVar::new_witness(ns!(cs, "offset"), || Ok(Fr::rand(&mut rng))).unwrap(),
)
.unwrap();
let mut coset_point = domain.offset().value().unwrap();
let mut oracle_evals = Vec::new();
for _ in 0..(1 << n) {
oracle_evals.push(poly.evaluate(&coset_point));
coset_point *= gen;
}

let evaluations_fp: Vec<_> = oracle_evals
.iter()
.map(|x| FpVar::new_input(ns!(cs, "evaluations"), || Ok(x)).unwrap())
.collect();
let evaluations_var = EvaluationsVar::from_vec_and_domain(evaluations_fp, domain, true);
let evaluations_fp: Vec<_> = oracle_evals
.iter()
.map(|x| FpVar::new_input(ns!(cs, "evaluations"), || Ok(x)).unwrap())
.collect();
let evaluations_var = EvaluationsVar::from_vec_and_domain(evaluations_fp, domain, true);

let interpolate_point = Fr::rand(&mut rng);
let interpolate_point_fp =
FpVar::new_input(ns!(cs, "interpolate point"), || Ok(interpolate_point)).unwrap();
let interpolate_point = Fr::rand(&mut rng);
let interpolate_point_fp =
FpVar::new_input(ns!(cs, "interpolate point"), || Ok(interpolate_point)).unwrap();

let expected = poly.evaluate(&interpolate_point);
let expected = poly.evaluate(&interpolate_point);

let actual = evaluations_var
.interpolate_and_evaluate(&interpolate_point_fp)
.unwrap()
.value()
.unwrap();
let actual = evaluations_var
.interpolate_and_evaluate(&interpolate_point_fp)
.unwrap()
.value()
.unwrap();

assert_eq!(actual, expected);
assert!(cs.is_satisfied().unwrap());
println!("number of constraints: {}", cs.num_constraints())
assert_eq!(actual, expected);
assert!(cs.is_satisfied().unwrap());
println!("number of constraints: {}", cs.num_constraints());
}
}

#[test]
Expand Down

0 comments on commit 381abcc

Please sign in to comment.