diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 0128005c0..3c9de2969 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -23,7 +23,7 @@ starknet-crypto = "0.6.2" starknet-ff = "0.3.7" thiserror.workspace = true tracing.workspace = true -rayon = {version = "1.10.0", optional = true} +rayon = { version = "1.10.0", optional = true } [dev-dependencies] aligned = "0.4.2" @@ -82,3 +82,7 @@ harness = false [[bench]] name = "quotients" harness = false + +[[bench]] +name = "poseidon" +harness = false diff --git a/crates/prover/benches/poseidon.rs b/crates/prover/benches/poseidon.rs new file mode 100644 index 000000000..7f219663b --- /dev/null +++ b/crates/prover/benches/poseidon.rs @@ -0,0 +1,32 @@ +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::channel::{Blake2sChannel, Channel}; +use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::fields::IntoSlice; +use stwo_prover::core::prover::prove; +use stwo_prover::core::vcs::blake2_hash::Blake2sHasher; +use stwo_prover::core::vcs::hasher::Hasher; +use stwo_prover::examples::poseidon::{gen_trace, PoseidonAir, PoseidonComponent}; + +pub fn simd_poseidon(c: &mut Criterion) { + const LOG_N_ROWS: u32 = 15; + let mut group = c.benchmark_group("poseidon2"); + group.throughput(Throughput::Elements(1u64 << (LOG_N_ROWS + 3))); + group.bench_function(format!("poseidon2 2^{} instances", LOG_N_ROWS + 3), |b| { + b.iter(|| { + let component = PoseidonComponent { + log_n_instances: LOG_N_ROWS, + }; + let trace = gen_trace(component.log_column_size()); + let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + let air = PoseidonAir { component }; + prove::(&air, channel, trace).unwrap() + }); + }); +} + +criterion_group!( + name = bit_rev; + config = Criterion::default().sample_size(10); + targets = simd_poseidon); +criterion_main!(bit_rev); diff --git a/crates/prover/src/core/backend/simd/m31.rs b/crates/prover/src/core/backend/simd/m31.rs index b28ca98e1..654a6ac34 100644 --- a/crates/prover/src/core/backend/simd/m31.rs +++ b/crates/prover/src/core/backend/simd/m31.rs @@ -119,6 +119,13 @@ impl AddAssign for PackedM31 { } } +impl AddAssign for PackedM31 { + #[inline(always)] + fn add_assign(&mut self, rhs: M31) { + *self = *self + PackedM31::broadcast(rhs); + } +} + impl Mul for PackedM31 { type Output = Self; @@ -142,6 +149,15 @@ impl Mul for PackedM31 { } } +impl Mul for PackedM31 { + type Output = Self; + + #[inline(always)] + fn mul(self, rhs: BaseField) -> Self::Output { + self * PackedM31::broadcast(rhs) + } +} + impl MulAssign for PackedM31 { #[inline(always)] fn mul_assign(&mut self, rhs: Self) { diff --git a/crates/prover/src/core/fft.rs b/crates/prover/src/core/fft.rs index af464ac9e..630fbe738 100644 --- a/crates/prover/src/core/fft.rs +++ b/crates/prover/src/core/fft.rs @@ -1,13 +1,20 @@ +use std::ops::{Add, AddAssign, Mul, Sub}; + use super::fields::m31::BaseField; -use super::fields::ExtensionOf; -pub fn butterfly>(v0: &mut F, v1: &mut F, twid: BaseField) { +pub fn butterfly(v0: &mut F, v1: &mut F, twid: BaseField) +where + F: Copy + AddAssign + Sub + Mul, +{ let tmp = *v1 * twid; *v1 = *v0 - tmp; *v0 += tmp; } -pub fn ibutterfly>(v0: &mut F, v1: &mut F, itwid: BaseField) { +pub fn ibutterfly(v0: &mut F, v1: &mut F, itwid: BaseField) +where + F: Copy + AddAssign + Add + Sub + Mul, +{ let tmp = *v0; *v0 = tmp + *v1; *v1 = (tmp - *v1) * itwid; diff --git a/crates/prover/src/examples/mod.rs b/crates/prover/src/examples/mod.rs index a81ec8ec4..cec9debd2 100644 --- a/crates/prover/src/examples/mod.rs +++ b/crates/prover/src/examples/mod.rs @@ -1,2 +1,3 @@ pub mod fibonacci; +pub mod poseidon; pub mod wide_fibonacci; diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs new file mode 100644 index 000000000..6a6f6c2c8 --- /dev/null +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -0,0 +1,498 @@ +//! AIR for Poseidon2 hash function from . + +use std::ops::{Add, AddAssign, Mul, Sub}; + +use itertools::Itertools; +use num_traits::Zero; +use tracing::{span, Level}; + +use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; +use crate::core::air::mask::fixed_mask_points; +use crate::core::air::{ + Air, AirProver, AirTraceVerifier, AirTraceWriter, Component, ComponentProver, ComponentTrace, + ComponentTraceWriter, +}; +use crate::core::backend::simd::column::BaseFieldVec; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Col, Column, ColumnOps}; +use crate::core::channel::Blake2sChannel; +use crate::core::circle::CirclePoint; +use crate::core::constraints::coset_vanishing; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::{FieldExpOps, FieldOps}; +use crate::core::pcs::TreeVec; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; +use crate::core::poly::BitReversedOrder; +use crate::core::{ColumnVec, InteractionElements}; + +const N_INSTANCES_PER_ROW: usize = 8; +const N_STATE: usize = 16; +const N_PARTIAL_ROUNDS: usize = 14; +const N_HALF_FULL_ROUNDS: usize = 4; +const FULL_ROUNDS: usize = 2 * N_HALF_FULL_ROUNDS; +const N_COLUMNS_PER_REP: usize = N_STATE * (1 + FULL_ROUNDS) + N_PARTIAL_ROUNDS; +const N_COLUMNS: usize = N_INSTANCES_PER_ROW * N_COLUMNS_PER_REP; +const LOG_EXPAND: u32 = 2; +// TODO(spapini): Pick better constants. +const EXTERNAL_ROUND_CONSTS: [[BaseField; N_STATE]; 2 * N_HALF_FULL_ROUNDS] = + [[BaseField::from_u32_unchecked(1234); N_STATE]; 2 * N_HALF_FULL_ROUNDS]; +const INTERNAL_ROUND_CONSTS: [BaseField; N_PARTIAL_ROUNDS] = + [BaseField::from_u32_unchecked(1234); N_PARTIAL_ROUNDS]; + +pub struct PoseidonComponent { + pub log_n_instances: u32, +} + +impl PoseidonComponent { + pub fn log_column_size(&self) -> u32 { + self.log_n_instances + } + + pub fn n_columns(&self) -> usize { + N_COLUMNS + } +} + +pub struct PoseidonAir { + pub component: PoseidonComponent, +} + +impl Air for PoseidonAir { + fn components(&self) -> Vec<&dyn Component> { + vec![&self.component] + } +} + +impl AirTraceVerifier for PoseidonAir { + fn interaction_elements(&self, _channel: &mut Blake2sChannel) -> InteractionElements { + InteractionElements::default() + } +} + +impl AirTraceWriter for PoseidonAir { + fn interact( + &self, + _trace: &ColumnVec>, + _elements: &InteractionElements, + ) -> Vec> { + vec![] + } + + fn to_air_prover(&self) -> &impl AirProver { + self + } +} + +impl Component for PoseidonComponent { + fn n_constraints(&self) -> usize { + (N_COLUMNS_PER_REP - N_STATE) * N_INSTANCES_PER_ROW + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_column_size() + LOG_EXPAND + } + + fn n_interaction_phases(&self) -> u32 { + 1 + } + + fn trace_log_degree_bounds(&self) -> TreeVec> { + TreeVec::new(vec![vec![self.log_column_size(); N_COLUMNS], vec![]]) + } + + fn mask_points( + &self, + point: CirclePoint, + ) -> TreeVec>>> { + TreeVec::new(vec![ + fixed_mask_points(&vec![vec![0_usize]; N_COLUMNS], point), + vec![], + ]) + } + + fn interaction_element_ids(&self) -> Vec { + vec![] + } + + fn evaluate_constraint_quotients_at_point( + &self, + point: CirclePoint, + mask: &ColumnVec>, + evaluation_accumulator: &mut PointEvaluationAccumulator, + _interaction_elements: &InteractionElements, + ) { + let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset; + let denom = coset_vanishing(constraint_zero_domain, point); + let denom_inverse = denom.inverse(); + let mut eval = PoseidonEvalAtPoint { + mask, + evaluation_accumulator, + col_index: 0, + denom_inverse, + }; + for _ in 0..N_INSTANCES_PER_ROW { + eval.eval(); + } + assert_eq!(eval.col_index, N_COLUMNS); + } +} + +// Applies the external round matrix. +// See https://eprint.iacr.org/2023/323.pdf 5.1 and Appendix B. +fn apply_external_round_matrix(state: &mut [F; 16]) +where + F: Copy + AddAssign + Add + Sub + Mul, +{ + // Applies M4 from the paper. + let apply_m4 = |x: [F; 4]| { + let t0 = x[0] + x[1]; + let t02 = t0 + t0; + let t1 = x[2] + x[3]; + let t12 = t1 + t1; + let t2 = x[1] + x[1] + t1; + let t3 = x[3] + x[3] + t0; + let t4 = t12 + t12 + t3; + let t5 = t02 + t02 + t2; + let t6 = t3 + t5; + let t7 = t2 + t4; + [t6, t5, t7, t4] + }; + + // Applies circ(2M4, M4, M4, M4). + for i in 0..4 { + [ + state[4 * i], + state[4 * i + 1], + state[4 * i + 2], + state[4 * i + 3], + ] = apply_m4([ + state[4 * i], + state[4 * i + 1], + state[4 * i + 2], + state[4 * i + 3], + ]); + } + for j in 0..4 { + let s = state[j] + state[j + 4] + state[j + 8] + state[j + 12]; + for i in 0..4 { + state[4 * i + j] += s; + } + } +} + +// Applies the internal round matrix. +// mu_i = 2^{i+1} + 1. +// See https://eprint.iacr.org/2023/323.pdf 5.2 . +fn apply_internal_round_matrix(state: &mut [F; 16]) +where + F: Copy + AddAssign + Add + Sub + Mul, +{ + // TODO(spapini): Check that these coefficients are good according to section 5.3 of Poseidon2 + // paper. + let sum = state[1..].iter().fold(state[0], |acc, s| acc + *s); + state.iter_mut().enumerate().for_each(|(i, s)| { + // TODO(spapini): Change to rotations. + *s = *s * BaseField::from_u32_unchecked(1 << (i + 1)) + sum; + }); +} + +struct PoseidonEvalAtPoint<'a> { + mask: &'a ColumnVec>, + evaluation_accumulator: &'a mut PointEvaluationAccumulator, + col_index: usize, + denom_inverse: SecureField, +} +impl<'a> PoseidonEval for PoseidonEvalAtPoint<'a> { + type F = SecureField; + + fn next_mask(&mut self) -> Self::F { + let res = self.mask[self.col_index][0]; + self.col_index += 1; + res + } + fn add_constraint(&mut self, constraint: Self::F) { + self.evaluation_accumulator + .accumulate(constraint * self.denom_inverse); + } +} + +fn pow5(x: F) -> F { + let x2 = x * x; + let x4 = x2 * x2; + x4 * x +} + +// TODO(spapini): Round constants. +trait PoseidonEval { + type F: FieldExpOps + + Copy + + AddAssign + + Add + + Sub + + Mul + + AddAssign; + + fn next_mask(&mut self) -> Self::F; + fn add_constraint(&mut self, constraint: Self::F); + + fn eval(&mut self) { + let mut state: [_; N_STATE] = std::array::from_fn(|_| self.next_mask()); + + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += EXTERNAL_ROUND_CONSTS[round][i]; + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter_mut().for_each(|s| { + let m = self.next_mask(); + self.add_constraint(*s - m); + *s = m; + }); + }); + + // Partial rounds. + (0..N_PARTIAL_ROUNDS).for_each(|round| { + state[0] += INTERNAL_ROUND_CONSTS[round]; + apply_internal_round_matrix(&mut state); + state[0] = pow5(state[0]); + let m = self.next_mask(); + self.add_constraint(state[0] - m); + state[0] = m; + }); + + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i]; + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter_mut().for_each(|s| { + let m = self.next_mask(); + self.add_constraint(*s - m); + *s = m; + }); + }); + } +} + +impl AirProver for PoseidonAir { + fn prover_components(&self) -> Vec<&dyn ComponentProver> { + vec![&self.component] + } +} + +pub fn gen_trace( + log_size: u32, +) -> ColumnVec> { + assert!(log_size >= LOG_N_LANES); + let mut trace = (0..N_COLUMNS) + .map(|_| Col::::zeros(1 << log_size)) + .collect_vec(); + for vec_index in 0..(1 << (log_size - LOG_N_LANES)) { + // Initial state. + let mut col_index = 0; + for rep_i in 0..N_INSTANCES_PER_ROW { + let mut state: [_; N_STATE] = std::array::from_fn(|state_i| { + PackedBaseField::from_array(std::array::from_fn(|i| { + BaseField::from_u32_unchecked((vec_index * 16 + i + state_i + rep_i) as u32) + })) + }); + state.iter().copied().for_each(|s| { + trace[col_index].data[vec_index] = s; + col_index += 1; + }); + + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += PackedBaseField::broadcast(EXTERNAL_ROUND_CONSTS[round][i]); + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter().copied().for_each(|s| { + trace[col_index].data[vec_index] = s; + col_index += 1; + }); + }); + + // Partial rounds. + (0..N_PARTIAL_ROUNDS).for_each(|round| { + state[0] += PackedBaseField::broadcast(INTERNAL_ROUND_CONSTS[round]); + apply_internal_round_matrix(&mut state); + state[0] = pow5(state[0]); + trace[col_index].data[vec_index] = state[0]; + col_index += 1; + }); + + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += PackedBaseField::broadcast( + EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i], + ); + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter().copied().for_each(|s| { + trace[col_index].data[vec_index] = s; + col_index += 1; + }); + }); + } + } + let domain = CanonicCoset::new(log_size).circle_domain(); + trace + .into_iter() + .map(|eval| CircleEvaluation::::new(domain, eval)) + .collect_vec() +} + +impl ComponentTraceWriter for PoseidonComponent { + fn write_interaction_trace( + &self, + _trace: &ColumnVec<&CircleEvaluation>, + _elements: &InteractionElements, + ) -> ColumnVec> { + vec![] + } +} + +struct PoseidonEvalAtDomain<'a> { + trace_eval: &'a TreeVec>>, + vec_row: usize, + random_coeff_powers: &'a [SecureField], + row_res: PackedSecureField, + col_index: usize, + constraint_index: usize, +} +impl<'a> PoseidonEval for PoseidonEvalAtDomain<'a> { + type F = PackedBaseField; + + fn next_mask(&mut self) -> Self::F { + let res = unsafe { + *self.trace_eval[0] + .get_unchecked(self.col_index) + .data + .get_unchecked(self.vec_row) + }; + self.col_index += 1; + res + } + fn add_constraint(&mut self, constraint: Self::F) { + self.row_res += + PackedSecureField::broadcast(self.random_coeff_powers[self.constraint_index]) + * constraint; + self.constraint_index += 1; + } +} + +impl ComponentProver for PoseidonComponent { + fn evaluate_constraint_quotients_on_domain( + &self, + trace: &ComponentTrace<'_, SimdBackend>, + evaluation_accumulator: &mut DomainEvaluationAccumulator, + _interaction_elements: &InteractionElements, + ) { + assert_eq!(trace.polys[0].len(), self.n_columns()); + let eval_domain = CanonicCoset::new(self.log_column_size() + LOG_EXPAND).circle_domain(); + + // Create a new evaluation. + let span = span!(Level::INFO, "Deg8 eval").entered(); + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(self.max_constraint_log_degree_bound()) + .circle_domain() + .half_coset, + ); + let trace_eval = trace + .polys + .as_cols_ref() + .map_cols(|col| col.evaluate_with_twiddles(eval_domain, &twiddles)); + span.exit(); + + // Denoms. + let span = span!(Level::INFO, "Constraint eval denominators").entered(); + let zero_domain = CanonicCoset::new(self.log_column_size()).coset; + let mut denoms = + BaseFieldVec::from_iter(eval_domain.iter().map(|p| coset_vanishing(zero_domain, p))); + >::bit_reverse_column(&mut denoms); + let mut denom_inverses = BaseFieldVec::zeros(denoms.len()); + >::batch_inverse(&denoms, &mut denom_inverses); + span.exit(); + + let _span = span!(Level::INFO, "Constraint pointwise eval").entered(); + + let constraint_log_degree_bound = self.max_constraint_log_degree_bound(); + let n_constraints = self.n_constraints(); + let [accum] = + evaluation_accumulator.columns([(constraint_log_degree_bound, n_constraints)]); + let mut pows = accum.random_coeff_powers.clone(); + pows.reverse(); + + for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES)) { + let mut evaluator = PoseidonEvalAtDomain { + trace_eval: &trace_eval, + vec_row, + random_coeff_powers: &pows, + row_res: PackedSecureField::zero(), + col_index: 0, + constraint_index: 0, + }; + for _ in 0..N_INSTANCES_PER_ROW { + evaluator.eval(); + } + let row_res = evaluator.row_res; + + unsafe { + accum.col.set_packed( + vec_row, + accum.col.packed_at(vec_row) + row_res * denom_inverses.data[vec_row], + ) + } + assert_eq!(evaluator.constraint_index, n_constraints); + } + } +} + +#[cfg(test)] +mod tests { + use tracing::{span, Level}; + + use crate::core::backend::simd::SimdBackend; + use crate::core::channel::{Blake2sChannel, Channel}; + use crate::core::fields::m31::BaseField; + use crate::core::fields::IntoSlice; + use crate::core::prover::{prove, verify}; + use crate::core::vcs::blake2_hash::Blake2sHasher; + use crate::core::vcs::hasher::Hasher; + use crate::examples::poseidon::{gen_trace, PoseidonAir, PoseidonComponent}; + + #[test_log::test] + fn test_simd_poseidon_prove() { + // Note: To see time measurement, run test with + // RUST_LOG_SPAN_EVENTS=enter,close RUST_LOG=info RUST_BACKTRACE=1 RUSTFLAGS=" + // -C target-cpu=native -C target-feature=+avx512f -C opt-level=3" cargo test + // test_simd_poseidon_prove -- --nocapture + + // Note: 15 means 208MB of trace. + const LOG_N_ROWS: u32 = 12; + let component = PoseidonComponent { + log_n_instances: LOG_N_ROWS, + }; + let span = span!(Level::INFO, "Trace generation").entered(); + let trace = gen_trace(component.log_column_size()); + span.exit(); + + let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + let air = PoseidonAir { component }; + let proof = prove::(&air, channel, trace).unwrap(); + + let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); + verify(proof, &air, channel).unwrap(); + } +} diff --git a/crates/prover/src/examples/wide_fibonacci/simd.rs b/crates/prover/src/examples/wide_fibonacci/simd.rs index 877b03ba8..387ece5e7 100644 --- a/crates/prover/src/examples/wide_fibonacci/simd.rs +++ b/crates/prover/src/examples/wide_fibonacci/simd.rs @@ -248,7 +248,7 @@ mod tests { fn test_simd_wide_fib_prove() { // Note: To see time measurement, run test with // RUST_LOG_SPAN_EVENTS=enter,close RUST_LOG=info RUST_BACKTRACE=1 RUSTFLAGS=" - // -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo test + // -C target-cpu=native -C target-feature=+avx512f -C opt-level=3" cargo test // test_simd_wide_fib_prove -- --nocapture // Note: 17 means 128MB of trace.