Skip to content

Commit

Permalink
Implement LogupOps for SIMD backend (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson authored Aug 21, 2024
1 parent a027a44 commit 8885e12
Show file tree
Hide file tree
Showing 3 changed files with 513 additions and 67 deletions.
43 changes: 7 additions & 36 deletions crates/prover/src/core/backend/cpu/lookups/gkr.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Add, Index};
use std::ops::Index;

use num_traits::{One, Zero};

Expand All @@ -11,7 +11,7 @@ use crate::core::lookups::gkr_prover::{
};
use crate::core::lookups::mle::{Mle, MleOps};
use crate::core::lookups::sumcheck::MultivariatePolyOracle;
use crate::core::lookups::utils::{Fraction, UnivariatePoly};
use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly};

impl GkrOps for CpuBackend {
fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle<Self, SecureField> {
Expand Down Expand Up @@ -173,23 +173,6 @@ fn eval_logup_singles_sum(
n_terms: usize,
lambda: SecureField,
) -> (SecureField, SecureField) {
/// Represents the fraction `1 / x`
struct Reciprocal {
x: SecureField,
}

impl Add for Reciprocal {
type Output = Fraction<SecureField>;

fn add(self, rhs: Self) -> Fraction<SecureField> {
// `1/a + 1/b = (a + b)/(a * b)`
Fraction {
numerator: self.x + rhs.x,
denominator: self.x * rhs.x,
}
}
}

let mut eval_at_0 = SecureField::zero();
let mut eval_at_2 = SecureField::zero();

Expand All @@ -211,19 +194,11 @@ fn eval_logup_singles_sum(
let Fraction {
numerator: numer_at_r0i,
denominator: denom_at_r0i,
} = Reciprocal {
x: inp_denom_at_r0i0,
} + Reciprocal {
x: inp_denom_at_r0i1,
};
} = Reciprocal::new(inp_denom_at_r0i0) + Reciprocal::new(inp_denom_at_r0i1);
let Fraction {
numerator: numer_at_r2i,
denominator: denom_at_r2i,
} = Reciprocal {
x: inp_denom_at_r2i0,
} + Reciprocal {
x: inp_denom_at_r2i1,
};
} = Reciprocal::new(inp_denom_at_r2i0) + Reciprocal::new(inp_denom_at_r2i1);

let eq_eval_at_0i = eq_evals[i];
eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i);
Expand Down Expand Up @@ -313,7 +288,6 @@ mod tests {
use crate::core::channel::Channel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer};
use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError};
use crate::core::lookups::mle::Mle;
Expand Down Expand Up @@ -368,7 +342,7 @@ mod tests {
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = zip(&numerator_values, &denominator_values)
.map(|(&n, &d)| Fraction::new(n, d))
.sum::<Fraction<SecureField>>();
.sum::<Fraction<SecureField, SecureField>>();
let numerators = Mle::<CpuBackend, SecureField>::new(numerator_values);
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpGeneric {
Expand Down Expand Up @@ -402,15 +376,12 @@ mod tests {
#[test]
fn logup_with_singles_trace_works() -> Result<(), GkrError> {
const N: usize = 1 << 5;
println!("{}", BaseField::from(2).inverse());
println!("{}", BaseField::from(1) - BaseField::from(2).inverse());

let mut rng = SmallRng::seed_from_u64(0);
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = denominator_values
.iter()
.map(|&d| Fraction::new(SecureField::one(), d))
.sum::<Fraction<SecureField>>();
.sum::<Fraction<SecureField, SecureField>>();
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpSingles {
denominators: denominators.clone(),
Expand Down Expand Up @@ -444,7 +415,7 @@ mod tests {
let denominator_values = (0..N).map(|_| rng.gen()).collect::<Vec<SecureField>>();
let sum = zip(&numerator_values, &denominator_values)
.map(|(&n, &d)| Fraction::new(n.into(), d))
.sum::<Fraction<SecureField>>();
.sum::<Fraction<SecureField, SecureField>>();
let numerators = Mle::<CpuBackend, BaseField>::new(numerator_values);
let denominators = Mle::<CpuBackend, SecureField>::new(denominator_values);
let top_layer = Layer::LogUpMultiplicities {
Expand Down
Loading

0 comments on commit 8885e12

Please sign in to comment.