Skip to content

Commit

Permalink
Eval framework asserts (#714)
Browse files Browse the repository at this point in the history
<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/starkware-libs/stwo/714)
<!-- Reviewable:end -->
  • Loading branch information
spapinistarkware committed Jul 14, 2024
1 parent 4839cd8 commit a69e8ca
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 1 deletion.
80 changes: 80 additions & 0 deletions crates/prover/src/constraint_framework/assert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use num_traits::{One, Zero};

use super::EvalAtRow;
use crate::core::backend::{Backend, Column};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CirclePoly};

/// Evaluates expressions at a trace domain row, and asserts constraints. Mainly used for testing.
pub struct AssertEvaluator<'a> {
pub trace: &'a TreeVec<Vec<Vec<BaseField>>>,
pub col_index: TreeVec<usize>,
pub row: usize,
}
impl<'a> AssertEvaluator<'a> {
pub fn new(trace: &'a TreeVec<Vec<Vec<BaseField>>>, row: usize) -> Self {
Self {
trace,
col_index: TreeVec::new(vec![0; trace.len()]),
row,
}
}
}
impl<'a> EvalAtRow for AssertEvaluator<'a> {
type F = BaseField;
type EF = SecureField;

fn next_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
offsets: [isize; N],
) -> [Self::F; N] {
let col_index = self.col_index[interaction];
self.col_index[interaction] += 1;
offsets.map(|off| {
// The mask row might wrap around the column size.
let col_size = self.trace[interaction][col_index].len() as isize;
self.trace[interaction][col_index]
[(self.row as isize + off).rem_euclid(col_size) as usize]
})
}

fn add_constraint<G>(&mut self, constraint: G)
where
Self::EF: std::ops::Mul<G, Output = Self::EF>,
{
// Cast to SecureField.
let res = SecureField::one() * constraint;
// The constraint should be zero at the given row, since we are evaluating on the trace
// domain.
assert_eq!(res, SecureField::zero(), "row: {}", self.row);
}

fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
SecureField::from_m31_array(values)
}
}

pub fn assert_constraints<B: Backend>(
trace_polys: &TreeVec<Vec<CirclePoly<B>>>,
trace_domain: CanonicCoset,
assert_func: impl Fn(AssertEvaluator<'_>),
) {
let traces = trace_polys.as_ref().map(|tree| {
tree.iter()
.map(|poly| {
poly.evaluate(trace_domain.circle_domain())
.bit_reverse()
.values
.to_cpu()
})
.collect()
});
for row in 0..trace_domain.size() {
let eval = AssertEvaluator::new(&traces, row);
assert_func(eval);
}
}
2 changes: 2 additions & 0 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
/// ! This module contains helpers to express and use constraints for components.
mod assert;
mod point;
mod simd_domain;

use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Sub};

pub use assert::{assert_constraints, AssertEvaluator};
use num_traits::{One, Zero};
pub use point::PointEvaluator;
pub use simd_domain::SimdDomainEvaluator;
Expand Down
21 changes: 20 additions & 1 deletion crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,14 @@ mod tests {
use num_traits::One;
use tracing::{span, Level};

use super::N_LOG_INSTANCES_PER_ROW;
use super::{PoseidonEval, N_LOG_INSTANCES_PER_ROW};
use crate::constraint_framework::assert_constraints;
use crate::core::backend::simd::SimdBackend;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::fields::m31::BaseField;
use crate::core::fields::IntoSlice;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CanonicCoset;
use crate::core::prover::{commit_and_prove, commit_and_verify};
use crate::core::vcs::blake2_hash::Blake2sHasher;
use crate::core::vcs::hasher::Hasher;
Expand Down Expand Up @@ -488,6 +491,22 @@ mod tests {
assert_eq!(state, expected_state);
}

#[test]
fn test_poseidon_constraints() {
const LOG_N_ROWS: u32 = 8;
let component = PoseidonComponent {
log_n_rows: LOG_N_ROWS,
};
let trace = gen_trace(component.log_column_size());
let trace_polys = TreeVec::new(vec![trace
.into_iter()
.map(|c| c.interpolate())
.collect_vec()]);
assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |eval| {
PoseidonEval { eval }.eval();
});
}

#[test_log::test]
fn test_simd_poseidon_prove() {
// Note: To see time measurement, run test with
Expand Down

0 comments on commit a69e8ca

Please sign in to comment.