From 8885e12849b23901f4c8335afede5f8ba7d3e893 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Wed, 21 Aug 2024 15:56:09 +0100 Subject: [PATCH] Implement LogupOps for SIMD backend (#641) --- .../src/core/backend/cpu/lookups/gkr.rs | 43 +- .../src/core/backend/simd/lookups/gkr.rs | 462 +++++++++++++++++- crates/prover/src/core/lookups/utils.rs | 75 ++- 3 files changed, 513 insertions(+), 67 deletions(-) diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index 4dbbc2322..ae9ab6b65 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -1,4 +1,4 @@ -use std::ops::{Add, Index}; +use std::ops::Index; use num_traits::{One, Zero}; @@ -11,7 +11,7 @@ use crate::core::lookups::gkr_prover::{ }; use crate::core::lookups::mle::{Mle, MleOps}; use crate::core::lookups::sumcheck::MultivariatePolyOracle; -use crate::core::lookups::utils::{Fraction, UnivariatePoly}; +use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly}; impl GkrOps for CpuBackend { fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle { @@ -173,23 +173,6 @@ fn eval_logup_singles_sum( n_terms: usize, lambda: SecureField, ) -> (SecureField, SecureField) { - /// Represents the fraction `1 / x` - struct Reciprocal { - x: SecureField, - } - - impl Add for Reciprocal { - type Output = Fraction; - - fn add(self, rhs: Self) -> Fraction { - // `1/a + 1/b = (a + b)/(a * b)` - Fraction { - numerator: self.x + rhs.x, - denominator: self.x * rhs.x, - } - } - } - let mut eval_at_0 = SecureField::zero(); let mut eval_at_2 = SecureField::zero(); @@ -211,19 +194,11 @@ fn eval_logup_singles_sum( let Fraction { numerator: numer_at_r0i, denominator: denom_at_r0i, - } = Reciprocal { - x: inp_denom_at_r0i0, - } + Reciprocal { - x: inp_denom_at_r0i1, - }; + } = Reciprocal::new(inp_denom_at_r0i0) + Reciprocal::new(inp_denom_at_r0i1); let Fraction { numerator: numer_at_r2i, denominator: denom_at_r2i, - } = Reciprocal { - x: inp_denom_at_r2i0, - } + Reciprocal { - x: inp_denom_at_r2i1, - }; + } = Reciprocal::new(inp_denom_at_r2i0) + Reciprocal::new(inp_denom_at_r2i1); let eq_eval_at_0i = eq_evals[i]; eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i); @@ -313,7 +288,6 @@ mod tests { use crate::core::channel::Channel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; - use crate::core::fields::FieldExpOps; use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; use crate::core::lookups::mle::Mle; @@ -368,7 +342,7 @@ mod tests { let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); let sum = zip(&numerator_values, &denominator_values) .map(|(&n, &d)| Fraction::new(n, d)) - .sum::>(); + .sum::>(); let numerators = Mle::::new(numerator_values); let denominators = Mle::::new(denominator_values); let top_layer = Layer::LogUpGeneric { @@ -402,15 +376,12 @@ mod tests { #[test] fn logup_with_singles_trace_works() -> Result<(), GkrError> { const N: usize = 1 << 5; - println!("{}", BaseField::from(2).inverse()); - println!("{}", BaseField::from(1) - BaseField::from(2).inverse()); - let mut rng = SmallRng::seed_from_u64(0); let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); let sum = denominator_values .iter() .map(|&d| Fraction::new(SecureField::one(), d)) - .sum::>(); + .sum::>(); let denominators = Mle::::new(denominator_values); let top_layer = Layer::LogUpSingles { denominators: denominators.clone(), @@ -444,7 +415,7 @@ mod tests { let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); let sum = zip(&numerator_values, &denominator_values) .map(|(&n, &d)| Fraction::new(n.into(), d)) - .sum::>(); + .sum::>(); let numerators = Mle::::new(numerator_values); let denominators = Mle::::new(denominator_values); let top_layer = Layer::LogUpMultiplicities { diff --git a/crates/prover/src/core/backend/simd/lookups/gkr.rs b/crates/prover/src/core/backend/simd/lookups/gkr.rs index 15edefa1f..017948dee 100644 --- a/crates/prover/src/core/backend/simd/lookups/gkr.rs +++ b/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -8,13 +8,14 @@ use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Column, CpuBackend}; +use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::lookups::gkr_prover::{ correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer, }; use crate::core::lookups::mle::Mle; use crate::core::lookups::sumcheck::MultivariatePolyOracle; -use crate::core::lookups::utils::UnivariatePoly; +use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly}; impl GkrOps for SimdBackend { #[allow(clippy::uninit_vec)] @@ -61,14 +62,14 @@ impl GkrOps for SimdBackend { match layer { Layer::GrandProduct(col) => next_grand_product_layer(col), Layer::LogUpGeneric { - numerators: _, - denominators: _, - } => todo!(), + numerators, + denominators, + } => next_logup_generic_layer(numerators, denominators), Layer::LogUpMultiplicities { - numerators: _, - denominators: _, - } => todo!(), - Layer::LogUpSingles { denominators: _ } => todo!(), + numerators, + denominators, + } => next_logup_multiplicities_layer(numerators, denominators), + Layer::LogUpSingles { denominators } => next_logup_singles_layer(denominators), } } @@ -88,18 +89,33 @@ impl GkrOps for SimdBackend { } let n_packed_terms = n_terms / N_LANES; + let packed_lambda = PackedSecureField::broadcast(h.lambda); let (mut eval_at_0, mut eval_at_2) = match &h.input_layer { Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_packed_terms), Layer::LogUpGeneric { - numerators: _, - denominators: _, - } => todo!(), + numerators, + denominators, + } => eval_logup_generic_sum( + eq_evals, + numerators, + denominators, + n_packed_terms, + packed_lambda, + ), Layer::LogUpMultiplicities { - numerators: _, - denominators: _, - } => todo!(), - Layer::LogUpSingles { denominators: _ } => todo!(), + numerators, + denominators, + } => eval_logup_multiplicities_sum( + eq_evals, + numerators, + denominators, + n_packed_terms, + packed_lambda, + ), + Layer::LogUpSingles { denominators } => { + eval_logup_singles_sum(eq_evals, denominators, n_packed_terms, packed_lambda) + } }; eval_at_0 *= h.eq_fixed_var_correction; @@ -130,6 +146,137 @@ fn next_grand_product_layer(layer: &Mle) -> Layer N_LANES`. +fn next_logup_generic_layer( + numerators: &Mle, + denominators: &Mle, +) -> Layer { + assert!(denominators.len() > N_LANES); + assert_eq!(numerators.len(), denominators.len()); + + let next_layer_len = denominators.len() / 2; + let next_layer_packed_len = next_layer_len / N_LANES; + + let mut next_numerators = Vec::with_capacity(next_layer_packed_len); + let mut next_denominators = Vec::with_capacity(next_layer_packed_len); + + for i in 0..next_layer_packed_len { + let (n_even, n_odd) = numerators.data[i * 2].deinterleave(numerators.data[i * 2 + 1]); + let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]); + + let Fraction { + numerator, + denominator, + } = Fraction::new(n_even, d_even) + Fraction::new(n_odd, d_odd); + + next_numerators.push(numerator); + next_denominators.push(denominator); + } + + let next_numerators = SecureColumn { + data: next_numerators, + length: next_layer_len, + }; + + let next_denominators = SecureColumn { + data: next_denominators, + length: next_layer_len, + }; + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), + } +} + +/// Generates the next GKR layer for LogUp. +/// +/// Assumption: `len(denominators) > N_LANES`. +// TODO(andrew): Code duplication of `next_logup_generic_layer`. Consider unifying these. +fn next_logup_multiplicities_layer( + numerators: &Mle, + denominators: &Mle, +) -> Layer { + assert!(denominators.len() > N_LANES); + assert_eq!(numerators.len(), denominators.len()); + + let next_layer_len = denominators.len() / 2; + let next_layer_packed_len = next_layer_len / N_LANES; + + let mut next_numerators = Vec::with_capacity(next_layer_packed_len); + let mut next_denominators = Vec::with_capacity(next_layer_packed_len); + + for i in 0..next_layer_packed_len { + let (n_even, n_odd) = numerators.data[i * 2].deinterleave(numerators.data[i * 2 + 1]); + let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]); + + let Fraction { + numerator, + denominator, + } = Fraction::new(n_even, d_even) + Fraction::new(n_odd, d_odd); + + next_numerators.push(numerator); + next_denominators.push(denominator); + } + + let next_numerators = SecureColumn { + data: next_numerators, + length: next_layer_len, + }; + + let next_denominators = SecureColumn { + data: next_denominators, + length: next_layer_len, + }; + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), + } +} + +/// Generates the next GKR layer for LogUp. +/// +/// Assumption: `len(denominators) > N_LANES`. +fn next_logup_singles_layer(denominators: &Mle) -> Layer { + assert!(denominators.len() > N_LANES); + + let next_layer_len = denominators.len() / 2; + let next_layer_packed_len = next_layer_len / N_LANES; + + let mut next_numerators = Vec::with_capacity(next_layer_packed_len); + let mut next_denominators = Vec::with_capacity(next_layer_packed_len); + + for i in 0..next_layer_packed_len { + let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]); + + let Fraction { + numerator, + denominator, + } = Reciprocal::new(d_even) + Reciprocal::new(d_odd); + + next_numerators.push(numerator); + next_denominators.push(denominator); + } + + let next_numerators = SecureColumn { + data: next_numerators, + length: next_layer_len, + }; + + let next_denominators = SecureColumn { + data: next_denominators, + length: next_layer_len, + }; + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), + } +} + /// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`. /// /// Output of the form: `(eval_at_0, eval_at_2)`. @@ -168,6 +315,172 @@ fn eval_grand_product_sum( ) } +fn eval_logup_generic_sum( + eq_evals: &EqEvals, + numerators: &Mle, + denominators: &Mle, + n_packed_terms: usize, + packed_lambda: PackedSecureField, +) -> (SecureField, SecureField) { + let mut packed_eval_at_0 = PackedSecureField::zero(); + let mut packed_eval_at_2 = PackedSecureField::zero(); + + let inp_numer = &numerators.data; + let inp_denom = &denominators.data; + + for i in 0..n_packed_terms { + // Input polynomials at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` + // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let (inp_numer_at_r0iv0, inp_numer_at_r0iv1) = + inp_numer[i * 2].deinterleave(inp_numer[i * 2 + 1]); + let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) = + inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]); + let (inp_numer_at_r1iv0, inp_numer_at_r1iv1) = inp_numer[(n_packed_terms + i) * 2] + .deinterleave(inp_numer[(n_packed_terms + i) * 2 + 1]); + let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2] + .deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]); + // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` + // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` + let inp_numer_at_r2iv0 = inp_numer_at_r1iv0.double() - inp_numer_at_r0iv0; + let inp_numer_at_r2iv1 = inp_numer_at_r1iv1.double() - inp_numer_at_r0iv1; + let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0; + let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1; + + // Fraction addition polynomials: + // - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)` + // - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`. + // at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let Fraction { + numerator: numer_at_r0iv, + denominator: denom_at_r0iv, + } = Fraction::new(inp_numer_at_r0iv0, inp_denom_at_r0iv0) + + Fraction::new(inp_numer_at_r0iv1, inp_denom_at_r0iv1); + let Fraction { + numerator: numer_at_r2iv, + denominator: denom_at_r2iv, + } = Fraction::new(inp_numer_at_r2iv0, inp_denom_at_r2iv0) + + Fraction::new(inp_numer_at_r2iv1, inp_denom_at_r2iv1); + + let eq_eval_at_0iv = eq_evals.data[i]; + packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv); + packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv); + } + + ( + packed_eval_at_0.pointwise_sum(), + packed_eval_at_2.pointwise_sum(), + ) +} + +// TODO(andrew): Code duplication of `eval_logup_generic_sum`. Consider unifying these. +fn eval_logup_multiplicities_sum( + eq_evals: &EqEvals, + numerators: &Mle, + denominators: &Mle, + n_packed_terms: usize, + packed_lambda: PackedSecureField, +) -> (SecureField, SecureField) { + let mut packed_eval_at_0 = PackedSecureField::zero(); + let mut packed_eval_at_2 = PackedSecureField::zero(); + + let inp_numer = &numerators.data; + let inp_denom = &denominators.data; + + for i in 0..n_packed_terms { + // Input polynomials at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` + // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let (inp_numer_at_r0iv0, inp_numer_at_r0iv1) = + inp_numer[i * 2].deinterleave(inp_numer[i * 2 + 1]); + let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) = + inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]); + let (inp_numer_at_r1iv0, inp_numer_at_r1iv1) = inp_numer[(n_packed_terms + i) * 2] + .deinterleave(inp_numer[(n_packed_terms + i) * 2 + 1]); + let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2] + .deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]); + // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` + // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` + let inp_numer_at_r2iv0 = inp_numer_at_r1iv0.double() - inp_numer_at_r0iv0; + let inp_numer_at_r2iv1 = inp_numer_at_r1iv1.double() - inp_numer_at_r0iv1; + let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0; + let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1; + + // Fraction addition polynomials: + // - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)` + // - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`. + // at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let Fraction { + numerator: numer_at_r0iv, + denominator: denom_at_r0iv, + } = Fraction::new(inp_numer_at_r0iv0, inp_denom_at_r0iv0) + + Fraction::new(inp_numer_at_r0iv1, inp_denom_at_r0iv1); + let Fraction { + numerator: numer_at_r2iv, + denominator: denom_at_r2iv, + } = Fraction::new(inp_numer_at_r2iv0, inp_denom_at_r2iv0) + + Fraction::new(inp_numer_at_r2iv1, inp_denom_at_r2iv1); + + let eq_eval_at_0iv = eq_evals.data[i]; + packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv); + packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv); + } + + ( + packed_eval_at_0.pointwise_sum(), + packed_eval_at_2.pointwise_sum(), + ) +} + +/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_denom(r, t, x, 1) + inp_denom(r, t, x, 0) + +/// lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, x, 1))` at `t=0` and `t=2`. +/// +/// Output of the form: `(eval_at_0, eval_at_2)`. +fn eval_logup_singles_sum( + eq_evals: &EqEvals, + denominators: &Mle, + n_packed_terms: usize, + packed_lambda: PackedSecureField, +) -> (SecureField, SecureField) { + let mut packed_eval_at_0 = PackedSecureField::zero(); + let mut packed_eval_at_2 = PackedSecureField::zero(); + + let inp_denom = &denominators.data; + + for i in 0..n_packed_terms { + // Input polynomial at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` + // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) = + inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]); + let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2] + .deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]); + // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` + // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` + let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0; + let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1; + + // Fraction addition polynomials: + // - `numer(x) = inp_denom(x, 1) + inp_denom(x, 0)` + // - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`. + // at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let Fraction { + numerator: numer_at_r0iv, + denominator: denom_at_r0iv, + } = Reciprocal::new(inp_denom_at_r0iv0) + Reciprocal::new(inp_denom_at_r0iv1); + let Fraction { + numerator: numer_at_r2iv, + denominator: denom_at_r2iv, + } = Reciprocal::new(inp_denom_at_r2iv0) + Reciprocal::new(inp_denom_at_r2iv1); + + let eq_eval_at_0iv = eq_evals.data[i]; + packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv); + packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv); + } + + ( + packed_eval_at_0.pointwise_sum(), + packed_eval_at_2.pointwise_sum(), + ) +} + fn into_simd_layer(cpu_layer: Layer) -> Layer { match cpu_layer { Layer::GrandProduct(mle) => { @@ -195,6 +508,12 @@ fn into_simd_layer(cpu_layer: Layer) -> Layer { #[cfg(test)] mod tests { + use std::iter::zip; + + use num_traits::One; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Column, CpuBackend}; use crate::core::channel::Channel; @@ -203,6 +522,7 @@ mod tests { use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; use crate::core::lookups::mle::Mle; + use crate::core::lookups::utils::Fraction; use crate::core::test_utils::test_channel; #[test] @@ -249,4 +569,116 @@ mod tests { ); Ok(()) } + + #[test] + fn logup_with_generic_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let mut rng = SmallRng::seed_from_u64(0); + let numerators = (0..N).map(|_| rng.gen()).collect::>(); + let denominators = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerators, &denominators) + .map(|(&n, &d)| Fraction::new(n, d)) + .sum::>(); + let numerators = Mle::::new(numerators.into_iter().collect()); + let denominators = Mle::::new(denominators.into_iter().collect()); + let input_layer = Layer::LogUpGeneric { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + numerators.eval_at_point(&ood_point), + denominators.eval_at_point(&ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } + + #[test] + fn logup_with_multiplicities_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let mut rng = SmallRng::seed_from_u64(0); + let numerators = (0..N).map(|_| rng.gen()).collect::>(); + let denominators = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerators, &denominators) + .map(|(&n, &d)| Fraction::new(n.into(), d)) + .sum::>(); + let numerators = Mle::::new(numerators.into_iter().collect()); + let denominators = Mle::::new(denominators.into_iter().collect()); + let input_layer = Layer::LogUpMultiplicities { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + numerators.eval_at_point(&ood_point), + denominators.eval_at_point(&ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } + + #[test] + fn logup_with_singles_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let mut rng = SmallRng::seed_from_u64(0); + let denominators = (0..N).map(|_| rng.gen()).collect::>(); + let sum = denominators + .iter() + .map(|&d| Fraction::new(SecureField::one(), d)) + .sum::>(); + let denominators = Mle::::new(denominators.into_iter().collect()); + let input_layer = Layer::LogUpSingles { + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [SecureField::one(), denominators.eval_at_point(&ood_point)] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } } diff --git a/crates/prover/src/core/lookups/utils.rs b/crates/prover/src/core/lookups/utils.rs index 70adb4c64..85ea4c32a 100644 --- a/crates/prover/src/core/lookups/utils.rs +++ b/crates/prover/src/core/lookups/utils.rs @@ -196,13 +196,13 @@ where /// Projective fraction. #[derive(Debug, Clone, Copy)] -pub struct Fraction { - pub numerator: F, - pub denominator: SecureField, +pub struct Fraction { + pub numerator: N, + pub denominator: D, } -impl Fraction { - pub fn new(numerator: F, denominator: SecureField) -> Self { +impl Fraction { + pub fn new(numerator: N, denominator: D) -> Self { Self { numerator, denominator, @@ -210,14 +210,12 @@ impl Fraction { } } -impl Add for Fraction -where - F: Field, - SecureField: ExtensionOf + Field, +impl + Add + Mul + Mul + Copy> Add + for Fraction { - type Output = Fraction; + type Output = Fraction; - fn add(self, rhs: Self) -> Fraction { + fn add(self, rhs: Self) -> Fraction { Fraction { numerator: rhs.denominator * self.numerator + self.denominator * rhs.numerator, denominator: self.denominator * rhs.denominator, @@ -225,11 +223,14 @@ where } } -impl Zero for Fraction { +impl Zero for Fraction +where + Self: Add, +{ fn zero() -> Self { Self { - numerator: SecureField::zero(), - denominator: SecureField::one(), + numerator: N::zero(), + denominator: D::one(), } } @@ -238,13 +239,39 @@ impl Zero for Fraction { } } -impl Sum for Fraction { +impl Sum for Fraction +where + Self: Zero, +{ fn sum>(mut iter: I) -> Self { let first = iter.next().unwrap_or_else(Self::zero); iter.fold(first, |a, b| a + b) } } +/// Represents the fraction `1 / x` +pub struct Reciprocal { + x: T, +} + +impl Reciprocal { + pub fn new(x: T) -> Self { + Self { x } + } +} + +impl + Mul + Copy> Add for Reciprocal { + type Output = Fraction; + + fn add(self, rhs: Self) -> Fraction { + // `1/a + 1/b = (a + b)/(a * b)` + Fraction { + numerator: self.x + rhs.x, + denominator: self.x * rhs.x, + } + } +} + #[cfg(test)] mod tests { use std::iter::zip; @@ -255,7 +282,7 @@ mod tests { use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; - use crate::core::lookups::utils::eq; + use crate::core::lookups::utils::{eq, Fraction}; #[test] fn lagrange_interpolation_works() { @@ -310,4 +337,20 @@ mod tests { eq(&[zero, one], &[zero]); } + + #[test] + fn fraction_addition_works() { + let a = Fraction::new(BaseField::from(1), BaseField::from(3)); + let b = Fraction::new(BaseField::from(2), BaseField::from(6)); + + let Fraction { + numerator, + denominator, + } = a + b; + + assert_eq!( + numerator / denominator, + BaseField::from(2) / BaseField::from(3) + ); + } }