Skip to content

Commit

Permalink
Add eq evals constraints (#741)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson authored Aug 21, 2024
1 parent d1e6267 commit 586ecf4
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 27 deletions.
24 changes: 24 additions & 0 deletions crates/prover/src/constraint_framework/constant_columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,34 @@ use crate::core::backend::{Backend, Col, Column};
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};

/// Generates a column with a single one at the first position, and zeros elsewhere.
pub fn gen_is_first<B: Backend>(log_size: u32) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
let mut col = Col::<B, BaseField>::zeros(1 << log_size);
col.set(0, BaseField::one());
CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col)
}

/// Generates a column with `1` at every `2^log_step` positions, `0` elsewhere, shifted by offset.
// TODO(andrew): Consider optimizing. Is a quotients of two coset_vanishing (use succinct rep for
// verifier).
pub fn gen_is_step_with_offset<B: Backend>(
log_size: u32,
log_step: u32,
offset: usize,
) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
let mut col = Col::<B, BaseField>::zeros(1 << log_size);

let size = 1 << log_size;
let step = 1 << log_step;
let step_offset = offset % step;

for i in (step_offset..size).step_by(step) {
let circle_domain_index = coset_index_to_circle_domain_index(i, log_size);
let circle_domain_index_bit_rev = bit_reverse_index(circle_domain_index, log_size);
col.set(circle_domain_index_bit_rev, BaseField::one());
}

CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col)
}
29 changes: 26 additions & 3 deletions crates/prover/src/core/backend/simd/column.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::mem;
use std::iter::zip;
use std::{array, mem};

use bytemuck::allocation::cast_vec;
use bytemuck::{cast_slice, cast_slice_mut, Zeroable};
Expand Down Expand Up @@ -215,6 +216,29 @@ pub struct SecureColumn {
pub length: usize,
}

impl SecureColumn {
// Separates a single column of `PackedSecureField` elements into `SECURE_EXTENSION_DEGREE` many
// `PackedBaseField` coordinate columns.
pub fn into_secure_column_by_coords(self) -> SecureColumnByCoords<SimdBackend> {
if self.len() < N_LANES {
return self.to_cpu().into_iter().collect();
}

let length = self.length;
let packed_length = self.data.len();
let mut columns = array::from_fn(|_| Vec::with_capacity(packed_length));

for v in self.data {
let packed_coords = v.into_packed_m31s();
zip(&mut columns, packed_coords).for_each(|(col, packed_coord)| col.push(packed_coord));
}

SecureColumnByCoords {
columns: columns.map(|col| BaseColumn { data: col, length }),
}
}
}

