diff --git a/crates/prover/benches/poseidon.rs b/crates/prover/benches/poseidon.rs index 594acb62d..3abe134ab 100644 --- a/crates/prover/benches/poseidon.rs +++ b/crates/prover/benches/poseidon.rs @@ -1,27 +1,12 @@ use criterion::{criterion_group, criterion_main, Criterion, Throughput}; -use stwo_prover::core::backend::simd::SimdBackend; -use stwo_prover::core::channel::{Blake2sChannel, Channel}; -use stwo_prover::core::fields::m31::BaseField; -use stwo_prover::core::fields::IntoSlice; -use stwo_prover::core::prover::prove; -use stwo_prover::core::vcs::blake2_hash::Blake2sHasher; -use stwo_prover::core::vcs::hasher::Hasher; -use stwo_prover::examples::poseidon::{gen_trace, PoseidonAir, PoseidonComponent}; +use stwo_prover::examples::poseidon::prove_poseidon; pub fn simd_poseidon(c: &mut Criterion) { - const LOG_N_ROWS: u32 = 15; + const LOG_N_INSTANCES: u32 = 18; let mut group = c.benchmark_group("poseidon2"); - group.throughput(Throughput::Elements(1u64 << (LOG_N_ROWS + 3))); - group.bench_function(format!("poseidon2 2^{} instances", LOG_N_ROWS + 3), |b| { - b.iter(|| { - let component = PoseidonComponent { - log_n_rows: LOG_N_ROWS, - }; - let trace = gen_trace(component.log_column_size()); - let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); - let air = PoseidonAir { component }; - prove::(&air, channel, trace).unwrap() - }); + group.throughput(Throughput::Elements(1u64 << LOG_N_INSTANCES)); + group.bench_function(format!("poseidon2 2^{} instances", LOG_N_INSTANCES), |b| { + b.iter(|| prove_poseidon(LOG_N_INSTANCES)); }); } diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs new file mode 100644 index 000000000..f86c5fda0 --- /dev/null +++ b/crates/prover/src/constraint_framework/logup.rs @@ -0,0 +1,134 @@ +use itertools::Itertools; +use num_traits::Zero; +use tracing::{span, Level}; + +use crate::core::backend::simd::column::SecureFieldVec; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Backend, Column}; +use crate::core::channel::{Blake2sChannel, Channel}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SecureColumn; +use crate::core::fields::FieldExpOps; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::bit_reverse_index; +use crate::core::ColumnVec; + +#[derive(Copy, Clone, Debug)] +pub struct LookupElements { + pub z: SecureField, + pub alpha: SecureField, +} +impl LookupElements { + pub fn draw(channel: &mut Blake2sChannel) -> Self { + let [z, alpha] = channel.draw_felts(2).try_into().unwrap(); + Self { z, alpha } + } +} + +// SIMD backend generator. +pub struct LogupTraceGenerator { + log_size: u32, + trace: Vec>, + denom: SecureFieldVec, + denom_inv: SecureFieldVec, +} +impl LogupTraceGenerator { + pub fn new(log_size: u32) -> Self { + let trace = vec![]; + let denom = SecureFieldVec::zeros(1 << log_size); + let denom_inv = SecureFieldVec::zeros(1 << log_size); + Self { + log_size, + trace, + denom, + denom_inv, + } + } + + pub fn new_col(&mut self) -> LogupColGenerator<'_> { + let log_size = self.log_size; + LogupColGenerator { + gen: self, + numerator: SecureColumn::::zeros(1 << log_size), + } + } + + pub fn finalize( + mut self, + ) -> ( + ColumnVec>, + SecureField, + ) { + let claimed_xor_sum = eval_order_prefix_sum(self.trace.last_mut().unwrap(), self.log_size); + + let trace = self + .trace + .into_iter() + .flat_map(|eval| { + eval.columns.map(|c| { + CircleEvaluation::::new( + CanonicCoset::new(self.log_size).circle_domain(), + c, + ) + }) + }) + .collect_vec(); + (trace, claimed_xor_sum) + } +} + +pub struct LogupColGenerator<'a> { + gen: &'a mut LogupTraceGenerator, + numerator: SecureColumn, +} +impl<'a> LogupColGenerator<'a> { + pub fn write_frac(&mut self, vec_row: usize, p: PackedSecureField, q: PackedSecureField) { + unsafe { + self.numerator.set_packed(vec_row, p); + *self.gen.denom.data.get_unchecked_mut(vec_row) = q; + } + } + + pub fn finalize_col(mut self) { + FieldExpOps::batch_inverse(&self.gen.denom.data, &mut self.gen.denom_inv.data); + + #[allow(clippy::needless_range_loop)] + for vec_row in 0..(1 << (self.gen.log_size - LOG_N_LANES)) { + unsafe { + let value = self.numerator.packed_at(vec_row) + * *self.gen.denom_inv.data.get_unchecked(vec_row); + let prev_value = self + .gen + .trace + .last() + .map(|col| col.packed_at(vec_row)) + .unwrap_or_else(PackedSecureField::zero); + self.numerator.set_packed(vec_row, value + prev_value) + }; + } + + self.gen.trace.push(self.numerator) + } +} + +// TODO(spapini): Consider adding optional Ops. +pub fn eval_order_prefix_sum(col: &mut SecureColumn, log_size: u32) -> SecureField { + let _span = span!(Level::INFO, "Prefix sum").entered(); + + let mut cur = SecureField::zero(); + for i in 0..(1 << log_size) { + let index = if i & 1 == 0 { + i / 2 + } else { + (1 << (log_size - 1)) + ((1 << log_size) - 1 - i) / 2 + }; + let index = bit_reverse_index(index, log_size); + cur += col.at(index); + col.set(index, cur); + } + cur +} diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index e4cac72b0..2d5fc1926 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -2,6 +2,8 @@ mod assert; pub mod constant_cols; mod domain; +mod info; +pub mod logup; mod point; use std::fmt::Debug; @@ -9,6 +11,7 @@ use std::ops::{Add, AddAssign, Mul, Sub}; pub use assert::{assert_constraints, AssertEvaluator}; pub use domain::DomainEvaluator; +pub use info::InfoEvaluator; use num_traits::{One, Zero}; pub use point::PointEvaluator; diff --git a/crates/prover/src/core/fields/secure_column.rs b/crates/prover/src/core/fields/secure_column.rs index d5c9767fd..f4caa67ad 100644 --- a/crates/prover/src/core/fields/secure_column.rs +++ b/crates/prover/src/core/fields/secure_column.rs @@ -2,7 +2,6 @@ use super::m31::BaseField; use super::qm31::SecureField; use super::{ExtensionOf, FieldOps}; use crate::core::backend::{Col, Column, CpuBackend}; -use crate::core::utils::IteratorMutExt; pub const SECURE_EXTENSION_DEGREE: usize = >::EXTENSION_DEGREE; @@ -14,13 +13,6 @@ pub struct SecureColumn> { pub columns: [Col; SECURE_EXTENSION_DEGREE], } impl SecureColumn { - pub fn set(&mut self, index: usize, value: SecureField) { - self.columns - .iter_mut() - .map(|c| &mut c[index]) - .assign(value.to_m31_array()); - } - // TODO(spapini): Remove when we no longer use CircleEvaluation. pub fn to_vec(&self) -> Vec { (0..self.len()).map(|i| self.at(i)).collect() @@ -50,6 +42,12 @@ impl> SecureColumn { columns: self.columns.clone().map(|c| c.to_cpu()), } } + + pub fn set(&mut self, index: usize, value: SecureField) { + for i in 0..SECURE_EXTENSION_DEGREE { + self.columns[i].set(index, value.to_m31_array()[i]); + } + } } pub struct SecureColumnIter<'a> { diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 2b32e6d93..35583b22c 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -5,23 +5,31 @@ use std::ops::{Add, AddAssign, Mul, Sub}; use itertools::Itertools; use tracing::{span, Level}; +use crate::constraint_framework::constant_cols::gen_is_first; +use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements}; use crate::constraint_framework::{DomainEvaluator, EvalAtRow, PointEvaluator}; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::mask::fixed_mask_points; use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace}; use crate::core::backend::simd::column::BaseFieldVec; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Col, Column, ColumnOps}; -use crate::core::channel::Blake2sChannel; +use crate::core::channel::{Blake2sChannel, Channel}; use crate::core::circle::CirclePoint; use crate::core::constraints::coset_vanishing; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::fields::{FieldExpOps, FieldOps}; -use crate::core::pcs::TreeVec; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::fields::{FieldExpOps, FieldOps, IntoSlice}; +use crate::core::pcs::{CommitmentSchemeProver, TreeVec}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; +use crate::core::prover::{prove_without_commit, StarkProof, LOG_BLOWUP_FACTOR}; +use crate::core::utils::shifted_secure_combination; +use crate::core::vcs::blake2_hash::Blake2sHasher; +use crate::core::vcs::hasher::Hasher; use crate::core::{ColumnVec, InteractionElements, LookupValues}; use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; @@ -104,7 +112,10 @@ impl Component for PoseidonComponent { } fn trace_log_degree_bounds(&self) -> TreeVec> { - TreeVec::new(vec![vec![self.log_column_size(); N_COLUMNS]]) + TreeVec::new(vec![ + vec![self.log_column_size(); N_COLUMNS], + vec![self.log_column_size(); N_INSTANCES_PER_ROW * SECURE_EXTENSION_DEGREE], + ]) } fn mask_points( @@ -113,6 +124,7 @@ impl Component for PoseidonComponent { ) -> TreeVec>>> { TreeVec::new(vec![ fixed_mask_points(&vec![vec![0_usize]; N_COLUMNS], point), + vec![vec![]; N_INSTANCES_PER_ROW * SECURE_EXTENSION_DEGREE], vec![vec![point]], ]) } @@ -261,13 +273,29 @@ impl AirProver for PoseidonAir { } } +pub struct LookupData { + initial_state: [[BaseFieldVec; N_STATE]; N_INSTANCES_PER_ROW], + final_state: [[BaseFieldVec; N_STATE]; N_INSTANCES_PER_ROW], +} pub fn gen_trace( log_size: u32, -) -> ColumnVec> { +) -> ( + ColumnVec>, + LookupData, +) { assert!(log_size >= LOG_N_LANES); let mut trace = (0..N_COLUMNS) .map(|_| Col::::zeros(1 << log_size)) .collect_vec(); + let mut lookup_data = LookupData { + initial_state: std::array::from_fn(|_| { + std::array::from_fn(|_| BaseFieldVec::zeros(1 << log_size)) + }), + final_state: std::array::from_fn(|_| { + std::array::from_fn(|_| BaseFieldVec::zeros(1 << log_size)) + }), + }; + for vec_index in 0..(1 << (log_size - LOG_N_LANES)) { // Initial state. let mut col_index = 0; @@ -281,6 +309,10 @@ pub fn gen_trace( trace[col_index].data[vec_index] = s; col_index += 1; }); + lookup_data.initial_state[rep_i] + .iter_mut() + .zip(state) + .for_each(|(res, state)| res.data[vec_index] = state); // 4 full rounds. (0..N_HALF_FULL_ROUNDS).for_each(|round| { @@ -318,13 +350,61 @@ pub fn gen_trace( col_index += 1; }); }); + + lookup_data.final_state[rep_i] + .iter_mut() + .zip(state) + .for_each(|(res, state)| res.data[vec_index] = state); } } let domain = CanonicCoset::new(log_size).circle_domain(); - trace + let trace = trace .into_iter() .map(|eval| CircleEvaluation::::new(domain, eval)) - .collect_vec() + .collect_vec(); + (trace, lookup_data) +} + +pub fn gen_interaction_trace( + log_size: u32, + lookup_data: LookupData, + lookup_elements: LookupElements, +) -> ( + ColumnVec>, + SecureField, +) { + let _span = span!(Level::INFO, "Generate interaction trace").entered(); + let LookupElements { z, alpha } = lookup_elements; + let alpha = PackedSecureField::broadcast(alpha); + let broadcast = PackedSecureField::broadcast(z); + let z = broadcast; + let mut logup_gen = LogupTraceGenerator::new(log_size); + + #[allow(clippy::needless_range_loop)] + for rep_i in 0..N_INSTANCES_PER_ROW { + let mut col_gen = logup_gen.new_col(); + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let q0 = shifted_secure_combination( + &lookup_data.initial_state[rep_i] + .each_ref() + .map(|s| s.data[vec_row]), + alpha, + z, + ); + let q1 = shifted_secure_combination( + &lookup_data.final_state[rep_i] + .each_ref() + .map(|s| s.data[vec_row]), + alpha, + z, + ); + // 1 / q0 - 1 / q1 = (q1 - q0) / (q0 * q1). + col_gen.write_frac(vec_row, q1 - q0, q0 * q1); + } + col_gen.finalize_col(); + } + + logup_gen.finalize() } impl ComponentTraceGenerator for PoseidonComponent { @@ -429,33 +509,89 @@ impl ComponentProver for PoseidonComponent { } } +pub fn prove_poseidon(log_n_instances: u32) -> (PoseidonAir, StarkProof) { + assert!(log_n_instances >= N_LOG_INSTANCES_PER_ROW as u32); + let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32; + + // Precompute twiddles. + let span = span!(Level::INFO, "Precompute twiddles").entered(); + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(log_n_rows + LOG_EXPAND + LOG_BLOWUP_FACTOR) + .circle_domain() + .half_coset, + ); + span.exit(); + + // Setup protocol. + let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + let commitment_scheme = &mut CommitmentSchemeProver::new(LOG_BLOWUP_FACTOR); + + // Trace. + let span = span!(Level::INFO, "Trace").entered(); + let span1 = span!(Level::INFO, "Generation").entered(); + let (trace, interaction_data) = gen_trace(log_n_rows); + span1.exit(); + commitment_scheme.commit_on_evals(trace, channel, &twiddles); + span.exit(); + + // Draw lookup element. + let lookup_elements = LookupElements::draw(channel); + + // Interaction trace. + let span = span!(Level::INFO, "Interaction").entered(); + let span1 = span!(Level::INFO, "Generation").entered(); + let (trace, _claimed_logup_sum) = + gen_interaction_trace(log_n_rows, interaction_data, lookup_elements); + span1.exit(); + commitment_scheme.commit_on_evals(trace, channel, &twiddles); + span.exit(); + + // Constant trace. + let span = span!(Level::INFO, "Constant").entered(); + commitment_scheme.commit_on_evals(vec![gen_is_first(log_n_rows)], channel, &twiddles); + span.exit(); + + // Prove constraints. + let component = PoseidonComponent { log_n_rows }; + let air = PoseidonAir { component }; + let proof = prove_without_commit::( + &air, + channel, + &InteractionElements::default(), + &twiddles, + commitment_scheme, + ) + .unwrap(); + + (air, proof) +} + #[cfg(test)] mod tests { use std::env; use itertools::Itertools; use num_traits::One; - use tracing::{span, Level}; - use super::N_LOG_INSTANCES_PER_ROW; use crate::constraint_framework::assert_constraints; use crate::constraint_framework::constant_cols::gen_is_first; + use crate::constraint_framework::logup::LookupElements; use crate::core::air::AirExt; - use crate::core::backend::simd::SimdBackend; use crate::core::channel::{Blake2sChannel, Channel}; use crate::core::fields::m31::BaseField; use crate::core::fields::IntoSlice; - use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, TreeVec}; - use crate::core::poly::circle::{CanonicCoset, PolyOps}; - use crate::core::prover::{prove_without_commit, verify_without_commit, LOG_BLOWUP_FACTOR}; + use crate::core::pcs::{CommitmentSchemeVerifier, TreeVec}; + use crate::core::poly::circle::CanonicCoset; + use crate::core::prover::verify_without_commit; use crate::core::vcs::blake2_hash::Blake2sHasher; use crate::core::vcs::hasher::Hasher; use crate::core::InteractionElements; use crate::examples::poseidon::{ - apply_internal_round_matrix, apply_m4, gen_trace, PoseidonAir, PoseidonComponent, - PoseidonEval, LOG_EXPAND, + apply_internal_round_matrix, apply_m4, gen_interaction_trace, gen_trace, prove_poseidon, + PoseidonEval, }; use crate::math::matrix::{RowMajorMatrix, SquareMatrix}; + use crate::qm31; #[test] fn test_apply_m4() { @@ -497,14 +633,20 @@ mod tests { #[test] fn test_poseidon_constraints() { const LOG_N_ROWS: u32 = 8; - let component = PoseidonComponent { - log_n_rows: LOG_N_ROWS, + + // Trace. + let (trace0, interaction_data) = gen_trace(LOG_N_ROWS); + let lookup_elements = LookupElements { + z: qm31!(1, 2, 3, 4), + alpha: qm31!(5, 6, 7, 8), }; - let trace = gen_trace(component.log_column_size()); - let trace_polys = TreeVec::new(vec![trace - .into_iter() - .map(|c| c.interpolate()) - .collect_vec()]); + let (trace1, _claimed_logup_sum) = + gen_interaction_trace(LOG_N_ROWS, interaction_data, lookup_elements); + let trace2 = vec![gen_is_first(LOG_N_ROWS)]; + + let traces = TreeVec::new(vec![trace0, trace1, trace2]); + let trace_polys = + traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |eval| { PoseidonEval { eval }.eval(); }); @@ -522,45 +664,9 @@ mod tests { .unwrap_or_else(|_| "10".to_string()) .parse::() .unwrap(); - let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32; - // Precompute twiddles. - let span = span!(Level::INFO, "Precompute twiddles").entered(); - let twiddles = SimdBackend::precompute_twiddles( - CanonicCoset::new(log_n_rows + LOG_EXPAND + LOG_BLOWUP_FACTOR) - .circle_domain() - .half_coset, - ); - span.exit(); - - // Setup protocol. - let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); - let commitment_scheme = &mut CommitmentSchemeProver::new(LOG_BLOWUP_FACTOR); - - // Trace. - let span = span!(Level::INFO, "Trace").entered(); - let span1 = span!(Level::INFO, "Generation").entered(); - let trace = gen_trace(log_n_rows); - span1.exit(); - commitment_scheme.commit_on_evals(trace, channel, &twiddles); - span.exit(); - - // Constant trace. - let span = span!(Level::INFO, "Constant").entered(); - commitment_scheme.commit_on_evals(vec![gen_is_first(log_n_rows)], channel, &twiddles); - span.exit(); - - // Prove constraints. - let component = PoseidonComponent { log_n_rows }; - let air = PoseidonAir { component }; - let proof = prove_without_commit::( - &air, - channel, - &InteractionElements::default(), - &twiddles, - commitment_scheme, - ) - .unwrap(); + // Prove. + let (air, proof) = prove_poseidon(log_n_instances); // Verify. let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); @@ -570,8 +676,10 @@ mod tests { let sizes = air.column_log_sizes(); // Trace columns. commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); + // Interaction columns. + commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); // Constant columns. - commitment_scheme.commit(proof.commitments[1], &[log_n_rows], channel); + commitment_scheme.commit(proof.commitments[2], &[air.component.log_n_rows], channel); verify_without_commit( &air, diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index 6c9ddaa02..25884902f 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -1,4 +1,5 @@ #![feature( + array_methods, array_chunks, iter_array_chunks, exact_size_is_empty,