Skip to content

Commit

Permalink
Merge pull request #2520 from o1-labs/dw/mvpoly-monomials-cross-terms
Browse files Browse the repository at this point in the history
MVPoly/monomials: compute cross-terms
  • Loading branch information
dannywillems authored Sep 8, 2024
2 parents daeff99 + bbaba20 commit 590ec1b
Show file tree
Hide file tree
Showing 4 changed files with 325 additions and 2 deletions.
23 changes: 23 additions & 0 deletions mvpoly/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;

use ark_ff::PrimeField;
use kimchi::circuits::expr::{ConstantExpr, Expr};
use rand::RngCore;
Expand Down Expand Up @@ -83,4 +85,25 @@ pub trait MVPoly<F: PrimeField, const N: usize, const D: usize>:
/// For instance, to add the monomial `3 * x_1^2 * x_2^3` to the polynomial,
/// one would call `add_monomial([2, 3], 3)`.
fn add_monomial(&mut self, exponents: [usize; N], coeff: F);

/// Compute the cross-terms as described in [Behind Nova: cross-terms
/// computation for high degree
/// gates](https://hackmd.io/@dannywillems/Syo5MBq90)
///
/// The polynomial must not necessarily be homogeneous. For this reason, the
/// values `u1` and `u2` represents the extra variable that is used to make
/// the polynomial homogeneous.
///
/// The homogeneous degree is supposed to be the one defined by the type of
/// the polynomial, i.e. `D`.
///
/// The output is a map of `D - 1` values that represents the cross-terms
/// for each power of `r`.
fn compute_cross_terms(
&self,
eval1: &[F; N],
eval2: &[F; N],
u1: F,
u2: F,
) -> HashMap<usize, F>;
}
98 changes: 97 additions & 1 deletion mvpoly/src/monomials.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use ark_ff::{One, PrimeField, Zero};
use kimchi::circuits::expr::{ConstantExpr, Expr};
use num_integer::binomial;
use rand::RngCore;
use std::{
collections::HashMap,
Expand All @@ -9,7 +10,7 @@ use std::{

use crate::{
prime,
utils::{naive_prime_factors, PrimeNumberGenerator},
utils::{compute_indices_nested_loop, naive_prime_factors, PrimeNumberGenerator},
MVPoly,
};

Expand Down Expand Up @@ -468,6 +469,101 @@ impl<const N: usize, const D: usize, F: PrimeField> MVPoly<F, N, D> for Sparse<F
.and_modify(|c| *c += coeff)
.or_insert(coeff);
}

fn compute_cross_terms(
&self,
eval1: &[F; N],
eval2: &[F; N],
u1: F,
u2: F,
) -> HashMap<usize, F> {
assert!(
D >= 2,
"The degree of the polynomial must be greater than 2"
);
let mut cross_terms_by_powers_of_r: HashMap<usize, F> = HashMap::new();
// We iterate over each monomial with their respective coefficient
// i.e. we do have something like coeff * x_1^d_1 * x_2^d_2 * ... * x_N^d_N
self.monomials.iter().for_each(|(exponents, coeff)| {
// "Exponents" contains all powers, even the ones that are 0. We must
// get rid of them and keep the index to fetch the correct
// evaluation later
let non_zero_exponents_with_index: Vec<(usize, &usize)> = exponents
.iter()
.enumerate()
.filter(|(_, &d)| d != 0)
.collect();
// coeff = 0 should not happen as we suppose we have a sparse polynomial
// Therefore, skipping a check
let non_zero_exponents: Vec<usize> = non_zero_exponents_with_index
.iter()
.map(|(_, d)| *d)
.copied()
.collect::<Vec<usize>>();
let monomial_degree = non_zero_exponents.iter().sum::<usize>();
let u_degree: usize = D - monomial_degree;
// Will be used to compute the nested sums
// It returns all the indices i_1, ..., i_k for the sums:
// Σ_{i_1 = 0}^{n_1} Σ_{i_2 = 0}^{n_2} ... Σ_{i_k = 0}^{n_k}
let indices =
compute_indices_nested_loop(non_zero_exponents.iter().map(|d| *d + 1).collect());
for i in 0..=u_degree {
// Add the binomial from the homogeneisation
// i.e (u_degree choose i)
let u_binomial_term = binomial(u_degree, i);
// Now, we iterate over all the indices i_1, ..., i_k, i.e. we
// do over the whole sum, and we populate the map depending on
// the power of r
indices.iter().for_each(|indices| {
let sum_indices = indices.iter().sum::<usize>() + i;
// power of r is Σ (n_k - i_k)
let power_r: usize = D - sum_indices;

// If the sum of the indices is 0 or D, we skip the
// computation as the contribution would go in the
// evaluation of the polynomial at each evaluation
// vectors eval1 and eval2
if sum_indices == 0 || sum_indices == D {
return;
}
// Compute
// (n_1 choose i_1) * (n_2 choose i_2) * ... * (n_k choose i_k)
let binomial_term = indices
.iter()
.zip(non_zero_exponents.iter())
.fold(u_binomial_term, |acc, (i, &d)| acc * binomial(d, *i));
let binomial_term = F::from(binomial_term as u64);
// Compute the product x_k^i_k
// We ignore the power as it comes into account for the
// right evaluation.
// NB: we could merge both loops, but we keep them separate
// for readability
let eval_left = indices
.iter()
.zip(non_zero_exponents_with_index.iter())
.fold(F::one(), |acc, (i, (idx, _d))| {
acc * eval1[*idx].pow([*i as u64])
});
// Compute the product x'_k^(n_k - i_k)
let eval_right = indices
.iter()
.zip(non_zero_exponents_with_index.iter())
.fold(F::one(), |acc, (i, (idx, d))| {
acc * eval2[*idx].pow([(*d - *i) as u64])
});
// u1^i * u2^(u_degree - i)
let u = u1.pow([i as u64]) * u2.pow([(u_degree - i) as u64]);
let res = binomial_term * eval_left * eval_right * u;
let res = *coeff * res;
cross_terms_by_powers_of_r
.entry(power_r)
.and_modify(|e| *e += res)
.or_insert(res);
})
}
});
cross_terms_by_powers_of_r
}
}

