From a69e8ca0e122ae83353464e56cd5085fc2149cd0 Mon Sep 17 00:00:00 2001
From: Shahar Papini <43779613+spapinistarkware@users.noreply.github.com>
Date: Sun, 14 Jul 2024 15:18:41 +0300
Subject: [PATCH] Eval framework asserts (#714)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This change is [](https://reviewable.io/reviews/starkware-libs/stwo/714)
---
.../prover/src/constraint_framework/assert.rs | 80 +++++++++++++++++++
crates/prover/src/constraint_framework/mod.rs | 2 +
crates/prover/src/examples/poseidon/mod.rs | 21 ++++-
3 files changed, 102 insertions(+), 1 deletion(-)
create mode 100644 crates/prover/src/constraint_framework/assert.rs
diff --git a/crates/prover/src/constraint_framework/assert.rs b/crates/prover/src/constraint_framework/assert.rs
new file mode 100644
index 000000000..c6aa88784
--- /dev/null
+++ b/crates/prover/src/constraint_framework/assert.rs
@@ -0,0 +1,80 @@
+use num_traits::{One, Zero};
+
+use super::EvalAtRow;
+use crate::core::backend::{Backend, Column};
+use crate::core::fields::m31::BaseField;
+use crate::core::fields::qm31::SecureField;
+use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
+use crate::core::pcs::TreeVec;
+use crate::core::poly::circle::{CanonicCoset, CirclePoly};
+
+/// Evaluates expressions at a trace domain row, and asserts constraints. Mainly used for testing.
+pub struct AssertEvaluator<'a> {
+ pub trace: &'a TreeVec>>,
+ pub col_index: TreeVec,
+ pub row: usize,
+}
+impl<'a> AssertEvaluator<'a> {
+ pub fn new(trace: &'a TreeVec>>, row: usize) -> Self {
+ Self {
+ trace,
+ col_index: TreeVec::new(vec![0; trace.len()]),
+ row,
+ }
+ }
+}
+impl<'a> EvalAtRow for AssertEvaluator<'a> {
+ type F = BaseField;
+ type EF = SecureField;
+
+ fn next_interaction_mask(
+ &mut self,
+ interaction: usize,
+ offsets: [isize; N],
+ ) -> [Self::F; N] {
+ let col_index = self.col_index[interaction];
+ self.col_index[interaction] += 1;
+ offsets.map(|off| {
+ // The mask row might wrap around the column size.
+ let col_size = self.trace[interaction][col_index].len() as isize;
+ self.trace[interaction][col_index]
+ [(self.row as isize + off).rem_euclid(col_size) as usize]
+ })
+ }
+
+ fn add_constraint(&mut self, constraint: G)
+ where
+ Self::EF: std::ops::Mul,
+ {
+ // Cast to SecureField.
+ let res = SecureField::one() * constraint;
+ // The constraint should be zero at the given row, since we are evaluating on the trace
+ // domain.
+ assert_eq!(res, SecureField::zero(), "row: {}", self.row);
+ }
+
+ fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
+ SecureField::from_m31_array(values)
+ }
+}
+
+pub fn assert_constraints(
+ trace_polys: &TreeVec>>,
+ trace_domain: CanonicCoset,
+ assert_func: impl Fn(AssertEvaluator<'_>),
+) {
+ let traces = trace_polys.as_ref().map(|tree| {
+ tree.iter()
+ .map(|poly| {
+ poly.evaluate(trace_domain.circle_domain())
+ .bit_reverse()
+ .values
+ .to_cpu()
+ })
+ .collect()
+ });
+ for row in 0..trace_domain.size() {
+ let eval = AssertEvaluator::new(&traces, row);
+ assert_func(eval);
+ }
+}
diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs
index 02cf52145..f8b6197dd 100644
--- a/crates/prover/src/constraint_framework/mod.rs
+++ b/crates/prover/src/constraint_framework/mod.rs
@@ -1,10 +1,12 @@
/// ! This module contains helpers to express and use constraints for components.
+mod assert;
mod point;
mod simd_domain;
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Sub};
+pub use assert::{assert_constraints, AssertEvaluator};
use num_traits::{One, Zero};
pub use point::PointEvaluator;
pub use simd_domain::SimdDomainEvaluator;
diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs
index aa5497239..3b71a1937 100644
--- a/crates/prover/src/examples/poseidon/mod.rs
+++ b/crates/prover/src/examples/poseidon/mod.rs
@@ -438,11 +438,14 @@ mod tests {
use num_traits::One;
use tracing::{span, Level};
- use super::N_LOG_INSTANCES_PER_ROW;
+ use super::{PoseidonEval, N_LOG_INSTANCES_PER_ROW};
+ use crate::constraint_framework::assert_constraints;
use crate::core::backend::simd::SimdBackend;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::fields::m31::BaseField;
use crate::core::fields::IntoSlice;
+ use crate::core::pcs::TreeVec;
+ use crate::core::poly::circle::CanonicCoset;
use crate::core::prover::{commit_and_prove, commit_and_verify};
use crate::core::vcs::blake2_hash::Blake2sHasher;
use crate::core::vcs::hasher::Hasher;
@@ -488,6 +491,22 @@ mod tests {
assert_eq!(state, expected_state);
}
+ #[test]
+ fn test_poseidon_constraints() {
+ const LOG_N_ROWS: u32 = 8;
+ let component = PoseidonComponent {
+ log_n_rows: LOG_N_ROWS,
+ };
+ let trace = gen_trace(component.log_column_size());
+ let trace_polys = TreeVec::new(vec![trace
+ .into_iter()
+ .map(|c| c.interpolate())
+ .collect_vec()]);
+ assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |eval| {
+ PoseidonEval { eval }.eval();
+ });
+ }
+
#[test_log::test]
fn test_simd_poseidon_prove() {
// Note: To see time measurement, run test with