Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GKR implementation of Grand Product lookups #620

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 96 additions & 8 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,84 @@
use num_traits::Zero;

use crate::core::backend::CpuBackend;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer};
use crate::core::fields::Field;
use crate::core::lookups::gkr_prover::{
correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer,
};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::UnivariatePoly;

impl GkrOps for CpuBackend {
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField> {
Mle::new(gen_eq_evals(y, v))
}

fn next_layer(_layer: &Layer<Self>) -> Layer<Self> {
todo!()
fn next_layer(layer: &Layer<Self>) -> Layer<Self> {
match layer {
Layer::GrandProduct(layer) => next_grand_product_layer(layer),
Layer::_LogUp(_) => todo!(),
}
}

fn sum_as_poly_in_first_variable(
_h: &GkrMultivariatePolyOracle<'_, Self>,
_claim: SecureField,
) -> crate::core::lookups::utils::UnivariatePoly<SecureField> {
todo!()
h: &GkrMultivariatePolyOracle<'_, Self>,
claim: SecureField,
) -> UnivariatePoly<SecureField> {
let n_variables = h.n_variables();
assert!(!n_variables.is_zero());
let n_terms = 1 << (n_variables - 1);
let eq_evals = h.eq_evals;
// Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube.
let y = eq_evals.y();
let input_layer = &h.input_layer;

let (mut eval_at_0, mut eval_at_2) = match input_layer {
Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_terms),
Layer::_LogUp(_) => todo!(),
};

eval_at_0 *= h.eq_fixed_var_correction;
eval_at_2 *= h.eq_fixed_var_correction;
correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables)
}
}

/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`.
///
/// Output of the form: `(eval_at_0, eval_at_2)`.
fn eval_grand_product_sum(
eq_evals: &EqEvals<CpuBackend>,
input_layer: &Mle<CpuBackend, SecureField>,
n_terms: usize,
) -> (SecureField, SecureField) {
let mut eval_at_0 = SecureField::zero();
let mut eval_at_2 = SecureField::zero();

for i in 0..n_terms {
// Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`.
let inp_at_r0i0 = input_layer[i * 2];
let inp_at_r0i1 = input_layer[i * 2 + 1];
let inp_at_r1i0 = input_layer[(n_terms + i) * 2];
let inp_at_r1i1 = input_layer[(n_terms + i) * 2 + 1];
// Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)`
// => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)`
let inp_at_r2i0 = inp_at_r1i0.double() - inp_at_r0i0;
let inp_at_r2i1 = inp_at_r1i1.double() - inp_at_r0i1;

// Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i))`.
let prod_at_r2i = inp_at_r2i0 * inp_at_r2i1;
let prod_at_r0i = inp_at_r0i0 * inp_at_r0i1;

let eq_eval_at_0i = eq_evals[i];
eval_at_0 += eq_eval_at_0i * prod_at_r0i;
eval_at_2 += eq_eval_at_0i * prod_at_r2i;
}

(eval_at_0, eval_at_2)
}

