diff --git a/crates/prover/benches/poseidon.rs b/crates/prover/benches/poseidon.rs index 7802682a2..6e3d64aca 100644 --- a/crates/prover/benches/poseidon.rs +++ b/crates/prover/benches/poseidon.rs @@ -15,7 +15,7 @@ pub fn simd_poseidon(c: &mut Criterion) { group.bench_function(format!("poseidon2 2^{} instances", LOG_N_ROWS + 3), |b| { b.iter(|| { let component = PoseidonComponent { - log_n_instances: LOG_N_ROWS, + log_n_rows: LOG_N_ROWS, }; let trace = gen_trace(component.log_column_size()); let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 8ca801aec..02cf52145 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -1,8 +1,13 @@ /// ! This module contains helpers to express and use constraints for components. +mod point; +mod simd_domain; + use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Sub}; use num_traits::{One, Zero}; +pub use point::PointEvaluator; +pub use simd_domain::SimdDomainEvaluator; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; diff --git a/crates/prover/src/constraint_framework/point.rs b/crates/prover/src/constraint_framework/point.rs new file mode 100644 index 000000000..88b48bc76 --- /dev/null +++ b/crates/prover/src/constraint_framework/point.rs @@ -0,0 +1,57 @@ +use std::ops::Mul; + +use super::EvalAtRow; +use crate::core::air::accumulation::PointEvaluationAccumulator; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::pcs::TreeVec; +use crate::core::ColumnVec; + +/// Evaluates expressions at an out of domain point. +pub struct PointEvaluator<'a> { + pub mask: TreeVec<&'a ColumnVec>>, + pub evaluation_accumulator: &'a mut PointEvaluationAccumulator, + pub col_index: Vec, + pub denom_inverse: SecureField, +} +impl<'a> PointEvaluator<'a> { + pub fn new( + mask: TreeVec<&'a ColumnVec>>, + evaluation_accumulator: &'a mut PointEvaluationAccumulator, + denom_inverse: SecureField, + ) -> Self { + let col_index = vec![0; mask.len()]; + Self { + mask, + evaluation_accumulator, + col_index, + denom_inverse, + } + } +} +impl<'a> EvalAtRow for PointEvaluator<'a> { + type F = SecureField; + type EF = SecureField; + + fn next_interaction_mask( + &mut self, + interaction: usize, + _offsets: [isize; N], + ) -> [Self::F; N] { + let col_index = self.col_index[interaction]; + self.col_index[interaction] += 1; + let mask = self.mask[interaction][col_index].clone(); + assert_eq!(mask.len(), N); + mask.try_into().unwrap() + } + fn add_constraint(&mut self, constraint: G) + where + Self::EF: Mul, + { + self.evaluation_accumulator + .accumulate(self.denom_inverse * constraint); + } + fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { + SecureField::from_partial_evals(values) + } +} diff --git a/crates/prover/src/constraint_framework/simd_domain.rs b/crates/prover/src/constraint_framework/simd_domain.rs new file mode 100644 index 000000000..b22f052b9 --- /dev/null +++ b/crates/prover/src/constraint_framework/simd_domain.rs @@ -0,0 +1,95 @@ +use std::ops::Mul; + +use num_traits::Zero; + +use super::EvalAtRow; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::pcs::TreeVec; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::offset_bit_reversed_circle_domain_index; + +/// Evaluates constraints at an evaluation domain points. +pub struct SimdDomainEvaluator<'a> { + pub trace_eval: + &'a TreeVec>>, + pub column_index_per_interaction: Vec, + /// The row index of the simd-vector row to evaluate the constraints at. + pub vec_row: usize, + pub random_coeff_powers: &'a [SecureField], + pub row_res: PackedSecureField, + pub constraint_index: usize, + pub domain_log_size: u32, + pub eval_domain_log_size: u32, +} +impl<'a> SimdDomainEvaluator<'a> { + pub fn new( + trace_eval: &'a TreeVec>>, + vec_row: usize, + random_coeff_powers: &'a [SecureField], + domain_log_size: u32, + eval_log_size: u32, + ) -> Self { + Self { + trace_eval, + column_index_per_interaction: vec![0; trace_eval.len()], + vec_row, + random_coeff_powers, + row_res: PackedSecureField::zero(), + constraint_index: 0, + domain_log_size, + eval_domain_log_size: eval_log_size, + } + } +} +impl<'a> EvalAtRow for SimdDomainEvaluator<'a> { + type F = PackedBaseField; + type EF = PackedSecureField; + + // TODO(spapini): Remove all boundary checks. + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N] { + let col_index = self.column_index_per_interaction[interaction]; + self.column_index_per_interaction[interaction] += 1; + offsets.map(|off| { + // If the offset is 0, we can just return the value directly from this row. + if off == 0 { + return self.trace_eval[interaction][col_index].data[self.vec_row]; + } + // Otherwise, we need to look up the value at the offset. + // Since the domain is bit-reversed circle domain ordered, we need to look up the value + // at the bit-reversed natural order index at an offset. + PackedBaseField::from_array(std::array::from_fn(|i| { + let row_index = offset_bit_reversed_circle_domain_index( + (self.vec_row << LOG_N_LANES) + i, + self.domain_log_size, + self.eval_domain_log_size, + off, + ); + self.trace_eval[interaction][col_index].at(row_index) + })) + }) + } + fn add_constraint(&mut self, constraint: G) + where + Self::EF: Mul, + { + self.row_res += + PackedSecureField::broadcast(self.random_coeff_powers[self.constraint_index]) + * constraint; + self.constraint_index += 1; + } + + fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { + PackedSecureField::from_packed_m31s(values) + } +} diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 354581dfd..aa5497239 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -3,15 +3,14 @@ use std::ops::{Add, AddAssign, Mul, Sub}; use itertools::Itertools; -use num_traits::Zero; use tracing::{span, Level}; +use crate::constraint_framework::{EvalAtRow, PointEvaluator, SimdDomainEvaluator}; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::mask::fixed_mask_points; use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace}; use crate::core::backend::simd::column::BaseFieldVec; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; -use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Col, Column, ColumnOps}; use crate::core::channel::Blake2sChannel; @@ -44,12 +43,12 @@ const INTERNAL_ROUND_CONSTS: [BaseField; N_PARTIAL_ROUNDS] = #[derive(Clone)] pub struct PoseidonComponent { - pub log_n_instances: u32, + pub log_n_rows: u32, } impl PoseidonComponent { pub fn log_column_size(&self) -> u32 { - self.log_n_instances + self.log_n_rows } pub fn n_columns(&self) -> usize { @@ -130,16 +129,13 @@ impl Component for PoseidonComponent { let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset; let denom = coset_vanishing(constraint_zero_domain, point); let denom_inverse = denom.inverse(); - let mut eval = PoseidonEvalAtPoint { - mask: &mask[0], - evaluation_accumulator, - col_index: 0, - denom_inverse, + let mut poseidon_eval = PoseidonEval { + eval: PointEvaluator::new(mask.as_ref(), evaluation_accumulator, denom_inverse), }; for _ in 0..N_INSTANCES_PER_ROW { - eval.eval(); + poseidon_eval.eval(); } - assert_eq!(eval.col_index, N_COLUMNS); + assert_eq!(poseidon_eval.eval.col_index[0], N_COLUMNS); } } @@ -206,46 +202,19 @@ where }); } -struct PoseidonEvalAtPoint<'a> { - mask: &'a ColumnVec>, - evaluation_accumulator: &'a mut PointEvaluationAccumulator, - col_index: usize, - denom_inverse: SecureField, -} -impl<'a> PoseidonEval for PoseidonEvalAtPoint<'a> { - type F = SecureField; - - fn next_mask(&mut self) -> Self::F { - let res = self.mask[self.col_index][0]; - self.col_index += 1; - res - } - fn add_constraint(&mut self, constraint: Self::F) { - self.evaluation_accumulator - .accumulate(constraint * self.denom_inverse); - } -} - fn pow5(x: F) -> F { let x2 = x * x; let x4 = x2 * x2; x4 * x } -trait PoseidonEval { - type F: FieldExpOps - + Copy - + AddAssign - + Add - + Sub - + Mul - + AddAssign; - - fn next_mask(&mut self) -> Self::F; - fn add_constraint(&mut self, constraint: Self::F); +struct PoseidonEval { + eval: E, +} +impl PoseidonEval { fn eval(&mut self) { - let mut state: [_; N_STATE] = std::array::from_fn(|_| self.next_mask()); + let mut state: [_; N_STATE] = std::array::from_fn(|_| self.eval.next_trace_mask()); // 4 full rounds. (0..N_HALF_FULL_ROUNDS).for_each(|round| { @@ -255,8 +224,8 @@ trait PoseidonEval { apply_external_round_matrix(&mut state); state = std::array::from_fn(|i| pow5(state[i])); state.iter_mut().for_each(|s| { - let m = self.next_mask(); - self.add_constraint(*s - m); + let m = self.eval.next_trace_mask(); + self.eval.add_constraint(*s - m); *s = m; }); }); @@ -266,8 +235,8 @@ trait PoseidonEval { state[0] += INTERNAL_ROUND_CONSTS[round]; apply_internal_round_matrix(&mut state); state[0] = pow5(state[0]); - let m = self.next_mask(); - self.add_constraint(state[0] - m); + let m = self.eval.next_trace_mask(); + self.eval.add_constraint(state[0] - m); state[0] = m; }); @@ -279,8 +248,8 @@ trait PoseidonEval { apply_external_round_matrix(&mut state); state = std::array::from_fn(|i| pow5(state[i])); state.iter_mut().for_each(|s| { - let m = self.next_mask(); - self.add_constraint(*s - m); + let m = self.eval.next_trace_mask(); + self.eval.add_constraint(*s - m); *s = m; }); }); @@ -387,35 +356,6 @@ impl ComponentTraceGenerator for PoseidonComponent { } } -struct PoseidonEvalAtDomain<'a> { - trace_eval: &'a TreeVec>>, - vec_row: usize, - random_coeff_powers: &'a [SecureField], - row_res: PackedSecureField, - col_index: usize, - constraint_index: usize, -} -impl<'a> PoseidonEval for PoseidonEvalAtDomain<'a> { - type F = PackedBaseField; - - fn next_mask(&mut self) -> Self::F { - let res = unsafe { - *self.trace_eval[0] - .get_unchecked(self.col_index) - .data - .get_unchecked(self.vec_row) - }; - self.col_index += 1; - res - } - fn add_constraint(&mut self, constraint: Self::F) { - self.row_res += - PackedSecureField::broadcast(self.random_coeff_powers[self.constraint_index]) - * constraint; - self.constraint_index += 1; - } -} - impl ComponentProver for PoseidonComponent { fn evaluate_constraint_quotients_on_domain( &self, @@ -438,6 +378,7 @@ impl ComponentProver for PoseidonComponent { .polys .as_cols_ref() .map_cols(|col| col.evaluate_with_twiddles(eval_domain, &twiddles)); + let trace_eval_ref = trace_eval.as_ref().map(|t| t.iter().collect_vec()); span.exit(); // Denoms. @@ -460,18 +401,19 @@ impl ComponentProver for PoseidonComponent { pows.reverse(); for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES)) { - let mut evaluator = PoseidonEvalAtDomain { - trace_eval: &trace_eval, - vec_row, - random_coeff_powers: &pows, - row_res: PackedSecureField::zero(), - col_index: 0, - constraint_index: 0, + let mut evaluator = PoseidonEval { + eval: SimdDomainEvaluator::new( + &trace_eval_ref, + vec_row, + &pows, + self.log_n_rows, + self.log_n_rows + LOG_EXPAND, + ), }; for _ in 0..N_INSTANCES_PER_ROW { evaluator.eval(); } - let row_res = evaluator.row_res; + let row_res = evaluator.eval.row_res; unsafe { accum.col.set_packed( @@ -479,7 +421,7 @@ impl ComponentProver for PoseidonComponent { accum.col.packed_at(vec_row) + row_res * denom_inverses.data[vec_row], ) } - assert_eq!(evaluator.constraint_index, n_constraints); + assert_eq!(evaluator.eval.constraint_index, n_constraints); } } @@ -559,9 +501,7 @@ mod tests { .parse::() .unwrap(); let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32; - let component = PoseidonComponent { - log_n_instances: log_n_rows, - }; + let component = PoseidonComponent { log_n_rows }; let span = span!(Level::INFO, "Trace generation").entered(); let trace = gen_trace(component.log_column_size()); span.exit();