From 1f9187e60555739ebf26877f8fe0e406c557c368 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Sun, 8 Sep 2024 21:04:08 -1000 Subject: [PATCH] Add build_trace functions for MLE eval component (#803) --- .../src/examples/xor/gkr_lookups/mle_eval.rs | 286 ++++++++++-------- 1 file changed, 155 insertions(+), 131 deletions(-) diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 5a5d60535..72dc133a5 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -3,20 +3,27 @@ #![allow(dead_code)] use std::array; +use std::iter::zip; -use itertools::Itertools; +use itertools::{chain, zip_eq, Itertools}; use num_traits::{One, Zero}; use crate::constraint_framework::EvalAtRow; +use crate::core::backend::simd::column::SecureColumn; +use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; +use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Col, Column}; use crate::core::circle::{CirclePoint, Coset}; use crate::core::constraints::{coset_vanishing, point_vanishing}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; use crate::core::fields::{Field, FieldExpOps}; +use crate::core::lookups::gkr_prover::GkrOps; +use crate::core::lookups::mle::Mle; use crate::core::lookups::utils::eq; -use crate::core::poly::circle::{CanonicCoset, SecureEvaluation}; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureEvaluation}; use crate::core::poly::BitReversedOrder; use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; @@ -24,24 +31,27 @@ use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; /// /// `mle_coeffs_col_eval` should be the evaluation of the column containing the coefficients of the /// MLE in the multilinear Lagrange basis. `mle_claim_shift` should equal `claim / 2^N_VARIABLES`. +#[allow(clippy::too_many_arguments)] pub fn eval_mle_eval_constraints( - mle_interaction: usize, - const_interaction: usize, + interaction: usize, eval: &mut E, mle_coeffs_col_eval: E::EF, mle_eval_point: MleEvalPoint, mle_claim_shift: SecureField, carry_quotients_col_eval: E::EF, + is_first: E::F, + is_second: E::F, ) { let eq_col_eval = eval_eq_constraints( - mle_interaction, - const_interaction, + interaction, eval, mle_eval_point, carry_quotients_col_eval, + is_first, + is_second, ); let terms_col_eval = mle_coeffs_col_eval * eq_col_eval; - eval_prefix_sum_constraints(mle_interaction, eval, terms_col_eval, mle_claim_shift) + eval_prefix_sum_constraints(interaction, eval, terms_col_eval, mle_claim_shift) } #[derive(Debug, Clone, Copy)] @@ -87,13 +97,13 @@ impl MleEvalPoint { /// See (Section 5.1). fn eval_eq_constraints( eq_interaction: usize, - const_interaction: usize, eval: &mut E, mle_eval_point: MleEvalPoint, carry_quotients_col_eval: E::EF, + is_first: E::F, + is_second: E::F, ) -> E::EF { let [curr, next_next] = eval.next_extension_interaction_mask(eq_interaction, [0, 2]); - let [is_first, is_second] = eval.next_interaction_mask(const_interaction, [0, -1]); // Check the initial value on half_coset0 and final value on half_coset1. // Combining these constraints is safe because `is_first` and `is_second` are never @@ -122,6 +132,45 @@ fn eval_prefix_sum_constraints( eval.add_constraint(curr - prev - row_diff + cumulative_sum_shift); } +/// Generates a trace. +/// +/// Trace structure: +/// +/// ```text +/// --------------------------------------------------------- +/// | EqEvals (basis) | MLE terms (prefix sum) | +/// --------------------------------------------------------- +/// | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | +/// --------------------------------------------------------- +/// ``` +pub fn build_trace( + mle: &Mle, + eval_point: &[SecureField], + claim: SecureField, +) -> Vec> { + let eq_evals = SimdBackend::gen_eq_evals(eval_point, SecureField::one()).into_evals(); + let mle_terms = hadamard_product(mle, &eq_evals); + + let eq_evals_cols = eq_evals.into_secure_column_by_coords().columns; + let mle_terms_cols = mle_terms.into_secure_column_by_coords().columns; + + #[cfg(test)] + debug_assert_eq!(claim, mle.eval_at_point(eval_point)); + let shift = claim / BaseField::from(mle.len()); + let packed_shift_coords = PackedSecureField::broadcast(shift).into_packed_m31s(); + let mut shifted_mle_terms_cols = mle_terms_cols; + zip(&mut shifted_mle_terms_cols, packed_shift_coords) + .for_each(|(col, shift_coord)| col.data.iter_mut().for_each(|v| *v -= shift_coord)); + let shifted_prefix_sum_cols = shifted_mle_terms_cols.map(inclusive_prefix_sum); + + let log_trace_domain_size = mle.n_variables() as u32; + let trace_domain = CanonicCoset::new(log_trace_domain_size).circle_domain(); + + chain![eq_evals_cols, shifted_prefix_sum_cols] + .map(|c| CircleEvaluation::new(trace_domain, c)) + .collect() +} + /// Returns succinct Eq carry quotients column. /// /// Given column `c(P)` defined on a [`CircleDomain`] `D = +-C`, and an MLE eval point @@ -131,10 +180,11 @@ fn eval_prefix_sum_constraints( /// /// [`CircleDomain`]: crate::core::poly::circle::CircleDomain fn gen_carry_quotient_col( - eval_point: &MleEvalPoint, + eval_point: &[SecureField; N_VARIABLES], ) -> SecureEvaluation { + let mle_eval_point = MleEvalPoint::new(*eval_point); let (half_coset0_carry_quotients, half_coset1_carry_quotients) = - gen_half_coset_carry_quotients(eval_point); + gen_half_coset_carry_quotients(&mle_eval_point); let log_size = N_VARIABLES as u32; let size = 1 << log_size; @@ -253,12 +303,24 @@ fn gen_half_coset_carry_quotients( (half_coset0_carry_quotients, half_coset1_carry_quotients) } +/// Returns the element-wise product of `a` and `b`. +fn hadamard_product( + a: &Col, + b: &Col, +) -> Col { + assert_eq!(a.len(), b.len()); + SecureColumn { + data: zip_eq(&a.data, &b.data).map(|(&a, &b)| a * b).collect(), + length: a.len(), + } +} + #[cfg(test)] mod tests { use std::array; use std::iter::{repeat, zip}; - use itertools::{chain, zip_eq, Itertools}; + use itertools::{chain, Itertools}; use num_traits::One; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; @@ -269,58 +331,59 @@ mod tests { }; use crate::constraint_framework::constant_columns::{gen_is_first, gen_is_step_with_offset}; use crate::constraint_framework::{assert_constraints, EvalAtRow}; - use crate::core::backend::simd::column::SecureColumn; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; - use crate::core::backend::{Col, Column}; use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; - use crate::core::lookups::gkr_prover::GkrOps; use crate::core::lookups::mle::Mle; use crate::core::pcs::TreeVec; - use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps, SecureEvaluation}; use crate::core::poly::BitReversedOrder; use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; - use crate::examples::xor::gkr_lookups::mle_eval::eval_step_selector_with_offset; + use crate::examples::xor::gkr_lookups::mle_eval::{ + build_trace, eval_step_selector_with_offset, + }; #[test] fn test_mle_eval_constraints_with_log_size_5() { const N_VARIABLES: usize = 5; - const EVAL_TRACE: usize = 0; - const CARRY_QUOTIENTS_TRACE: usize = 1; - const CONST_TRACE: usize = 2; + const COEFFS_COL_TRACE: usize = 0; + const MLE_EVAL_TRACE: usize = 1; + const AUX_TRACE: usize = 2; let mut rng = SmallRng::seed_from_u64(0); let log_size = N_VARIABLES as u32; let size = 1 << log_size; - let mle = Mle::new((0..size).map(|_| rng.gen::()).collect()); + let mle_coeffs = (0..size).map(|_| rng.gen::()).collect(); + let mle = Mle::::new(mle_coeffs); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); - let mle_eval_point = MleEvalPoint::new(eval_point); - let base_trace = gen_base_trace(&mle, &eval_point); let claim = mle.eval_at_point(&eval_point); + let mle_eval_point = MleEvalPoint::new(eval_point); + let mle_eval_trace = build_trace(&mle, &eval_point, claim); + let mle_coeffs_col_trace = build_mle_coeffs_trace(mle); let claim_shift = claim / BaseField::from(size); - let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) - .into_coordinate_evals() - .to_vec(); - let constants_trace = gen_constants_trace::(); - let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); + let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); + let is_first_col = [gen_is_first(log_size)]; + let aux_trace = chain![carry_quotients_col, is_first_col].collect(); + let traces = TreeVec::new(vec![mle_coeffs_col_trace, mle_eval_trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(log_size); assert_constraints(&trace_polys, trace_domain, |mut eval| { - let [mle_coeff_col_eval] = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); - let [carry_quotients_col_eval] = - eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); + let [mle_coeff_col_eval] = eval.next_extension_interaction_mask(COEFFS_COL_TRACE, [0]); + let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); + let [is_first_eval, is_second_eval] = eval.next_interaction_mask(AUX_TRACE, [0, -1]); eval_mle_eval_constraints( - EVAL_TRACE, - CONST_TRACE, + MLE_EVAL_TRACE, &mut eval, mle_coeff_col_eval, mle_eval_point, claim_shift, carry_quotients_col_eval, + is_first_eval, + is_second_eval, ) }); } @@ -329,32 +392,30 @@ mod tests { #[ignore = "SimdBackend `MIN_FFT_LOG_SIZE` is 5"] fn eq_constraints_with_4_variables() { const N_VARIABLES: usize = 4; - const EVAL_TRACE: usize = 0; - const CARRY_QUOTIENTS_TRACE: usize = 1; - const CONST_TRACE: usize = 2; + const EQ_EVAL_TRACE: usize = 0; + const AUX_TRACE: usize = 1; let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); let mle_eval_point = MleEvalPoint::new(eval_point); - let base_trace = gen_base_trace(&mle, &eval_point); - let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) - .into_coordinate_evals() - .to_vec(); - let constants_trace = gen_constants_trace::(); - let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); + let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); + let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); + let is_first_col = [gen_is_first(N_VARIABLES as u32)]; + let aux_trace = chain![carry_quotients_col, is_first_col].collect(); + let traces = TreeVec::new(vec![trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(eval_point.len() as u32); assert_constraints(&trace_polys, trace_domain, |mut eval| { - let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); - let [carry_quotients_col_eval] = - eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); + let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); + let [is_first, is_second] = eval.next_interaction_mask(AUX_TRACE, [0, -1]); eval_eq_constraints( - EVAL_TRACE, - CONST_TRACE, + EQ_EVAL_TRACE, &mut eval, mle_eval_point, carry_quotients_col_eval, + is_first, + is_second, ); }); } @@ -362,32 +423,30 @@ mod tests { #[test] fn eq_constraints_with_5_variables() { const N_VARIABLES: usize = 5; - const EVAL_TRACE: usize = 0; - const CARRY_QUOTIENTS_TRACE: usize = 1; - const CONST_TRACE: usize = 2; + const EQ_EVAL_TRACE: usize = 0; + const AUX_TRACE: usize = 1; let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); let mle_eval_point = MleEvalPoint::new(eval_point); - let base_trace = gen_base_trace(&mle, &eval_point); - let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) - .into_coordinate_evals() - .to_vec(); - let constants_trace = gen_constants_trace::(); - let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); + let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); + let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); + let is_first_col = [gen_is_first(N_VARIABLES as u32)]; + let aux_trace = chain![carry_quotients_col, is_first_col].collect(); + let traces = TreeVec::new(vec![trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(eval_point.len() as u32); assert_constraints(&trace_polys, trace_domain, |mut eval| { - let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); - let [carry_quotients_col_eval] = - eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); + let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); + let [is_first, is_second] = eval.next_interaction_mask(AUX_TRACE, [0, -1]); eval_eq_constraints( - EVAL_TRACE, - CONST_TRACE, + EQ_EVAL_TRACE, &mut eval, mle_eval_point, carry_quotients_col_eval, + is_first, + is_second, ); }); } @@ -395,32 +454,30 @@ mod tests { #[test] fn eq_constraints_with_8_variables() { const N_VARIABLES: usize = 8; - const EVAL_TRACE: usize = 0; - const CARRY_QUOTIENTS_TRACE: usize = 1; - const CONST_TRACE: usize = 2; + const EQ_EVAL_TRACE: usize = 0; + const AUX_TRACE: usize = 1; let mut rng = SmallRng::seed_from_u64(0); let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); let mle_eval_point = MleEvalPoint::new(eval_point); - let base_trace = gen_base_trace(&mle, &eval_point); - let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) - .into_coordinate_evals() - .to_vec(); - let constants_trace = gen_constants_trace::(); - let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); + let trace = build_trace(&mle, &eval_point, mle.eval_at_point(&eval_point)); + let carry_quotients_col = gen_carry_quotient_col(&eval_point).into_coordinate_evals(); + let is_first_col = [gen_is_first(N_VARIABLES as u32)]; + let aux_trace = chain![carry_quotients_col, is_first_col].collect(); + let traces = TreeVec::new(vec![trace, aux_trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); let trace_domain = CanonicCoset::new(eval_point.len() as u32); assert_constraints(&trace_polys, trace_domain, |mut eval| { - let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); - let [carry_quotients_col_eval] = - eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); + let [carry_quotients_col_eval] = eval.next_extension_interaction_mask(AUX_TRACE, [0]); + let [is_first, is_second] = eval.next_interaction_mask(AUX_TRACE, [0, -1]); eval_eq_constraints( - EVAL_TRACE, - CONST_TRACE, + EQ_EVAL_TRACE, &mut eval, mle_eval_point, carry_quotients_col_eval, + is_first, + is_second, ); }); } @@ -461,8 +518,9 @@ mod tests { fn eval_carry_quotient_col_works() { const N_VARIABLES: usize = 5; let mut rng = SmallRng::seed_from_u64(0); - let mle_eval_point = MleEvalPoint::::new(array::from_fn(|_| rng.gen())); - let col_eval = gen_carry_quotient_col(&mle_eval_point); + let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let mle_eval_point = MleEvalPoint::new(eval_point); + let col_eval = gen_carry_quotient_col(&eval_point); let twiddles = SimdBackend::precompute_twiddles(col_eval.domain.half_coset); let col_poly = col_eval.interpolate_with_twiddles(&twiddles); let p = SECURE_FIELD_CIRCLE_GEN; @@ -472,45 +530,6 @@ mod tests { assert_eq!(eval, col_poly.eval_at_point(p)); } - /// Generates a trace. - /// - /// Trace structure: - /// - /// ```text - /// ------------------------------------------------------------------------------------- - /// | MLE coeffs | EqEvals (basis) | MLE terms (prefix sum) | - /// ------------------------------------------------------------------------------------- - /// | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c9 | c9 | c10 | c11 | - /// ------------------------------------------------------------------------------------- - /// ``` - fn gen_base_trace( - mle: &Mle, - eval_point: &[SecureField], - ) -> Vec> { - let mle_coeffs = mle.clone().into_evals(); - let eq_evals = SimdBackend::gen_eq_evals(eval_point, SecureField::one()).into_evals(); - let mle_terms = hadamard_product(&mle_coeffs, &eq_evals); - - let mle_coeff_cols = mle_coeffs.into_secure_column_by_coords().columns; - let eq_evals_cols = eq_evals.into_secure_column_by_coords().columns; - let mle_terms_cols = mle_terms.into_secure_column_by_coords().columns; - - let claim = mle.eval_at_point(eval_point); - let shift = claim / BaseField::from(mle.len()); - let packed_shifts = PackedSecureField::broadcast(shift).into_packed_m31s(); - let mut shifted_mle_terms_cols = mle_terms_cols.clone(); - zip(&mut shifted_mle_terms_cols, packed_shifts) - .for_each(|(col, shift)| col.data.iter_mut().for_each(|v| *v -= shift)); - let shifted_prefix_sum_cols = shifted_mle_terms_cols.map(inclusive_prefix_sum); - - let log_trace_domain_size = mle.n_variables() as u32; - let trace_domain = CanonicCoset::new(log_trace_domain_size).circle_domain(); - - chain![mle_coeff_cols, eq_evals_cols, shifted_prefix_sum_cols] - .map(|c| CircleEvaluation::new(trace_domain, c)) - .collect() - } - /// Generates a trace. /// /// Trace structure: @@ -551,21 +570,26 @@ mod tests { .collect() } - /// Returns the element-wise product of `a` and `b`. - fn hadamard_product( - a: &Col, - b: &Col, - ) -> Col { - assert_eq!(a.len(), b.len()); - SecureColumn { - data: zip_eq(&a.data, &b.data).map(|(&a, &b)| a * b).collect(), - length: a.len(), - } - } - - fn gen_constants_trace( + /// Generates a trace. + /// + /// Trace structure: + /// + /// ```text + /// ----------------------------- + /// | MLE coeffs col | + /// ----------------------------- + /// | c0 | c1 | c2 | c3 | + /// ----------------------------- + /// ``` + fn build_mle_coeffs_trace( + mle: Mle, ) -> Vec> { - let log_size = N_VARIABLES as u32; - vec![gen_is_first(log_size)] + let log_size = mle.n_variables() as u32; + let trace_domain = CanonicCoset::new(log_size).circle_domain(); + let mle_coeffs_col_by_coords = mle.into_evals().into_secure_column_by_coords(); + SecureEvaluation::new(trace_domain, mle_coeffs_col_by_coords) + .into_coordinate_evals() + .into_iter() + .collect() } }