impl<const N: usize, const D: usize, F: PrimeField> From<prime::Dense<F, N, D>>
Expand Down
10 changes: 10 additions & 0 deletions mvpoly/src/prime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,16 @@ impl<F: PrimeField, const N: usize, const D: usize> MVPoly<F, N, D> for Dense<F,
.unwrap();
self.coeff[inv_idx] += coeff;
}

fn compute_cross_terms(
&self,
_eval1: &[F; N],
_eval2: &[F; N],
_u1: F,
_u2: F,
) -> HashMap<usize, F> {
unimplemented!()
}
}

impl<F: PrimeField, const N: usize, const D: usize> Dense<F, N, D> {
Expand Down
196 changes: 195 additions & 1 deletion mvpoly/tests/monomials.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ark_ff::{One, UniformRand, Zero};
use ark_ff::{Field, One, UniformRand, Zero};
use mina_curves::pasta::Fp;
use mvpoly::{monomials::Sparse, MVPoly};
use rand::Rng;
Expand Down Expand Up @@ -469,3 +469,197 @@ fn test_add_monomial() {
random_c1 * random_eval[0] * random_eval[0] + random_c2 * random_eval[1] * random_eval[1];
assert_eq!(eval_p4, exp_eval_p4);
}

#[test]
fn test_mvpoly_compute_cross_terms_degree_two_unit_test() {
let mut rng = o1_utils::tests::make_test_rng(None);

{
// Homogeneous form is Y^2
let p1 = Sparse::<Fp, 4, 2>::from(Fp::from(1));

let random_eval1: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng));
let random_eval2: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng));
let u1 = Fp::rand(&mut rng);
let u2 = Fp::rand(&mut rng);
let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2);