impl Column<SecureField> for SecureColumn {
fn zeros(length: usize) -> Self {
Self {
Expand Down Expand Up @@ -276,9 +300,8 @@ impl FromIterator<SecureField> for SecureColumn {

impl FromIterator<PackedSecureField> for SecureColumn {
fn from_iter<I: IntoIterator<Item = PackedSecureField>>(iter: I) -> Self {
let data = (&mut iter.into_iter()).collect_vec();
let data = iter.into_iter().collect_vec();
let length = data.len() * N_LANES;

Self { data, length }
}
}
Expand Down
17 changes: 11 additions & 6 deletions crates/prover/src/core/fields/secure_column.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::array;
use std::iter::zip;

use super::m31::BaseField;
use super::qm31::SecureField;
use super::{ExtensionOf, FieldOps};
Expand Down Expand Up @@ -89,13 +92,15 @@ impl<'a> IntoIterator for &'a SecureColumnByCoords<CpuBackend> {
}
impl FromIterator<SecureField> for SecureColumnByCoords<CpuBackend> {
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
let mut columns = std::array::from_fn(|_| vec![]);
for value in iter.into_iter() {
let vals = value.to_m31_array();
for j in 0..SECURE_EXTENSION_DEGREE {
columns[j].push(vals[j]);
}
let values = iter.into_iter();
let (lower_bound, _) = values.size_hint();
let mut columns = array::from_fn(|_| Vec::with_capacity(lower_bound));

for value in values {
let coords = value.to_m31_array();
zip(&mut columns, coords).for_each(|(col, coord)| col.push(coord));
}

SecureColumnByCoords { columns }
}
}
Expand Down
12 changes: 8 additions & 4 deletions crates/prover/src/core/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,15 @@ pub(crate) fn coset_order_to_circle_domain_order<F: Field>(values: &[F]) -> Vec<
circle_domain_order
}

pub fn coset_order_to_circle_domain_order_index(index: usize, log_size: u32) -> usize {
if index & 1 == 0 {
index / 2
/// Converts an index within a [`Coset`] to the corresponding index in a [`CircleDomain`].
///
/// [`CircleDomain`]: crate::core::poly::circle::CircleDomain
/// [`Coset`]: crate::core::circle::Coset
pub fn coset_index_to_circle_domain_index(coset_index: usize, log_domain_size: u32) -> usize {
if coset_index % 2 == 0 {
coset_index / 2
} else {
(1 << log_size) - (index + 1) / 2
((2 << log_domain_size) - coset_index) / 2
}
}

Expand Down
219 changes: 219 additions & 0 deletions crates/prover/src/examples/xor/eq_eval_constraints.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
use std::array;

use num_traits::{One, Zero};

use crate::constraint_framework::EvalAtRow;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::eq;

/// Evaluates EqEvals constraints on a column.
///
/// Returns the evaluation at offset 0 on the column.
///
/// Given a column `c(P)` defined on a circle domain D, and an MLE evaluation point `(r0, r1, ...)`
/// evaluates constraints that guarantee: `c(D[b0,b1,...]) = eq((b0,b1,...), (r0,r1,...))`
///
/// See <https://eprint.iacr.org/2023/1284.pdf> (Section 5.1).
pub fn eval_eq_constraints<
E: EvalAtRow,
const N_VARIABLES: usize,
const EQ_EVALS_TRACE: usize,
const SELECTOR_TRACE: usize,
>(
eval: &mut E,
mle_eval_point: MleEvalPoint<N_VARIABLES>,
) -> E::EF {
let [curr, next_next] = eval.next_extension_interaction_mask(EQ_EVALS_TRACE, [0, 2]);
let [is_first, is_second] = eval.next_interaction_mask(SELECTOR_TRACE, [0, -1]);

// Check the initial value on half_coset0 and final value on half_coset1.
// Combining these constraints is safe because `is_first` and `is_second` are never
// non-zero at the same time on the trace.
let half_coset0_initial_check = (curr - mle_eval_point.eq_0_p) * is_first;
let half_coset1_final_check = (curr - mle_eval_point.eq_1_p) * is_second;
eval.add_constraint(half_coset0_initial_check + half_coset1_final_check);

// Check all variables except the last (last variable is handled by the constraint above).
#[allow(clippy::needless_range_loop)]
for variable_i in 0..N_VARIABLES.saturating_sub(1) {
let half_coset0_next = next_next;
let half_coset1_prev = next_next;
let [half_coset0_step, half_coset1_step] =
eval.next_interaction_mask(SELECTOR_TRACE, [0, -1]);
let carry_quotient = mle_eval_point.eq_carry_quotients[variable_i];
// Safe to combine these constraints as `is_step.half_coset0` and `is_step.half_coset1`
// are never non-zero at the same time on the trace.
let half_coset0_check = (curr - half_coset0_next * carry_quotient) * half_coset0_step;
let half_coset1_check = (curr * carry_quotient - half_coset1_prev) * half_coset1_step;
eval.add_constraint(half_coset0_check + half_coset1_check);
}

curr
}

#[derive(Debug, Clone, Copy)]
pub struct MleEvalPoint<const N_VARIABLES: usize> {
// Equals `eq({0}^|p|, p)`.
eq_0_p: SecureField,
// Equals `eq({1}^|p|, p)`.
eq_1_p: SecureField,
// Index `i` stores `eq(({1}^|i|, 0), p[0..i+1]) / eq(({0}^|i|, 1), p[0..i+1])`.
eq_carry_quotients: [SecureField; N_VARIABLES],
// Point `p`.
_p: [SecureField; N_VARIABLES],
}

impl<const N_VARIABLES: usize> MleEvalPoint<N_VARIABLES> {
/// Creates new metadata from point `p`.
pub fn new(p: [SecureField; N_VARIABLES]) -> Self {
let zero = SecureField::zero();
let one = SecureField::one();

Self {
eq_0_p: eq(&[zero; N_VARIABLES], &p),
eq_1_p: eq(&[one; N_VARIABLES], &p),
eq_carry_quotients: array::from_fn(|i| {
let mut numer_assignment = vec![one; i + 1];
numer_assignment[i] = zero;
let mut denom_assignment = vec![zero; i + 1];
denom_assignment[i] = one;
eq(&numer_assignment, &p[..i + 1]) / eq(&denom_assignment, &p[..i + 1])
}),
_p: p,
}
}
}

#[cfg(test)]
pub mod tests {
use std::array;

use num_traits::One;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use test_log::test;

use super::MleEvalPoint;
use crate::constraint_framework::assert_constraints;
use crate::constraint_framework::constant_columns::{gen_is_first, gen_is_step_with_offset};
use crate::core::backend::simd::SimdBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;
use crate::examples::xor::eq_eval_constraints::eval_eq_constraints;

const EVALS_TRACE: usize = 0;
const CONST_TRACE: usize = 1;

#[test]
#[ignore = "SimdBackend `MIN_FFT_LOG_SIZE` is 5"]
fn eq_constraints_with_4_variables() {
const N_VARIABLES: usize = 4;
let mut rng = SmallRng::seed_from_u64(0);
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let base_trace = gen_evals_trace(&eval_point);
let constants_trace = gen_constants_trace(N_VARIABLES);
let traces = TreeVec::new(vec![base_trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(eval_point.len() as u32);
let mle_eval_point = MleEvalPoint::new(eval_point);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
eval_eq_constraints::<_, N_VARIABLES, EVALS_TRACE, CONST_TRACE>(
&mut eval,
mle_eval_point,
);
});
}

#[test]
fn eq_constraints_with_5_variables() {
const N_VARIABLES: usize = 5;
let mut rng = SmallRng::seed_from_u64(0);
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let base_trace = gen_evals_trace(&eval_point);
let constants_trace = gen_constants_trace(N_VARIABLES);
let traces = TreeVec::new(vec![base_trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(eval_point.len() as u32);
let mle_eval_point = MleEvalPoint::new(eval_point);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
eval_eq_constraints::<_, N_VARIABLES, EVALS_TRACE, CONST_TRACE>(
&mut eval,
mle_eval_point,
);
});
}

#[test]
fn eq_constraints_with_8_variables() {
const N_VARIABLES: usize = 8;
let mut rng = SmallRng::seed_from_u64(0);
let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen());
let base_trace = gen_evals_trace(&eval_point);
let constants_trace = gen_constants_trace(N_VARIABLES);
let traces = TreeVec::new(vec![base_trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(eval_point.len() as u32);
let mle_eval_point = MleEvalPoint::new(eval_point);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
eval_eq_constraints::<_, N_VARIABLES, EVALS_TRACE, CONST_TRACE>(
&mut eval,
mle_eval_point,
);
});
}

/// Generates a trace.
///
/// Trace structure:
///
/// ```text
/// -----------------------------
/// | eq evals |
/// -----------------------------
/// | c0 | c1 | c2 | c3 |
/// -----------------------------
/// ```
pub fn gen_evals_trace(
eval_point: &[SecureField],
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
// TODO(andrew): Consider storing eq evals as a SecureColumn.
let eq_evals = SimdBackend::gen_eq_evals(eval_point, SecureField::one()).into_evals();
let eq_evals_coordinate_columns = eq_evals.into_secure_column_by_coords().columns;

let n_variables = eval_point.len();
let domain = CanonicCoset::new(n_variables as u32).circle_domain();
eq_evals_coordinate_columns
.map(|col| CircleEvaluation::new(domain, col))
.into()
}

pub fn gen_constants_trace(
n_variables: usize,
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let log_size = n_variables as u32;
let mut constants_trace = Vec::new();
constants_trace.push(gen_is_first(log_size));

// TODO(andrew): Note the last selector column is not needed. The column for `is_first`
// with an offset for each half coset midpoint can be used instead.
for variable_i in 1..n_variables as u32 {
let half_coset_log_step = variable_i;
let half_coset_offset = (1 << (half_coset_log_step - 1)) - 1;

let log_step = half_coset_log_step + 1;
let offset = half_coset_offset * 2;

constants_trace.push(gen_is_step_with_offset(log_size, log_step, offset))
}

constants_trace
}
}
1 change: 1 addition & 0 deletions crates/prover/src/examples/xor/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod eq_eval_constraints;
pub mod prefix_sum_constraints;
9 changes: 2 additions & 7 deletions crates/prover/src/examples/xor/prefix_sum_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,15 @@ impl<E: EvalAtRow> PrefixSumMask<E> {
#[cfg(test)]
mod tests {
use itertools::Itertools;
use num_traits::One;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use test_log::test;

use super::inclusive_prefix_sum_check;
use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::{assert_constraints, EvalAtRow};
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
Expand Down Expand Up @@ -117,10 +116,6 @@ mod tests {
fn gen_constants_trace(
log_size: u32,
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let trace_domain = CanonicCoset::new(log_size).circle_domain();
// Column is `1` at the first trace point and `0` on all other trace points.
let mut is_first = Col::<SimdBackend, BaseField>::zeros(1 << log_size);
is_first.as_mut_slice()[0] = BaseField::one();
vec![CircleEvaluation::new(trace_domain, is_first)]
vec![gen_is_first(log_size)]
}
}
Loading

0 comments on commit 586ecf4

Please sign in to comment.