diff --git a/crates/prover/src/builder/logup.rs b/crates/prover/src/builder/logup.rs index f86c5fda0..17f09b733 100644 --- a/crates/prover/src/builder/logup.rs +++ b/crates/prover/src/builder/logup.rs @@ -1,7 +1,8 @@ use itertools::Itertools; -use num_traits::Zero; +use num_traits::{One, Zero}; use tracing::{span, Level}; +use super::EvalAtRow; use crate::core::backend::simd::column::SecureFieldVec; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::qm31::PackedSecureField; @@ -14,10 +15,91 @@ use crate::core::fields::secure_column::SecureColumn; use crate::core::fields::FieldExpOps; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; use crate::core::poly::BitReversedOrder; -use crate::core::utils::bit_reverse_index; +use crate::core::utils::{bit_reverse_index, shifted_secure_combination}; use crate::core::ColumnVec; -#[derive(Copy, Clone, Debug)] +pub struct LogupAtRow { + pub interaction: usize, + pub queue: [(E::EF, E::EF); BATCH_SIZE], + pub queue_size: usize, + pub claimed_sum: SecureField, + pub prev_mask: E::EF, + pub is_first: E::F, +} +impl LogupAtRow { + pub fn new(interaction: usize, claimed_sum: SecureField, is_first: E::F) -> Self { + Self { + interaction, + queue: [(E::EF::zero(), E::EF::zero()); BATCH_SIZE], + queue_size: 0, + claimed_sum, + prev_mask: E::EF::zero(), + is_first, + } + } + pub fn push_lookup( + &mut self, + eval: &mut E, + numerator: E::EF, + values: &[E::F], + lookup_elements: LookupElements, + ) { + let shifted_value = shifted_secure_combination( + values, + E::EF::zero() + lookup_elements.alpha, + E::EF::zero() + lookup_elements.z, + ); + self.push_frac(eval, numerator, shifted_value); + } + + pub fn push_frac(&mut self, eval: &mut E, p: E::EF, q: E::EF) { + if self.queue_size < BATCH_SIZE { + self.queue[self.queue_size] = (p, q); + self.queue_size += 1; + return; + } + + // Compute sum_i pi/qi over batch, as a fraction, p/q. + let (num, denom) = self + .queue + .iter() + .copied() + .fold((E::EF::zero(), E::EF::one()), |(p0, q0), (pi, qi)| { + (p0 * qi + pi * q0, qi * q0) + }); + + self.queue[0] = (p, q); + self.queue_size = 1; + + // Add a constraint that p / q = diff. + let cur = E::combine_ef(std::array::from_fn(|_| { + eval.next_interaction_mask(1, [0])[0] + })); + let diff = cur - self.prev_mask; + self.prev_mask = cur; + eval.add_constraint(diff * denom - num); + } + + pub fn finalize(self, eval: &mut E) { + let (p, q) = self.queue[0..self.queue_size] + .iter() + .copied() + .fold((E::EF::zero(), E::EF::one()), |(p0, q0), (pi, qi)| { + (p0 * qi + pi * q0, qi * q0) + }); + + let cumulative_mask_values = + std::array::from_fn(|_| eval.next_interaction_mask(self.interaction, [0, -1])); + let cur = E::combine_ef(cumulative_mask_values.map(|[cur, _prev]| cur)); + let up = E::combine_ef(cumulative_mask_values.map(|[_cur, prev]| prev)); + let up = up - self.is_first * self.claimed_sum; + let diff = cur - up - self.prev_mask; + + eval.add_constraint(diff * q - p); + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct LookupElements { pub z: SecureField, pub alpha: SecureField, diff --git a/crates/prover/src/builder/mod.rs b/crates/prover/src/builder/mod.rs index edd26096d..29ba4c46b 100644 --- a/crates/prover/src/builder/mod.rs +++ b/crates/prover/src/builder/mod.rs @@ -6,7 +6,7 @@ pub mod logup; mod point; use std::fmt::Debug; -use std::ops::{Add, AddAssign, Mul, Sub}; +use std::ops::{Add, AddAssign, Mul, Neg, Sub}; pub use assert::AssertEvaluator; pub use domain::DomainEvaluator; @@ -34,6 +34,7 @@ pub trait EvalAtRow { + Copy + Debug + Zero + + Neg + Add + Sub + Mul diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 9ff4b162b..040a17748 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -3,13 +3,13 @@ use std::ops::{Add, AddAssign, Mul, Sub}; use itertools::Itertools; +use num_traits::{One, Zero}; use tracing::{span, Level}; use crate::builder::constant_cols::gen_is_first; -use crate::builder::logup::{LogupTraceGenerator, LookupElements}; -use crate::builder::{DomainEvaluator, EvalAtRow, PointEvaluator}; +use crate::builder::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; +use crate::builder::{DomainEvaluator, EvalAtRow, InfoEvaluator, PointEvaluator}; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use crate::core::air::mask::fixed_mask_points; use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace}; use crate::core::backend::simd::column::BaseFieldVec; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; @@ -21,7 +21,6 @@ use crate::core::circle::CirclePoint; use crate::core::constraints::coset_vanishing; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; use crate::core::fields::{FieldExpOps, FieldOps, IntoSlice}; use crate::core::pcs::{CommitmentSchemeProver, TreeVec}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; @@ -51,6 +50,8 @@ const INTERNAL_ROUND_CONSTS: [BaseField; N_PARTIAL_ROUNDS] = #[derive(Clone)] pub struct PoseidonComponent { pub log_n_rows: u32, + pub lookup_elements: LookupElements, + pub claimed_sum: SecureField, } impl PoseidonComponent { @@ -98,9 +99,22 @@ impl AirTraceGenerator for PoseidonAir { } } +pub fn poseidon_info() -> InfoEvaluator { + let mut counter = PoseidonEval { + eval: InfoEvaluator::default(), + lookup_elements: LookupElements { + z: SecureField::one(), + alpha: SecureField::one(), + }, + logup: LogupAtRow::new(1, SecureField::zero(), BaseField::zero()), + }; + counter.eval.next_interaction_mask(2, [0]); + counter.eval() +} + impl Component for PoseidonComponent { fn n_constraints(&self) -> usize { - (N_COLUMNS_PER_REP - N_STATE) * N_INSTANCES_PER_ROW + poseidon_info().n_constraints } fn max_constraint_log_degree_bound(&self) -> u32 { @@ -112,21 +126,32 @@ impl Component for PoseidonComponent { } fn trace_log_degree_bounds(&self) -> TreeVec> { - TreeVec::new(vec![ - vec![self.log_column_size(); N_COLUMNS], - vec![self.log_column_size(); N_INSTANCES_PER_ROW * SECURE_EXTENSION_DEGREE], - ]) + TreeVec::new( + poseidon_info() + .mask_offsets + .iter() + .map(|tree_masks| vec![self.log_n_rows; tree_masks.len()]) + .collect(), + ) } fn mask_points( &self, point: CirclePoint, ) -> TreeVec>>> { - TreeVec::new(vec![ - fixed_mask_points(&vec![vec![0_usize]; N_COLUMNS], point), - vec![vec![]; N_INSTANCES_PER_ROW * SECURE_EXTENSION_DEGREE], - vec![vec![point]], - ]) + let trace_step = CanonicCoset::new(self.log_n_rows).step(); + let counter = poseidon_info(); + counter.mask_offsets.map(|tree_mask| { + tree_mask + .iter() + .map(|col_mask| { + col_mask + .iter() + .map(|off| point + trace_step.mul_signed(*off).into_ef()) + .collect() + }) + .collect() + }) } fn interaction_element_ids(&self) -> Vec { @@ -144,13 +169,16 @@ impl Component for PoseidonComponent { let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset; let denom = coset_vanishing(constraint_zero_domain, point); let denom_inverse = denom.inverse(); - let mut eval = PoseidonEval { - eval: PointEvaluator::new(mask.as_ref(), evaluation_accumulator, denom_inverse), + + let mut eval = PointEvaluator::new(mask.as_ref(), evaluation_accumulator, denom_inverse); + let [is_first] = eval.next_interaction_mask(2, [0]); + let poseidon_eval = PoseidonEval { + eval, + logup: LogupAtRow::new(1, self.claimed_sum, is_first), + lookup_elements: self.lookup_elements, }; - for _ in 0..N_INSTANCES_PER_ROW { - eval.eval(); - } - assert_eq!(eval.eval.col_index[0], N_COLUMNS); + let eval = poseidon_eval.eval(); + assert_eq!(eval.col_index[0], N_COLUMNS); } } @@ -225,49 +253,64 @@ fn pow5(x: F) -> F { struct PoseidonEval { eval: E, + logup: LogupAtRow<2, E>, + lookup_elements: LookupElements, } impl PoseidonEval { - fn eval(&mut self) { - let mut state: [_; N_STATE] = std::array::from_fn(|_| self.eval.next_mask()); + fn eval(mut self) -> E { + for _ in 0..N_INSTANCES_PER_ROW { + let mut state: [_; N_STATE] = std::array::from_fn(|_| self.eval.next_mask()); + + // Require state lookup. + self.logup + .push_lookup(&mut self.eval, E::EF::one(), &state, self.lookup_elements); - // 4 full rounds. - (0..N_HALF_FULL_ROUNDS).for_each(|round| { - (0..N_STATE).for_each(|i| { - state[i] += EXTERNAL_ROUND_CONSTS[round][i]; + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += EXTERNAL_ROUND_CONSTS[round][i]; + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter_mut().for_each(|s| { + let m = self.eval.next_mask(); + self.eval.add_constraint(*s - m); + *s = m; + }); }); - apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); - state.iter_mut().for_each(|s| { + + // Partial rounds. + (0..N_PARTIAL_ROUNDS).for_each(|round| { + state[0] += INTERNAL_ROUND_CONSTS[round]; + apply_internal_round_matrix(&mut state); + state[0] = pow5(state[0]); let m = self.eval.next_mask(); - self.eval.add_constraint(*s - m); - *s = m; - }); - }); - - // Partial rounds. - (0..N_PARTIAL_ROUNDS).for_each(|round| { - state[0] += INTERNAL_ROUND_CONSTS[round]; - apply_internal_round_matrix(&mut state); - state[0] = pow5(state[0]); - let m = self.eval.next_mask(); - self.eval.add_constraint(state[0] - m); - state[0] = m; - }); - - // 4 full rounds. - (0..N_HALF_FULL_ROUNDS).for_each(|round| { - (0..N_STATE).for_each(|i| { - state[i] += EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i]; + self.eval.add_constraint(state[0] - m); + state[0] = m; }); - apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); - state.iter_mut().for_each(|s| { - let m = self.eval.next_mask(); - self.eval.add_constraint(*s - m); - *s = m; + + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i]; + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter_mut().for_each(|s| { + let m = self.eval.next_mask(); + self.eval.add_constraint(*s - m); + *s = m; + }); }); - }); + + // Provide state lookup. + self.logup + .push_lookup(&mut self.eval, -E::EF::one(), &state, self.lookup_elements); + } + + self.logup.finalize(&mut self.eval); + self.eval } } @@ -316,7 +359,7 @@ pub fn gen_trace( lookup_data.initial_state[rep_i] .iter_mut() .zip(state) - .for_each(|(res, state)| res.data[vec_index] = state); + .for_each(|(res, si)| res.data[vec_index] = si); // 4 full rounds. (0..N_HALF_FULL_ROUNDS).for_each(|round| { @@ -358,7 +401,7 @@ pub fn gen_trace( lookup_data.final_state[rep_i] .iter_mut() .zip(state) - .for_each(|(res, state)| res.data[vec_index] = state); + .for_each(|(res, si)| res.data[vec_index] = si); } } let domain = CanonicCoset::new(log_size).circle_domain(); @@ -484,19 +527,21 @@ impl ComponentProver for PoseidonComponent { pows.reverse(); for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES)) { - let mut evaluator = PoseidonEval { - eval: DomainEvaluator::new( - &trace_eval_ref, - vec_row, - &pows, - self.log_n_rows, - self.log_n_rows + LOG_EXPAND, - ), + let mut eval = DomainEvaluator::new( + &trace_eval_ref, + vec_row, + &pows, + self.log_n_rows, + self.log_n_rows + LOG_EXPAND, + ); + let [is_first] = eval.next_interaction_mask(2, [0]); + let poseidon_eval = PoseidonEval { + eval, + logup: LogupAtRow::new(1, self.claimed_sum, is_first), + lookup_elements: self.lookup_elements, }; - for _ in 0..N_INSTANCES_PER_ROW { - evaluator.eval(); - } - let row_res = evaluator.eval.row_res; + let eval = poseidon_eval.eval(); + let row_res = eval.row_res; unsafe { accum.col.set_packed( @@ -504,7 +549,7 @@ impl ComponentProver for PoseidonComponent { accum.col.packed_at(vec_row) + row_res * denom_inverses.data[vec_row], ) } - assert_eq!(evaluator.eval.constraint_index, n_constraints); + assert_eq!(eval.constraint_index, n_constraints); } } @@ -544,8 +589,7 @@ pub fn prove_poseidon(log_n_instances: u32) -> (PoseidonAir, StarkProof) { // Interaction trace. let span = span!(Level::INFO, "Interaction").entered(); let span1 = span!(Level::INFO, "Generation").entered(); - let (trace, _claimed_logup_sum) = - gen_interaction_trace(log_n_rows, interaction_data, lookup_elements); + let (trace, claimed_sum) = gen_interaction_trace(log_n_rows, interaction_data, lookup_elements); span1.exit(); commitment_scheme.commit_on_evals(trace, channel, &twiddles); span.exit(); @@ -556,7 +600,11 @@ pub fn prove_poseidon(log_n_instances: u32) -> (PoseidonAir, StarkProof) { span.exit(); // Prove constraints. - let component = PoseidonComponent { log_n_rows }; + let component = PoseidonComponent { + log_n_rows, + lookup_elements, + claimed_sum, + }; let air = PoseidonAir { component }; let proof = prove_without_commit::( &air, @@ -577,6 +625,7 @@ mod tests { use itertools::Itertools; use num_traits::One; + use crate::builder::logup::LookupElements; use crate::core::air::AirExt; use crate::core::channel::{Blake2sChannel, Channel}; use crate::core::fields::m31::BaseField; @@ -643,6 +692,7 @@ mod tests { let (air, proof) = prove_poseidon(log_n_instances); // Verify. + // TODO: Create Air instance independently. let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); let commitment_scheme = &mut CommitmentSchemeVerifier::new(); @@ -650,6 +700,10 @@ mod tests { let sizes = air.column_log_sizes(); // Trace columns. commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); + // Draw lookup element. + let lookup_elements = LookupElements::draw(channel); + assert_eq!(lookup_elements, air.component.lookup_elements); + // TODO(spapini): Check claimed sum against first and last instances. // Interaction columns. commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); // Constant columns.