diff --git a/crates/prover/src/constraint_framework/assert.rs b/crates/prover/src/constraint_framework/assert.rs new file mode 100644 index 000000000..c6aa88784 --- /dev/null +++ b/crates/prover/src/constraint_framework/assert.rs @@ -0,0 +1,80 @@ +use num_traits::{One, Zero}; + +use super::EvalAtRow; +use crate::core::backend::{Backend, Column}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::pcs::TreeVec; +use crate::core::poly::circle::{CanonicCoset, CirclePoly}; + +/// Evaluates expressions at a trace domain row, and asserts constraints. Mainly used for testing. +pub struct AssertEvaluator<'a> { + pub trace: &'a TreeVec>>, + pub col_index: TreeVec, + pub row: usize, +} +impl<'a> AssertEvaluator<'a> { + pub fn new(trace: &'a TreeVec>>, row: usize) -> Self { + Self { + trace, + col_index: TreeVec::new(vec![0; trace.len()]), + row, + } + } +} +impl<'a> EvalAtRow for AssertEvaluator<'a> { + type F = BaseField; + type EF = SecureField; + + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N] { + let col_index = self.col_index[interaction]; + self.col_index[interaction] += 1; + offsets.map(|off| { + // The mask row might wrap around the column size. + let col_size = self.trace[interaction][col_index].len() as isize; + self.trace[interaction][col_index] + [(self.row as isize + off).rem_euclid(col_size) as usize] + }) + } + + fn add_constraint(&mut self, constraint: G) + where + Self::EF: std::ops::Mul, + { + // Cast to SecureField. + let res = SecureField::one() * constraint; + // The constraint should be zero at the given row, since we are evaluating on the trace + // domain. + assert_eq!(res, SecureField::zero(), "row: {}", self.row); + } + + fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { + SecureField::from_m31_array(values) + } +} + +pub fn assert_constraints( + trace_polys: &TreeVec>>, + trace_domain: CanonicCoset, + assert_func: impl Fn(AssertEvaluator<'_>), +) { + let traces = trace_polys.as_ref().map(|tree| { + tree.iter() + .map(|poly| { + poly.evaluate(trace_domain.circle_domain()) + .bit_reverse() + .values + .to_cpu() + }) + .collect() + }); + for row in 0..trace_domain.size() { + let eval = AssertEvaluator::new(&traces, row); + assert_func(eval); + } +} diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 02cf52145..f8b6197dd 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -1,10 +1,12 @@ /// ! This module contains helpers to express and use constraints for components. +mod assert; mod point; mod simd_domain; use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Sub}; +pub use assert::{assert_constraints, AssertEvaluator}; use num_traits::{One, Zero}; pub use point::PointEvaluator; pub use simd_domain::SimdDomainEvaluator; diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index aa5497239..3b71a1937 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -438,11 +438,14 @@ mod tests { use num_traits::One; use tracing::{span, Level}; - use super::N_LOG_INSTANCES_PER_ROW; + use super::{PoseidonEval, N_LOG_INSTANCES_PER_ROW}; + use crate::constraint_framework::assert_constraints; use crate::core::backend::simd::SimdBackend; use crate::core::channel::{Blake2sChannel, Channel}; use crate::core::fields::m31::BaseField; use crate::core::fields::IntoSlice; + use crate::core::pcs::TreeVec; + use crate::core::poly::circle::CanonicCoset; use crate::core::prover::{commit_and_prove, commit_and_verify}; use crate::core::vcs::blake2_hash::Blake2sHasher; use crate::core::vcs::hasher::Hasher; @@ -488,6 +491,22 @@ mod tests { assert_eq!(state, expected_state); } + #[test] + fn test_poseidon_constraints() { + const LOG_N_ROWS: u32 = 8; + let component = PoseidonComponent { + log_n_rows: LOG_N_ROWS, + }; + let trace = gen_trace(component.log_column_size()); + let trace_polys = TreeVec::new(vec![trace + .into_iter() + .map(|c| c.interpolate()) + .collect_vec()]); + assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |eval| { + PoseidonEval { eval }.eval(); + }); + } + #[test_log::test] fn test_simd_poseidon_prove() { // Note: To see time measurement, run test with