// We only have one cross-term in this case as degree 2
assert_eq!(cross_terms.len(), 1);
// Cross term of constant is r * (2 u1 u2)
assert_eq!(cross_terms[&1], (u1 * u2).double());
}
}

#[test]
fn test_mvpoly_compute_cross_terms_degree_two() {
let mut rng = o1_utils::tests::make_test_rng(None);
let p1 = unsafe { Sparse::<Fp, 4, 2>::random(&mut rng, None) };
let random_eval1: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng));
let random_eval2: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng));
let u1 = Fp::rand(&mut rng);
let u2 = Fp::rand(&mut rng);
let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2);
// We only have one cross-term in this case
assert_eq!(cross_terms.len(), 1);

let r = Fp::rand(&mut rng);
let random_lincomb: [Fp; 4] = std::array::from_fn(|i| random_eval1[i] + r * random_eval2[i]);

let lhs = p1.homogeneous_eval(&random_lincomb, u1 + r * u2);

let rhs = {
let eval1_hom = p1.homogeneous_eval(&random_eval1, u1);
let eval2_hom = p1.homogeneous_eval(&random_eval2, u2);
let cross_terms_eval = cross_terms.iter().fold(Fp::zero(), |acc, (power, term)| {
acc + r.pow([*power as u64]) * term
});
eval1_hom + r * r * eval2_hom + cross_terms_eval
};
assert_eq!(lhs, rhs);
}

#[test]
fn test_mvpoly_compute_cross_terms_degree_three() {
let mut rng = o1_utils::tests::make_test_rng(None);
let p1 = unsafe { Sparse::<Fp, 4, 3>::random(&mut rng, None) };
let random_eval1: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng));
let random_eval2: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng));
let u1 = Fp::rand(&mut rng);
let u2 = Fp::rand(&mut rng);
let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2);

assert_eq!(cross_terms.len(), 2);

let r = Fp::rand(&mut rng);
let random_lincomb: [Fp; 4] = std::array::from_fn(|i| random_eval1[i] + r * random_eval2[i]);

let lhs = p1.homogeneous_eval(&random_lincomb, u1 + r * u2);

let rhs = {
let eval1_hom = p1.homogeneous_eval(&random_eval1, u1);
let eval2_hom = p1.homogeneous_eval(&random_eval2, u2);
let cross_terms_eval = cross_terms.iter().fold(Fp::zero(), |acc, (power, term)| {
acc + r.pow([*power as u64]) * term
});
let r_cube = r.pow([3]);
eval1_hom + r_cube * eval2_hom + cross_terms_eval
};
assert_eq!(lhs, rhs);
}

#[test]
fn test_mvpoly_compute_cross_terms_degree_four() {
let mut rng = o1_utils::tests::make_test_rng(None);
let p1 = unsafe { Sparse::<Fp, 6, 4>::random(&mut rng, None) };
let random_eval1: [Fp; 6] = std::array::from_fn(|_| Fp::rand(&mut rng));
let random_eval2: [Fp; 6] = std::array::from_fn(|_| Fp::rand(&mut rng));
let u1 = Fp::rand(&mut rng);
let u2 = Fp::rand(&mut rng);
let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2);

assert_eq!(cross_terms.len(), 3);

let r = Fp::rand(&mut rng);
let random_lincomb: [Fp; 6] = std::array::from_fn(|i| random_eval1[i] + r * random_eval2[i]);

let lhs = p1.homogeneous_eval(&random_lincomb, u1 + r * u2);

let rhs = {
let eval1_hom = p1.homogeneous_eval(&random_eval1, u1);
let eval2_hom = p1.homogeneous_eval(&random_eval2, u2);
let cross_terms_eval = cross_terms.iter().fold(Fp::zero(), |acc, (power, term)| {
acc + r.pow([*power as u64]) * term
});
let r_four = r.pow([4]);
eval1_hom + r_four * eval2_hom + cross_terms_eval
};
assert_eq!(lhs, rhs);
}

