Skip to content

Commit

Permalink
Add GKR implementation of Grand Product lookups (#620)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson authored Jul 15, 2024
1 parent 7a0bdde commit 437ea59
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 27 deletions.
104 changes: 96 additions & 8 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,84 @@
use num_traits::Zero;

use crate::core::backend::CpuBackend;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer};
use crate::core::fields::Field;
use crate::core::lookups::gkr_prover::{
correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer,
};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::UnivariatePoly;

impl GkrOps for CpuBackend {
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField> {
Mle::new(gen_eq_evals(y, v))
}

fn next_layer(_layer: &Layer<Self>) -> Layer<Self> {
todo!()
fn next_layer(layer: &Layer<Self>) -> Layer<Self> {
match layer {
Layer::GrandProduct(layer) => next_grand_product_layer(layer),
Layer::_LogUp(_) => todo!(),
}
}

fn sum_as_poly_in_first_variable(
_h: &GkrMultivariatePolyOracle<'_, Self>,
_claim: SecureField,
) -> crate::core::lookups::utils::UnivariatePoly<SecureField> {
todo!()
h: &GkrMultivariatePolyOracle<'_, Self>,
claim: SecureField,
) -> UnivariatePoly<SecureField> {
let n_variables = h.n_variables();
assert!(!n_variables.is_zero());
let n_terms = 1 << (n_variables - 1);
let eq_evals = h.eq_evals;
// Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube.
let y = eq_evals.y();
let input_layer = &h.input_layer;

let (mut eval_at_0, mut eval_at_2) = match input_layer {
Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_terms),
Layer::_LogUp(_) => todo!(),
};

eval_at_0 *= h.eq_fixed_var_correction;
eval_at_2 *= h.eq_fixed_var_correction;
correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables)
}
}

/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`.
///
/// Output of the form: `(eval_at_0, eval_at_2)`.
fn eval_grand_product_sum(
eq_evals: &EqEvals<CpuBackend>,
input_layer: &Mle<CpuBackend, SecureField>,
n_terms: usize,
) -> (SecureField, SecureField) {
let mut eval_at_0 = SecureField::zero();
let mut eval_at_2 = SecureField::zero();

for i in 0..n_terms {
// Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`.
let inp_at_r0i0 = input_layer[i * 2];
let inp_at_r0i1 = input_layer[i * 2 + 1];
let inp_at_r1i0 = input_layer[(n_terms + i) * 2];
let inp_at_r1i1 = input_layer[(n_terms + i) * 2 + 1];
// Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)`
// => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)`
let inp_at_r2i0 = inp_at_r1i0.double() - inp_at_r0i0;
let inp_at_r2i1 = inp_at_r1i1.double() - inp_at_r0i1;

// Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i))`.
let prod_at_r2i = inp_at_r2i0 * inp_at_r2i1;
let prod_at_r0i = inp_at_r0i0 * inp_at_r0i1;

let eq_eval_at_0i = eq_evals[i];
eval_at_0 += eq_eval_at_0i * prod_at_r0i;
eval_at_2 += eq_eval_at_0i * prod_at_r2i;
}

(eval_at_0, eval_at_2)
}

