diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index 5565186f6..98ee61355 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -1,8 +1,9 @@ use std::ops::{Add, Mul, Sub}; use itertools::Itertools; -use num_traits::Zero; +use num_traits::{One, Zero}; +use super::EvalAtRow; use crate::core::backend::simd::column::SecureColumn; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; @@ -19,8 +20,102 @@ use crate::core::poly::BitReversedOrder; use crate::core::utils::shifted_secure_combination; use crate::core::ColumnVec; +/// Evaluates constraints for batched logups. +/// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum. +/// BATCH_SIZE is the number of fractions to batch together. The degree of the resulting constraints +/// will be BATCH_SIZE + 1. +pub struct LogupAtRow { + /// The index of the interaction used for the cumulative sum columns. + pub interaction: usize, + /// Queue of fractions waiting to be batched together. + pub queue: [(E::EF, E::EF); BATCH_SIZE], + /// Number of fractions in the queue. + pub queue_size: usize, + /// The claimed sum of all the fractions. + pub claimed_sum: SecureField, + /// The evaluation of the last cumulative sum column. + pub prev_col_cumsum: E::EF, + /// The value of the `is_first` constant column at current row. + /// See [`super::constant_columns::gen_is_first()`]. + pub is_first: E::F, +} +impl LogupAtRow { + pub fn new(interaction: usize, claimed_sum: SecureField, is_first: E::F) -> Self { + Self { + interaction, + queue: [(E::EF::zero(), E::EF::zero()); BATCH_SIZE], + queue_size: 0, + claimed_sum, + prev_col_cumsum: E::EF::zero(), + is_first, + } + } + pub fn push_lookup( + &mut self, + eval: &mut E, + numerator: E::EF, + values: &[E::F], + lookup_elements: LookupElements, + ) { + let shifted_value = shifted_secure_combination( + values, + E::EF::from(lookup_elements.alpha), + E::EF::from(lookup_elements.z), + ); + self.push_frac(eval, numerator, shifted_value); + } + + pub fn push_frac(&mut self, eval: &mut E, numerator: E::EF, denominator: E::EF) { + if self.queue_size < BATCH_SIZE { + self.queue[self.queue_size] = (numerator, denominator); + self.queue_size += 1; + return; + } + + // Compute sum_i pi/qi over batch, as a fraction, p/q. + let (num, denom) = self + .queue + .iter() + .copied() + .fold((E::EF::zero(), E::EF::one()), |(p0, q0), (pi, qi)| { + (p0 * qi + pi * q0, qi * q0) + }); + + self.queue[0] = (numerator, denominator); + self.queue_size = 1; + + // Add a constraint that p / q = diff. + let cur_cumsum = E::combine_ef(std::array::from_fn(|_| { + eval.next_interaction_mask(1, [0])[0] + })); + let diff = cur_cumsum - self.prev_col_cumsum; + self.prev_col_cumsum = cur_cumsum; + eval.add_constraint(diff * denom - num); + } + + pub fn finalize(self, eval: &mut E) { + let (num, denom) = self.queue[0..self.queue_size] + .iter() + .copied() + .fold((E::EF::zero(), E::EF::one()), |(p0, q0), (pi, qi)| { + (p0 * qi + pi * q0, qi * q0) + }); + + let cumsum_mask = + std::array::from_fn(|_| eval.next_interaction_mask(self.interaction, [0, -1])); + let cur_cumsum = E::combine_ef(cumsum_mask.map(|[cur_row, _prev_row]| cur_row)); + let prev_row_cumsum = E::combine_ef(cumsum_mask.map(|[_cur_row, prev_row]| prev_row)); + + // Fix `prev_row_cumsum` by subtracting `claimed_sum` if this is the first row. + let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first * self.claimed_sum; + let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum; + + eval.add_constraint(diff * denom - num); + } +} + /// Interaction elements for the logup protocol. -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct LookupElements { pub z: SecureField, pub alpha: SecureField, diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 12b2f290e..ef47c8e71 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -8,7 +8,7 @@ mod simd_domain; use std::array; use std::fmt::Debug; -use std::ops::{Add, AddAssign, Mul, Sub}; +use std::ops::{Add, AddAssign, Mul, Neg, Sub}; pub use assert::{assert_constraints, AssertEvaluator}; pub use info::InfoEvaluator; @@ -45,13 +45,15 @@ pub trait EvalAtRow { + Copy + Debug + Zero + + Neg + Add + Sub + Mul + Add + Mul + Sub - + Mul; + + Mul + + From; /// Returns the next mask value for the first interaction at offset 0. fn next_trace_mask(&mut self) -> Self::F { diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 963887768..350d45cf6 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -4,15 +4,15 @@ use std::array; use std::ops::{Add, AddAssign, Mul, Sub}; use itertools::Itertools; +use num_traits::{One, Zero}; #[cfg(feature = "parallel")] use rayon::prelude::*; use tracing::{span, Level}; use crate::constraint_framework::constant_columns::gen_is_first; -use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements}; -use crate::constraint_framework::{EvalAtRow, PointEvaluator, SimdDomainEvaluator}; +use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; +use crate::constraint_framework::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator}; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use crate::core::air::mask::fixed_mask_points; use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace}; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::simd::m31::{PackedBaseField, PackedM31, LOG_N_LANES}; @@ -24,7 +24,6 @@ use crate::core::circle::CirclePoint; use crate::core::constraints::coset_vanishing; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; use crate::core::fields::{FieldExpOps, IntoSlice}; use crate::core::pcs::{CommitmentSchemeProver, TreeVec}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; @@ -54,6 +53,8 @@ const INTERNAL_ROUND_CONSTS: [BaseField; N_PARTIAL_ROUNDS] = #[derive(Clone)] pub struct PoseidonComponent { pub log_n_rows: u32, + pub lookup_elements: LookupElements, + pub claimed_sum: SecureField, } impl PoseidonComponent { @@ -105,9 +106,23 @@ impl AirTraceGenerator for PoseidonAir { } } +pub fn poseidon_info() -> InfoEvaluator { + let mut eval = InfoEvaluator::default(); + let [is_first] = eval.next_interaction_mask(2, [0]); + let counter = PoseidonEval { + eval, + lookup_elements: LookupElements { + z: SecureField::one(), + alpha: SecureField::one(), + }, + logup: LogupAtRow::new(1, SecureField::zero(), is_first), + }; + counter.eval() +} + impl Component for PoseidonComponent { fn n_constraints(&self) -> usize { - (N_COLUMNS_PER_REP - N_STATE) * N_INSTANCES_PER_ROW + poseidon_info().n_constraints } fn max_constraint_log_degree_bound(&self) -> u32 { @@ -115,22 +130,32 @@ impl Component for PoseidonComponent { } fn trace_log_degree_bounds(&self) -> TreeVec> { - TreeVec::new(vec![ - vec![self.log_column_size(); N_COLUMNS], - vec![self.log_column_size(); N_INSTANCES_PER_ROW * SECURE_EXTENSION_DEGREE], - vec![self.log_column_size()], - ]) + TreeVec::new( + poseidon_info() + .mask_offsets + .iter() + .map(|tree_masks| vec![self.log_n_rows; tree_masks.len()]) + .collect(), + ) } fn mask_points( &self, point: CirclePoint, ) -> TreeVec>>> { - TreeVec::new(vec![ - fixed_mask_points(&vec![vec![0_usize]; N_COLUMNS], point), - vec![vec![]; N_INSTANCES_PER_ROW * SECURE_EXTENSION_DEGREE], - vec![vec![point]], - ]) + let trace_step = CanonicCoset::new(self.log_n_rows).step(); + let counter = poseidon_info(); + counter.mask_offsets.map(|tree_mask| { + tree_mask + .iter() + .map(|col_mask| { + col_mask + .iter() + .map(|off| point + trace_step.mul_signed(*off).into_ef()) + .collect() + }) + .collect() + }) } fn evaluate_constraint_quotients_at_point( @@ -144,13 +169,16 @@ impl Component for PoseidonComponent { let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset; let denom = coset_vanishing(constraint_zero_domain, point); let denom_inverse = denom.inverse(); - let mut poseidon_eval = PoseidonEval { - eval: PointEvaluator::new(mask.as_ref(), evaluation_accumulator, denom_inverse), + + let mut eval = PointEvaluator::new(mask.as_ref(), evaluation_accumulator, denom_inverse); + let [is_first] = eval.next_interaction_mask(2, [0]); + let poseidon_eval = PoseidonEval { + eval, + logup: LogupAtRow::new(1, self.claimed_sum, is_first), + lookup_elements: self.lookup_elements, }; - for _ in 0..N_INSTANCES_PER_ROW { - poseidon_eval.eval(); - } - assert_eq!(poseidon_eval.eval.col_index[0], N_COLUMNS); + let eval = poseidon_eval.eval(); + assert_eq!(eval.col_index[0], N_COLUMNS); } } @@ -225,49 +253,64 @@ fn pow5(x: F) -> F { struct PoseidonEval { eval: E, + logup: LogupAtRow<2, E>, + lookup_elements: LookupElements, } impl PoseidonEval { - fn eval(&mut self) { - let mut state: [_; N_STATE] = std::array::from_fn(|_| self.eval.next_trace_mask()); + fn eval(mut self) -> E { + for _ in 0..N_INSTANCES_PER_ROW { + let mut state: [_; N_STATE] = std::array::from_fn(|_| self.eval.next_trace_mask()); + + // Require state lookup. + self.logup + .push_lookup(&mut self.eval, E::EF::one(), &state, self.lookup_elements); - // 4 full rounds. - (0..N_HALF_FULL_ROUNDS).for_each(|round| { - (0..N_STATE).for_each(|i| { - state[i] += EXTERNAL_ROUND_CONSTS[round][i]; + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += EXTERNAL_ROUND_CONSTS[round][i]; + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter_mut().for_each(|s| { + let m = self.eval.next_trace_mask(); + self.eval.add_constraint(*s - m); + *s = m; + }); }); - apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); - state.iter_mut().for_each(|s| { + + // Partial rounds. + (0..N_PARTIAL_ROUNDS).for_each(|round| { + state[0] += INTERNAL_ROUND_CONSTS[round]; + apply_internal_round_matrix(&mut state); + state[0] = pow5(state[0]); let m = self.eval.next_trace_mask(); - self.eval.add_constraint(*s - m); - *s = m; + self.eval.add_constraint(state[0] - m); + state[0] = m; }); - }); - - // Partial rounds. - (0..N_PARTIAL_ROUNDS).for_each(|round| { - state[0] += INTERNAL_ROUND_CONSTS[round]; - apply_internal_round_matrix(&mut state); - state[0] = pow5(state[0]); - let m = self.eval.next_trace_mask(); - self.eval.add_constraint(state[0] - m); - state[0] = m; - }); - // 4 full rounds. - (0..N_HALF_FULL_ROUNDS).for_each(|round| { - (0..N_STATE).for_each(|i| { - state[i] += EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i]; - }); - apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); - state.iter_mut().for_each(|s| { - let m = self.eval.next_trace_mask(); - self.eval.add_constraint(*s - m); - *s = m; + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i]; + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter_mut().for_each(|s| { + let m = self.eval.next_trace_mask(); + self.eval.add_constraint(*s - m); + *s = m; + }); }); - }); + + // Provide state lookup. + self.logup + .push_lookup(&mut self.eval, -E::EF::one(), &state, self.lookup_elements); + } + + self.logup.finalize(&mut self.eval); + self.eval } } @@ -317,7 +360,7 @@ pub fn gen_trace( lookup_data.initial_state[rep_i] .iter_mut() .zip(state) - .for_each(|(res, state)| res.data[vec_index] = state); + .for_each(|(res, si)| res.data[vec_index] = si); // 4 full rounds. (0..N_HALF_FULL_ROUNDS).for_each(|round| { @@ -359,7 +402,7 @@ pub fn gen_trace( lookup_data.final_state[rep_i] .iter_mut() .zip(state) - .for_each(|(res, state)| res.data[vec_index] = state); + .for_each(|(res, si)| res.data[vec_index] = si); } } let domain = CanonicCoset::new(log_size).circle_domain(); @@ -493,24 +536,26 @@ impl ComponentProver for PoseidonComponent { iter.for_each(|(chunk_offset, mut col_chunk)| { for offset in 0..CHUNK_SIZE { let vec_row = chunk_offset + offset; - let mut evaluator = PoseidonEval { - eval: SimdDomainEvaluator::new( - &trace_eval_ref, - vec_row, - &pows, - self.log_n_rows, - self.log_n_rows + LOG_EXPAND, - ), + let mut eval = SimdDomainEvaluator::new( + &trace_eval_ref, + vec_row, + &pows, + self.log_n_rows, + self.log_n_rows + LOG_EXPAND, + ); + let [is_first] = eval.next_interaction_mask(2, [0]); + let poseidon_eval = PoseidonEval { + eval, + logup: LogupAtRow::new(1, self.claimed_sum, is_first), + lookup_elements: self.lookup_elements, }; - for _ in 0..N_INSTANCES_PER_ROW { - evaluator.eval(); - } + let eval = poseidon_eval.eval(); + let row_res = eval.row_res; let packed_denom_inv = packed_denoms_inv[vec_row >> (zero_domain.log_size() - LOG_N_LANES)]; - let quotient = evaluator.eval.row_res * packed_denom_inv; + let quotient = row_res * packed_denom_inv; unsafe { col_chunk.set_packed(offset, col_chunk.packed_at(offset) + quotient) }; - assert_eq!(evaluator.eval.constraint_index, n_constraints); } }); } @@ -550,8 +595,7 @@ pub fn prove_poseidon(log_n_instances: u32) -> (PoseidonAir, StarkProof) { // Interaction trace. let span = span!(Level::INFO, "Interaction").entered(); - let (trace, _claimed_logup_sum) = - gen_interaction_trace(log_n_rows, lookup_data, lookup_elements); + let (trace, claimed_sum) = gen_interaction_trace(log_n_rows, lookup_data, lookup_elements); let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals(trace); tree_builder.commit(channel); @@ -565,7 +609,11 @@ pub fn prove_poseidon(log_n_instances: u32) -> (PoseidonAir, StarkProof) { span.exit(); // Prove constraints. - let component = PoseidonComponent { log_n_rows }; + let component = PoseidonComponent { + log_n_rows, + lookup_elements, + claimed_sum, + }; let air = PoseidonAir { component }; let proof = prove::( &air, @@ -585,9 +633,9 @@ mod tests { use itertools::Itertools; use num_traits::One; - use crate::constraint_framework::assert_constraints; use crate::constraint_framework::constant_columns::gen_is_first; - use crate::constraint_framework::logup::LookupElements; + use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; + use crate::constraint_framework::{assert_constraints, EvalAtRow}; use crate::core::air::AirExt; use crate::core::channel::{Blake2sChannel, Channel}; use crate::core::fields::m31::BaseField; @@ -652,15 +700,21 @@ mod tests { z: qm31!(1, 2, 3, 4), alpha: qm31!(5, 6, 7, 8), }; - let (trace1, _claimed_logup_sum) = + let (trace1, claimed_sum) = gen_interaction_trace(LOG_N_ROWS, interaction_data, lookup_elements); let trace2 = vec![gen_is_first(LOG_N_ROWS)]; let traces = TreeVec::new(vec![trace0, trace1, trace2]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); - assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |eval| { - PoseidonEval { eval }.eval(); + assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |mut eval| { + let [is_first] = eval.next_interaction_mask(2, [0]); + PoseidonEval { + eval, + logup: LogupAtRow::new(1, claimed_sum, is_first), + lookup_elements, + } + .eval(); }); } @@ -681,6 +735,7 @@ mod tests { let (air, proof) = prove_poseidon(log_n_instances); // Verify. + // TODO: Create Air instance independently. let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); let commitment_scheme = &mut CommitmentSchemeVerifier::new(); @@ -689,6 +744,10 @@ mod tests { let sizes = air.column_log_sizes(); // Trace columns. commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); + // Draw lookup element. + let lookup_elements = LookupElements::draw(channel); + assert_eq!(lookup_elements, air.component.lookup_elements); + // TODO(spapini): Check claimed sum against first and last instances. // Interaction columns. commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); // Constant columns.