Skip to content

Commit

Permalink
Merge pull request #2715 from o1-labs/dw/improve-from-variable-implem…
Browse files Browse the repository at this point in the history
…entation

MVPoly: improve from_variable and support "next row" in from_expr
  • Loading branch information
dannywillems authored Oct 17, 2024
2 parents 1f8cf43 + 594f689 commit 0dab255
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 105 deletions.
50 changes: 30 additions & 20 deletions mvpoly/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
//! "Expressions", as defined in the [kimchi] crate, can be converted into a
//! multi-variate polynomial using the `from_expr` method.

use std::collections::HashMap;

use ark_ff::PrimeField;
use kimchi::circuits::{
expr::{ConstantExpr, ConstantExprInner, ConstantTerm, Expr, ExprInner, Operations, Variable},
gate::CurrOrNext,
use kimchi::circuits::expr::{
ConstantExpr, ConstantExprInner, ConstantTerm, Expr, ExprInner, Operations, Variable,
};
use rand::RngCore;
use std::collections::HashMap;

pub mod monomials;
pub mod pbt;
Expand Down Expand Up @@ -86,7 +84,6 @@ pub trait MVPoly<F: PrimeField, const N: usize, const D: usize>:
/// speed up the computation.
fn eval(&self, x: &[F; N]) -> F;


/// Build the univariate polynomial `x_i` from the variable `i`.
/// The conversion into the type `usize` is unspecified by this trait. It
/// is left to the trait implementation.
Expand All @@ -95,7 +92,12 @@ pub trait MVPoly<F: PrimeField, const N: usize, const D: usize>:
/// used.
/// For [crate::monomials], the output must be the index of the variable,
/// starting from `0`.
fn from_variable<Column: Into<usize>>(var: Column) -> Self;
///
/// The parameter `offset_next_row` is an optional argument that is used to
/// support the case where the "next row" is used. In this case, the type
/// parameter `N` must include this offset (i.e. if 4 variables are in ued,
/// N should be at least `8 = 2 * 4`).
fn from_variable<Column: Into<usize>>(var: Variable<Column>, offset_next_row: Option<usize>) -> Self;

fn from_constant<ChallengeTerm: Clone>(op: Operations<ConstantExprInner<F, ChallengeTerm>>) -> Self {
use kimchi::circuits::expr::Operations::*;
Expand Down Expand Up @@ -147,7 +149,16 @@ pub trait MVPoly<F: PrimeField, const N: usize, const D: usize>:
/// "the expression framework".
/// In the near future, the "expression framework" should be moved also into
/// this library.
fn from_expr<Column: Into<usize>, ChallengeTerm: Clone>(expr: Expr<ConstantExpr<F, ChallengeTerm>, Column>) -> Self {
///
/// The mapping from variable to the user is left unspecified by this trait
/// and is left to the implementation. The conversion of a variable into an
/// index is done by the trait requirement `Into<usize>` on the column type.
///
/// The parameter `offset_next_row` is an optional argument that is used to
/// support the case where the "next row" is used. In this case, the type
/// parameter `N` must include this offset (i.e. if 4 variables are in ued,
/// N should be at least `8 = 2 * 4`).
fn from_expr<Column: Into<usize>, ChallengeTerm: Clone>(expr: Expr<ConstantExpr<F, ChallengeTerm>, Column>, offset_next_row: Option<usize>) -> Self {
use kimchi::circuits::expr::Operations::*;

match expr {
Expand All @@ -160,38 +171,37 @@ pub trait MVPoly<F: PrimeField, const N: usize, const D: usize>:
unimplemented!("Not used in this context")
}
ExprInner::Constant(c) => Self::from_constant(c),
ExprInner::Cell(Variable { col, row }) => {
assert_eq!(row, CurrOrNext::Curr, "Only current row is supported for now. You cannot reference the next row");
Self::from_variable(col)
ExprInner::Cell(var) => {
Self::from_variable::<Column>(var, offset_next_row)
}
}
}
Add(e1, e2) => {
let p1 = Self::from_expr(*e1);
let p2 = Self::from_expr(*e2);
let p1 = Self::from_expr::<Column, ChallengeTerm>(*e1, offset_next_row);
let p2 = Self::from_expr::<Column, ChallengeTerm>(*e2, offset_next_row);
p1 + p2
}
Sub(e1, e2) => {
let p1 = Self::from_expr(*e1);
let p2 = Self::from_expr(*e2);
let p1 = Self::from_expr::<Column, ChallengeTerm>(*e1, offset_next_row);
let p2 = Self::from_expr::<Column, ChallengeTerm>(*e2, offset_next_row);
p1 - p2
}
Mul(e1, e2) => {
let p1 = Self::from_expr(*e1);
let p2 = Self::from_expr(*e2);
let p1 = Self::from_expr::<Column, ChallengeTerm>(*e1, offset_next_row);
let p2 = Self::from_expr::<Column, ChallengeTerm>(*e2, offset_next_row);
p1 * p2
}
Double(p) => {
let p = Self::from_expr(*p);
let p = Self::from_expr::<Column, ChallengeTerm>(*p, offset_next_row);
p.double()
}
Square(p) => {
let p = Self::from_expr(*p);
let p = Self::from_expr::<Column, ChallengeTerm>(*p, offset_next_row);
p.clone() * p.clone()
}
Pow(c, e) => {
// FIXME: dummy implementation
let p = Self::from_expr(*c);
let p = Self::from_expr::<Column, ChallengeTerm>(*c, offset_next_row);
let mut result = p.clone();
for _ in 0..e {
result = result.clone() * p.clone();
Expand Down
27 changes: 24 additions & 3 deletions mvpoly/src/monomials.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use ark_ff::{One, PrimeField, Zero};
use kimchi::circuits::{expr::Variable, gate::CurrOrNext};
use num_integer::binomial;
use rand::RngCore;
use std::{
Expand Down Expand Up @@ -295,10 +296,30 @@ impl<const N: usize, const D: usize, F: PrimeField> MVPoly<F, N, D> for Sparse<F
prime::Dense::random(rng, max_degree).into()
}

fn from_variable<Column: Into<usize>>(var: Column) -> Self {
let var_usize: usize = var.into();
fn from_variable<Column: Into<usize>>(
var: Variable<Column>,
offset_next_row: Option<usize>,
) -> Self {
let Variable { col, row } = var;
// Manage offset
if row == CurrOrNext::Next {
assert!(
offset_next_row.is_some(),
"The offset must be provided for the next row"
);
}
let offset = if row == CurrOrNext::Curr {
0
} else {
offset_next_row.unwrap()
};

// Build the corresponding monomial
let var_usize: usize = col.into();
let idx = offset + var_usize;

let mut monomials = HashMap::new();
let exponents: [usize; N] = std::array::from_fn(|i| if i == var_usize { 1 } else { 0 });
let exponents: [usize; N] = std::array::from_fn(|i| if i == idx { 1 } else { 0 });
monomials.insert(exponents, F::one());
Self { monomials }
}
Expand Down
30 changes: 30 additions & 0 deletions mvpoly/src/pbt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,33 @@ pub fn test_is_multilinear<F: PrimeField, const N: usize, const D: usize, T: MVP
assert!(!p.is_multilinear());
}
}

pub fn test_is_constant<F: PrimeField, const N: usize, const D: usize, T: MVPoly<F, N, D>>() {
let mut rng = o1_utils::tests::make_test_rng(None);
let c = F::rand(&mut rng);
let p = T::from(c);
assert!(p.is_constant());

let p = T::zero();
assert!(p.is_constant());

let p = {
let mut res = T::zero();
let monomial: [usize; N] = std::array::from_fn(|i| if i == 0 { 1 } else { 0 });
res.add_monomial(monomial, F::one());
res
};
assert!(!p.is_constant());

let p = {
let mut res = T::zero();
let monomial: [usize; N] = std::array::from_fn(|i| if i == 1 { 1 } else { 0 });
res.add_monomial(monomial, F::one());
res
};
assert!(!p.is_constant());

// This might be flaky
let p = unsafe { T::random(&mut rng, None) };
assert!(!p.is_constant());
}
65 changes: 49 additions & 16 deletions mvpoly/src/prime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ use std::{
};

use ark_ff::{One, PrimeField, Zero};
use kimchi::circuits::{expr::Variable, gate::CurrOrNext};
use num_integer::binomial;
use o1_utils::FieldHelpers;
use rand::{Rng, RngCore};
Expand Down Expand Up @@ -337,18 +338,38 @@ impl<F: PrimeField, const N: usize, const D: usize> MVPoly<F, N, D> for Dense<F,
})
}

fn from_variable<Column: Into<usize>>(var: Column) -> Self {
let mut res = Self::zero();
fn from_variable<Column: Into<usize>>(
var: Variable<Column>,
offset_next_row: Option<usize>,
) -> Self {
let Variable { col, row } = var;
if row == CurrOrNext::Next {
assert!(
offset_next_row.is_some(),
"The offset for the next row must be provided"
);
}
let offset = if row == CurrOrNext::Curr {
0
} else {
offset_next_row.unwrap()
};
let var_usize: usize = col.into();

let mut prime_gen = PrimeNumberGenerator::new();
let primes = prime_gen.get_first_nth_primes(N);
let var_usize: usize = var.into();
assert!(primes.contains(&var_usize), "The usize representation of the variable must be a prime number, and unique for each variable");
let inv_var = res

let prime_idx = primes.iter().position(|&x| x == var_usize).unwrap();
let idx = prime_gen.get_nth_prime(prime_idx + offset + 1);

let mut res = Self::zero();
let inv_idx = res
.normalized_indices
.iter()
.position(|&x| x == var_usize)
.position(|&x| x == idx)
.unwrap();
res[inv_var] = F::one();
res[inv_idx] = F::one();
res
}

Expand Down Expand Up @@ -674,28 +695,40 @@ impl<F: PrimeField, const N: usize, const D: usize> Eq for Dense<F, N, D> {}
impl<F: PrimeField, const N: usize, const D: usize> Debug for Dense<F, N, D> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
let mut prime_gen = PrimeNumberGenerator::new();
self.coeff.iter().enumerate().for_each(|(i, c)| {
let primes = prime_gen.get_first_nth_primes(N);
let coeff: Vec<_> = self
.coeff
.iter()
.enumerate()
.filter(|(_i, c)| *c != &F::zero())
.collect();
// Print 0 if the polynomial is zero
if coeff.is_empty() {
write!(f, "0").unwrap();
return Ok(());
}
let l = coeff.len();
coeff.into_iter().for_each(|(i, c)| {
let normalized_idx = self.normalized_indices[i];
if normalized_idx == 1 {
if normalized_idx == 1 && *c != F::one() {
write!(f, "{}", c.to_biguint()).unwrap();
} else {
let prime_decomposition = naive_prime_factors(normalized_idx, &mut prime_gen);
write!(f, "{}", c.to_biguint()).unwrap();
if *c != F::one() {
write!(f, "{}", c.to_biguint()).unwrap();
}
prime_decomposition.iter().for_each(|(p, d)| {
// FIXME: not correct
let inv_p = self
.normalized_indices
.iter()
.position(|&x| x == *p)
.unwrap();
let inv_p = primes.iter().position(|&x| x == *p).unwrap();
if *d > 1 {
write!(f, "x_{}^{}", inv_p, d).unwrap();
} else {
write!(f, "x_{}", inv_p).unwrap();
}
});
}
if i != self.coeff.len() - 1 {
// Avoid printing the last `+` or if the polynomial is a single
// monomial
if i != l - 1 && l != 1 {
write!(f, " + ").unwrap();
}
});
Expand Down
Loading

0 comments on commit 0dab255

Please sign in to comment.