/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`.
///
/// Evaluations are returned in bit-reversed order.
Expand All @@ -40,15 +99,24 @@ fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec<SecureField> {
evals
}

fn next_grand_product_layer(layer: &Mle<CpuBackend, SecureField>) -> Layer<CpuBackend> {
let res = layer.array_chunks().map(|&[a, b]| a * b).collect();
Layer::GrandProduct(Mle::new(res))
}

#[cfg(test)]
mod tests {
use num_traits::{One, Zero};

use crate::core::backend::CpuBackend;
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::utils::eq;
use crate::core::test_utils::test_channel;

#[test]
fn gen_eq_evals() {
Expand All @@ -69,4 +137,24 @@ mod tests {
]
);
}

#[test]
fn grand_product_works() -> Result<(), GkrError> {
const N: usize = 1 << 5;
let values = test_channel().draw_felts(N);
let product = values.iter().product::<SecureField>();
let col = Mle::<CpuBackend, SecureField>::new(values);
let input_layer = Layer::GrandProduct(col.clone());
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);

let GkrArtifact {
ood_point: r,
claims_to_verify_by_instance,
..
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

assert_eq!(proof.output_claims_by_instance, [vec![product]]);
assert_eq!(claims_to_verify_by_instance, [vec![col.eval_at_point(&r)]]);
Ok(())
}
}
139 changes: 126 additions & 13 deletions crates/prover/src/core/lookups/gkr_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ use super::sumcheck::MultivariatePolyOracle;
use super::utils::{eq, random_linear_combination, UnivariatePoly};
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{Field, FieldExpOps};
use crate::core::lookups::sumcheck;

pub trait GkrOps: MleOps<SecureField> {
pub trait GkrOps: MleOps<BaseField> + MleOps<SecureField> {
/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`.
///
/// Note [`Mle`] stores values in bit-reversed order.
Expand Down Expand Up @@ -85,13 +87,16 @@ impl<B: ColumnOps<SecureField>> Deref for EqEvals<B> {
/// [LogUp]: https://eprint.iacr.org/2023/1284.pdf
pub enum Layer<B: GkrOps> {
_LogUp(B),
_GrandProduct(B),
GrandProduct(Mle<B, SecureField>),
}

impl<B: GkrOps> Layer<B> {
/// Returns the number of variables used to interpolate the layer's gate values.
fn n_variables(&self) -> usize {
todo!()
match self {
Self::_LogUp(_) => todo!(),
Self::GrandProduct(mle) => mle.n_variables(),
}
}

/// Produces the next layer from the current layer.
Expand All @@ -112,7 +117,28 @@ impl<B: GkrOps> Layer<B> {

/// Returns each column output if the layer is an output layer, otherwise returns an `Err`.
fn try_into_output_layer_values(self) -> Result<Vec<SecureField>, NotOutputLayerError> {
todo!()
if !self.is_output_layer() {
return Err(NotOutputLayerError);
}

Ok(match self {
Self::GrandProduct(col) => {
vec![col.at(0)]
}
Self::_LogUp(_) => todo!(),
})
}

/// Returns a transformed layer with the first variable of each column fixed to `assignment`.
fn fix_first_variable(self, x0: SecureField) -> Self {
if self.n_variables() == 0 {
return self;
}

match self {
Self::_LogUp(_) => todo!(),
Self::GrandProduct(mle) => Self::GrandProduct(mle.fix_first_variable(x0)),
}
}

/// Represents the next GKR layer evaluation as a multivariate polynomial which uses this GKR
Expand Down Expand Up @@ -145,16 +171,37 @@ impl<B: GkrOps> Layer<B> {
fn into_multivariate_poly(
self,
_lambda: SecureField,
_eq_evals: &EqEvals<B>,
eq_evals: &EqEvals<B>,
) -> GkrMultivariatePolyOracle<'_, B> {
todo!()
GkrMultivariatePolyOracle {
eq_evals,
input_layer: self,
eq_fixed_var_correction: SecureField::one(),
}
}
}

#[derive(Debug)]
struct NotOutputLayerError;

/// A multivariate polynomial that expresses the relation between two consecutive GKR layers.
/// Multivariate polynomial `P` that expresses the relation between two consecutive GKR layers.
///
/// When the input layer is [`Layer::GrandProduct`] (represented by multilinear column `inp`)
/// the polynomial represents:
///
/// ```text
/// P(x) = eq(x, y) * inp(x, 0) * inp(x, 1)
/// ```
///
/// When the input layer is LogUp (represented by multilinear columns `inp_numer` and
/// `inp_denom`) the polynomial represents:
///
/// ```text
/// numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)
/// denom(x) = inp_denom(x, 0) * inp_denom(x, 1)
///
/// P(x) = eq(x, y) * (numer(x) + lambda * denom(x))
/// ```
pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> {
/// `eq_evals` passed by `Layer::into_multivariate_poly()`.
pub eq_evals: &'a EqEvals<B>,
Expand All @@ -164,15 +211,26 @@ pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> {

impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> {
fn n_variables(&self) -> usize {
todo!()
self.input_layer.n_variables() - 1
}

fn sum_as_poly_in_first_variable(&self, _claim: SecureField) -> UnivariatePoly<SecureField> {
todo!()
fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly<SecureField> {
B::sum_as_poly_in_first_variable(self, claim)
}

fn fix_first_variable(self, _challenge: SecureField) -> Self {
todo!()
fn fix_first_variable(self, challenge: SecureField) -> Self {
if self.n_variables() == 0 {
return self;
}

let z0 = self.eq_evals.y()[self.eq_evals.y().len() - self.n_variables()];
let eq_fixed_var_correction = self.eq_fixed_var_correction * eq(&[challenge], &[z0]);

Self {
eq_evals: self.eq_evals,
eq_fixed_var_correction,
input_layer: self.input_layer.fix_first_variable(challenge),
}
}
}

Expand All @@ -188,7 +246,14 @@ impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> {
///
/// For more context see <https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf> page 64.
fn try_into_mask(self) -> Result<GkrMask, NotConstantPolyError> {
todo!()
if self.n_variables() != 0 {
return Err(NotConstantPolyError);
}

match self.input_layer {
Layer::_LogUp(_) => todo!(),
Layer::GrandProduct(mle) => Ok(GkrMask::new(vec![mle.to_cpu().try_into().unwrap()])),
}
}
}

Expand Down Expand Up @@ -319,3 +384,51 @@ fn gen_layers<B: GkrOps>(input_layer: Layer<B>) -> Vec<Layer<B>> {
assert_eq!(layers.len(), n_variables + 1);
layers
}

/// Computes `r(t) = sum_x eq((t, x), y[-k:]) * p(t, x)` from evaluations of
/// `f(t) = sum_x eq(({0}^(n - k), 0, x), y) * p(t, x)`.
///
/// Note `claim` must equal `r(0) + r(1)` and `r` must have degree <= 3.
///
/// For more context see `Layer::into_multivariate_poly()` docs.
/// See also <https://ia.cr/2024/108> (section 3.2).
pub fn correct_sum_as_poly_in_first_variable(
f_at_0: SecureField,
f_at_2: SecureField,
claim: SecureField,
y: &[SecureField],
k: usize,
) -> UnivariatePoly<SecureField> {
assert_ne!(k, 0);
let n = y.len();
assert!(k <= n);

// We evaluated `f(0)` and `f(2)` - the inputs.
// We want to compute `r(t) = f(t) * eq(t, y[n - k]) / eq(0, y[:n - k + 1])`.
let a_const = eq(&vec![SecureField::zero(); n - k + 1], &y[..n - k + 1]).inverse();

// Find the additional root of `r(t)`, by finding the root of `eq(t, y[n - k])`:
// 0 = eq(t, y[n - k])
// = t * y[n - k] + (1 - t)(1 - y[n - k])
// = 1 - y[n - k] - t(1 - 2 * y[n - k])
// => t = (1 - y[n - k]) / (1 - 2 * y[n - k])
// = b
let b_const = (SecureField::one() - y[n - k]) / (SecureField::one() - y[n - k].double());

// We get that `r(t) = f(t) * eq(t, y[n - k]) * a`.
let r_at_0 = f_at_0 * eq(&[SecureField::zero()], &[y[n - k]]) * a_const;
let r_at_1 = claim - r_at_0;
let r_at_2 = f_at_2 * eq(&[BaseField::from(2).into()], &[y[n - k]]) * a_const;
let r_at_b = SecureField::zero();

// Interpolate.
UnivariatePoly::interpolate_lagrange(
&[
SecureField::zero(),
SecureField::one(),
SecureField::from(BaseField::from(2)),
b_const,
],
&[r_at_0, r_at_1, r_at_2, r_at_b],
)
}
Loading

0 comments on commit 437ea59

Please sign in to comment.