#[test]
fn test_mvpoly_compute_cross_terms_degree_five() {
let mut rng = o1_utils::tests::make_test_rng(None);
let p1 = unsafe { Sparse::<Fp, 3, 5>::random(&mut rng, None) };
let random_eval1: [Fp; 3] = std::array::from_fn(|_| Fp::rand(&mut rng));
let random_eval2: [Fp; 3] = std::array::from_fn(|_| Fp::rand(&mut rng));
let u1 = Fp::rand(&mut rng);
let u2 = Fp::rand(&mut rng);
let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2);

assert_eq!(cross_terms.len(), 4);

let r = Fp::rand(&mut rng);
let random_lincomb: [Fp; 3] = std::array::from_fn(|i| random_eval1[i] + r * random_eval2[i]);

let lhs = p1.homogeneous_eval(&random_lincomb, u1 + r * u2);

let rhs = {
let eval1_hom = p1.homogeneous_eval(&random_eval1, u1);
let eval2_hom = p1.homogeneous_eval(&random_eval2, u2);
let cross_terms_eval = cross_terms.iter().fold(Fp::zero(), |acc, (power, term)| {
acc + r.pow([*power as u64]) * term
});
let r_five = r.pow([5]);
eval1_hom + r_five * eval2_hom + cross_terms_eval
};
assert_eq!(lhs, rhs);
}

#[test]
fn test_mvpoly_compute_cross_terms_degree_six() {
let mut rng = o1_utils::tests::make_test_rng(None);
let p1 = unsafe { Sparse::<Fp, 4, 6>::random(&mut rng, None) };
let random_eval1: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng));
let random_eval2: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng));
let u1 = Fp::rand(&mut rng);
let u2 = Fp::rand(&mut rng);
let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2);

assert_eq!(cross_terms.len(), 5);

let r = Fp::rand(&mut rng);
let random_lincomb: [Fp; 4] = std::array::from_fn(|i| random_eval1[i] + r * random_eval2[i]);

let lhs = p1.homogeneous_eval(&random_lincomb, u1 + r * u2);

let rhs = {
let eval1_hom = p1.homogeneous_eval(&random_eval1, u1);
let eval2_hom = p1.homogeneous_eval(&random_eval2, u2);
let cross_terms_eval = cross_terms.iter().fold(Fp::zero(), |acc, (power, term)| {
acc + r.pow([*power as u64]) * term
});
let r_six = r.pow([6]);
eval1_hom + r_six * eval2_hom + cross_terms_eval
};
assert_eq!(lhs, rhs);
}

#[test]
fn test_mvpoly_compute_cross_terms_degree_seven() {
let mut rng = o1_utils::tests::make_test_rng(None);
let p1 = unsafe { Sparse::<Fp, 4, 7>::random(&mut rng, None) };
let random_eval1: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng));
let random_eval2: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng));
let u1 = Fp::rand(&mut rng);
let u2 = Fp::rand(&mut rng);
let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2);

assert_eq!(cross_terms.len(), 6);

let r = Fp::rand(&mut rng);
let random_lincomb: [Fp; 4] = std::array::from_fn(|i| random_eval1[i] + r * random_eval2[i]);

let lhs = p1.homogeneous_eval(&random_lincomb, u1 + r * u2);

let rhs = {
let eval1_hom = p1.homogeneous_eval(&random_eval1, u1);
let eval2_hom = p1.homogeneous_eval(&random_eval2, u2);
let cross_terms_eval = cross_terms.iter().fold(Fp::zero(), |acc, (power, term)| {
acc + r.pow([*power as u64]) * term
});
let r_seven = r.pow([7]);
eval1_hom + r_seven * eval2_hom + cross_terms_eval
};
assert_eq!(lhs, rhs);
}

0 comments on commit 590ec1b

Please sign in to comment.