/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`.
///
/// Evaluations are returned in bit-reversed order.
Expand All @@ -40,15 +99,24 @@ fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec<SecureField> {
evals
}

fn next_grand_product_layer(layer: &Mle<CpuBackend, SecureField>) -> Layer<CpuBackend> {
let res = layer.array_chunks().map(|&[a, b]| a * b).collect();
Layer::GrandProduct(Mle::new(res))
}

#[cfg(test)]
mod tests {
use num_traits::{One, Zero};

use crate::core::backend::CpuBackend;
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError};
use crate::core::lookups::mle::Mle;
use crate::core::lookups::utils::eq;
use crate::core::test_utils::test_channel;

#[test]
fn gen_eq_evals() {
Expand All @@ -69,4 +137,24 @@ mod tests {
]
);
}

#[test]
fn grand_product_works() -> Result<(), GkrError> {
const N: usize = 1 << 5;
let values = test_channel().draw_felts(N);
let product = values.iter().product::<SecureField>();
let col = Mle::<CpuBackend, SecureField>::new(values);
let input_layer = Layer::GrandProduct(col.clone());
let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]);

let GkrArtifact {
ood_point: r,
claims_to_verify_by_instance,
..
} = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?;

assert_eq!(proof.output_claims_by_instance, [vec![product]]);
assert_eq!(claims_to_verify_by_instance, [vec![col.eval_at_point(&r)]]);
Ok(())
}
}
139 changes: 126 additions & 13 deletions crates/prover/src/core/lookups/gkr_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ use super::sumcheck::MultivariatePolyOracle;
use super::utils::{eq, random_linear_combination, UnivariatePoly};
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{Field, FieldExpOps};
use crate::core::lookups::sumcheck;

pub trait GkrOps: MleOps<SecureField> {
pub trait GkrOps: MleOps<BaseField> + MleOps<SecureField> {
/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`.
///
/// Note [`Mle`] stores values in bit-reversed order.
Expand Down Expand Up @@ -85,13 +87,16 @@ impl<B: ColumnOps<SecureField>> Deref for EqEvals<B> {
/// [LogUp]: https://eprint.iacr.org/2023/1284.pdf
pub enum Layer<B: GkrOps> {
_LogUp(B),
_GrandProduct(B),
GrandProduct(Mle<B, SecureField>),
}

impl<B: GkrOps> Layer<B> {
/// Returns the number of variables used to interpolate the layer's gate values.
fn n_variables(&self) -> usize {
todo!()
match self {
Self::_LogUp(_) => todo!(),
Self::GrandProduct(mle) => mle.n_variables(),
}
}

/// Produces the next layer from the current layer.
Expand All @@ -112,7 +117,28 @@ impl<B: GkrOps> Layer<B> {

/// Returns each column output if the layer is an output layer, otherwise returns an `Err`.
fn try_into_output_layer_values(self) -> Result<Vec<SecureField>, NotOutputLayerError> {
todo!()
if !self.is_output_layer() {
return Err(NotOutputLayerError);
}

Ok(match self {
Self::GrandProduct(col) => {
vec![col.at(0)]
}
Self::_LogUp(_) => todo!(),
})
}

/// Returns a transformed layer with the first variable of each column fixed to `assignment`.
fn fix_first_variable(self, x0: SecureField) -> Self {
if self.n_variables() == 0 {
return self;
}

match self {
Self::_LogUp(_) => todo!(),
Self::GrandProduct(mle) => Self::GrandProduct(mle.fix_first_variable(x0)),
}
}

/// Represents the next GKR layer evaluation as a multivariate polynomial which uses this GKR
Expand Down Expand Up @@ -145,16 +171,37 @@ impl<B: GkrOps> Layer<B> {
fn into_multivariate_poly(
self,
_lambda: SecureField,
_eq_evals: &EqEvals<B>,
eq_evals: &EqEvals<B>,
) -> GkrMultivariatePolyOracle<'_, B> {
todo!()
GkrMultivariatePolyOracle {
eq_evals,
input_layer: self,
eq_fixed_var_correction: SecureField::one(),
}
}
}

#[derive(Debug)]
struct NotOutputLayerError;

/// A multivariate polynomial that expresses the relation between two consecutive GKR layers.
/// Multivariate polynomial `P` that expresses the relation between two consecutive GKR layers.
///
/// When the input layer is [`Layer::GrandProduct`] (represented by multilinear column `inp`)
/// the polynomial represents:
///
/// ```text
/// P(x) = eq(x, y) * inp(x, 0) * inp(x, 1)
/// ```
///
/// When the input layer is LogUp (represented by multilinear columns `inp_numer` and
/// `inp_denom`) the polynomial represents:
///
/// ```text
/// numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)
/// denom(x) = inp_denom(x, 0) * inp_denom(x, 1)
///
/// P(x) = eq(x, y) * (numer(x) + lambda * denom(x))
/// ```
pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> {
/// `eq_evals` passed by `Layer::into_multivariate_poly()`.
pub eq_evals: &'a EqEvals<B>,
Expand All @@ -164,15 +211,26 @@ pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> {

impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> {
fn n_variables(&self) -> usize {
todo!()
self.input_layer.n_variables() - 1
}

fn sum_as_poly_in_first_variable(&self, _claim: SecureField) -> UnivariatePoly<SecureField> {
todo!()
fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly<SecureField> {
B::sum_as_poly_in_first_variable(self, claim)
}

fn fix_first_variable(self, _challenge: SecureField) -> Self {
todo!()
fn fix_first_variable(self, challenge: SecureField) -> Self {
if self.n_variables() == 0 {
return self;
}

let z0 = self.eq_evals.y()[self.eq_evals.y().len() - self.n_variables()];
let eq_fixed_var_correction = self.eq_fixed_var_correction * eq(&[challenge], &[z0]);

Self {
eq_evals: self.eq_evals,
eq_fixed_var_correction,
input_layer: self.input_layer.fix_first_variable(challenge),
}
}
}

Expand All @@ -188,7 +246,14 @@ impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> {
///
/// For more context see <https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf> page 64.
fn try_into_mask(self) -> Result<GkrMask, NotConstantPolyError> {
todo!()
if self.n_variables() != 0 {
return Err(NotConstantPolyError);
}

match self.input_layer {
Layer::_LogUp(_) => todo!(),
Layer::GrandProduct(mle) => Ok(GkrMask::new(vec![mle.to_cpu().try_into().unwrap()])),
}
}
}

Expand Down Expand Up @@ -319,3 +384,51 @@ fn gen_layers<B: GkrOps>(input_layer: Layer<B>) -> Vec<Layer<B>> {
assert_eq!(layers.len(), n_variables + 1);
layers
}

/// Computes `r(t) = sum_x eq((t, x), y[-k:]) * p(t, x)` from evaluations of
/// `f(t) = sum_x eq(({0}^(n - k), 0, x), y) * p(t, x)`.
///
/// Note `claim` must equal `r(0) + r(1)` and `r` must have degree <= 3.
///
/// For more context see `Layer::into_multivariate_poly()` docs.
/// See also <https://ia.cr/2024/108> (section 3.2).
pub fn correct_sum_as_poly_in_first_variable(
f_at_0: SecureField,
f_at_2: SecureField,
claim: SecureField,
y: &[SecureField],
k: usize,
) -> UnivariatePoly<SecureField> {
assert_ne!(k, 0);
let n = y.len();
assert!(k <= n);

// We evaluated `f(0)` and `f(2)` - the inputs.
// We want to compute `r(t) = f(t) * eq(t, y[n - k]) / eq(0, y[:n - k + 1])`.
let a_const = eq(&vec![SecureField::zero(); n - k + 1], &y[..n - k + 1]).inverse();

// Find the additional root of `r(t)`, by finding the root of `eq(t, y[n - k])`:
// 0 = eq(t, y[n - k])
// = t * y[n - k] + (1 - t)(1 - y[n - k])
// = 1 - y[n - k] - t(1 - 2 * y[n - k])
// => t = (1 - y[n - k]) / (1 - 2 * y[n - k])
// = b
let b_const = (SecureField::one() - y[n - k]) / (SecureField::one() - y[n - k].double());

// We get that `r(t) = f(t) * eq(t, y[n - k]) * a`.
let r_at_0 = f_at_0 * eq(&[SecureField::zero()], &[y[n - k]]) * a_const;
let r_at_1 = claim - r_at_0;
let r_at_2 = f_at_2 * eq(&[BaseField::from(2).into()], &[y[n - k]]) * a_const;
let r_at_b = SecureField::zero();

// Interpolate.
UnivariatePoly::interpolate_lagrange(
&[
SecureField::zero(),
SecureField::one(),
SecureField::from(BaseField::from(2)),
b_const,
],
&[r_at_0, r_at_1, r_at_2, r_at_b],
)
}
Loading
Loading