Skip to content

Commit

Permalink
Refactor constraint evaluation. (#696)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Jul 9, 2024
1 parent 4fe7142 commit defcfe2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 90 deletions.
20 changes: 19 additions & 1 deletion crates/prover/src/core/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ use std::ops::Add;

use num_traits::{One, Zero};

use super::circle::CirclePoint;
use super::constraints::point_vanishing;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::fields::ExtensionOf;
use super::fields::{ExtensionOf, FieldExpOps};
use super::poly::circle::CircleDomain;

pub trait IteratorMutExt<'a, T: 'a>: Iterator<Item = &'a mut T> {
fn assign(self, other: impl IntoIterator<Item = T>)
Expand Down Expand Up @@ -150,6 +153,21 @@ where
res - z
}

pub fn point_vanish_denominator_inverses(
domain: CircleDomain,
vanish_point: CirclePoint<BaseField>,
) -> Vec<BaseField> {
let mut denoms = vec![];
for point in domain.iter() {
// TODO(AlonH): Use `point_vanishing_fraction` instead of `point_vanishing` everywhere.
denoms.push(point_vanishing(vanish_point, point));
}
bit_reverse(&mut denoms);
let mut denom_inverses = vec![BaseField::zero(); 1 << (domain.log_size())];
BaseField::batch_inverse(&denoms, &mut denom_inverses);
denom_inverses
}

#[cfg(test)]
mod tests {
use itertools::Itertools;
Expand Down
131 changes: 42 additions & 89 deletions crates/prover/src/examples/wide_fibonacci/constraint_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::core::air::{AirProver, Component, ComponentProver, ComponentTrace};
use crate::core::backend::CpuBackend;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::circle::Coset;
use crate::core::constraints::{coset_vanishing, point_excluder, point_vanishing};
use crate::core::constraints::{coset_vanishing, point_excluder};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
Expand All @@ -22,7 +22,8 @@ use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{BASE_TRACE, INTERACTION_TRACE};
use crate::core::utils::{
bit_reverse, previous_bit_reversed_circle_domain_index, shifted_secure_combination,
bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index,
shifted_secure_combination,
};
use crate::core::{ColumnVec, InteractionElements, LookupValues};
use crate::examples::wide_fibonacci::component::LOG_N_COLUMNS;
Expand Down Expand Up @@ -72,56 +73,35 @@ impl WideFibComponent {
accum: &mut ColumnAccumulator<'_, CpuBackend>,
lookup_values: &LookupValues,
) {
let max_constraint_degree = self.max_constraint_log_degree_bound();
let mut first_point_denoms = vec![];
let mut last_point_denoms = vec![];
for point in trace_eval_domain.iter() {
// TODO(AlonH): Use `point_vanishing_fraction` instead of `point_vanishing` everywhere.
first_point_denoms.push(point_vanishing(zero_domain.at(0), point));
last_point_denoms.push(point_vanishing(
zero_domain.at(zero_domain.size() - 1),
point,
));
}
bit_reverse(&mut first_point_denoms);
bit_reverse(&mut last_point_denoms);
let mut first_point_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
let mut last_point_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&first_point_denoms, &mut first_point_denom_inverses);
BaseField::batch_inverse(&last_point_denoms, &mut last_point_denom_inverses);
let mut first_point_numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];
let mut last_point_numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];
let first_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let last_point_denom_inverses = point_vanish_denominator_inverses(
trace_eval_domain,
zero_domain.at(zero_domain.size() - 1),
);
let (lookup_value_0, lookup_value_1, lookup_value_n_minus_2, lookup_value_n_minus_1) = (
lookup_values[LOOKUP_VALUE_0_ID],
lookup_values[LOOKUP_VALUE_1_ID],
lookup_values[LOOKUP_VALUE_N_MINUS_2_ID],
lookup_values[LOOKUP_VALUE_N_MINUS_1_ID],
);

#[allow(clippy::needless_range_loop)]
for i in 0..trace_eval_domain.size() {
first_point_numerators[i] = accum.random_coeff_powers[self.n_columns() + 4]
for (i, (first_point_denom_inverse, last_point_denom_inverse)) in
zip_eq(first_point_denom_inverses, last_point_denom_inverses).enumerate()
{
let first_point_numerator = accum.random_coeff_powers[self.n_columns() + 4]
* (trace_evals[BASE_TRACE][0][i] - lookup_value_0)
+ accum.random_coeff_powers[self.n_columns() + 3]
* (trace_evals[BASE_TRACE][1][i] - lookup_value_1);
last_point_numerators[i] = accum.random_coeff_powers[self.n_columns() + 2]
let last_point_numerator = accum.random_coeff_powers[self.n_columns() + 2]
* (trace_evals[BASE_TRACE][self.n_columns() - 2][i] - lookup_value_n_minus_2)
+ accum.random_coeff_powers[self.n_columns() + 1]
* (trace_evals[BASE_TRACE][self.n_columns() - 1][i] - lookup_value_n_minus_1);
}
for (i, (num, denom_inverse)) in first_point_numerators
.iter()
.zip(first_point_denom_inverses.iter())
.enumerate()
{
accum.accumulate(i, *num * *denom_inverse);
}
for (i, (num, denom_inverse)) in last_point_numerators
.iter()
.zip(last_point_denom_inverses.iter())
.enumerate()
{
accum.accumulate(i, *num * *denom_inverse);
accum.accumulate(
i,
first_point_numerator * first_point_denom_inverse
+ last_point_numerator * last_point_denom_inverse,
);
}
}

Expand All @@ -140,19 +120,16 @@ impl WideFibComponent {
bit_reverse(&mut denoms);
let mut denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&denoms, &mut denom_inverses);
let mut numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];

#[allow(clippy::needless_range_loop)]
for i in 0..trace_eval_domain.size() {
for (i, denom_inverse) in denom_inverses.iter().enumerate() {
let mut numerator = SecureField::zero();
for j in 0..self.n_columns() - 2 {
numerators[i] += accum.random_coeff_powers[self.n_columns() - 3 - j]
numerator += accum.random_coeff_powers[self.n_columns() - 3 - j]
* (trace_evals[BASE_TRACE][j][i].square()
+ trace_evals[BASE_TRACE][j + 1][i].square()
- trace_evals[BASE_TRACE][j + 2][i]);
}
}
for (i, (num, denom_inverse)) in numerators.iter().zip(denom_inverses.iter()).enumerate() {
accum.accumulate(i, *num * *denom_inverse);
accum.accumulate(i, numerator * *denom_inverse)
}
}

Expand All @@ -165,24 +142,12 @@ impl WideFibComponent {
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
let max_constraint_degree = self.max_constraint_log_degree_bound();
let mut first_point_denoms = vec![];
let mut last_point_denoms = vec![];
for point in trace_eval_domain.iter() {
first_point_denoms.push(point_vanishing(zero_domain.at(0), point));
last_point_denoms.push(point_vanishing(
zero_domain.at(zero_domain.size() - 1),
point,
));
}
bit_reverse(&mut first_point_denoms);
bit_reverse(&mut last_point_denoms);
let mut first_point_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
let mut last_point_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&first_point_denoms, &mut first_point_denom_inverses);
BaseField::batch_inverse(&last_point_denoms, &mut last_point_denom_inverses);
let mut first_point_numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];
let mut last_point_numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];
let first_point_denom_inverses =
point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0));
let last_point_denom_inverses = point_vanish_denominator_inverses(
trace_eval_domain,
zero_domain.at(zero_domain.size() - 1),
);
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);
let (lookup_value_0, lookup_value_1, lookup_value_n_minus_2, lookup_value_n_minus_1) = (
lookup_values[LOOKUP_VALUE_0_ID],
Expand All @@ -191,12 +156,13 @@ impl WideFibComponent {
lookup_values[LOOKUP_VALUE_N_MINUS_1_ID],
);

#[allow(clippy::needless_range_loop)]
for i in 0..trace_eval_domain.size() {
for (i, (first_point_denom_inverse, last_point_denom_inverse)) in
zip_eq(first_point_denom_inverses, last_point_denom_inverses).enumerate()
{
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
first_point_numerators[i] = accum.random_coeff_powers[self.n_columns() - 1]
let first_point_numerator = accum.random_coeff_powers[self.n_columns() - 1]
* ((value
* shifted_secure_combination(
&[
Expand All @@ -211,28 +177,19 @@ impl WideFibComponent {
alpha,
z,
));
last_point_numerators[i] = accum.random_coeff_powers[self.n_columns() - 2]
let last_point_numerator = accum.random_coeff_powers[self.n_columns() - 2]
* ((value
* shifted_secure_combination(
&[lookup_value_n_minus_2, lookup_value_n_minus_1],
alpha,
z,
))
- shifted_secure_combination(&[lookup_value_0, lookup_value_1], alpha, z));
}
for (i, (num, denom_inverse)) in first_point_numerators
.iter()
.zip(first_point_denom_inverses.iter())
.enumerate()
{
accum.accumulate(i, *num * *denom_inverse);
}
for (i, (num, denom_inverse)) in last_point_numerators
.iter()
.zip(last_point_denom_inverses.iter())
.enumerate()
{
accum.accumulate(i, *num * *denom_inverse);
accum.accumulate(
i,
first_point_numerator * first_point_denom_inverse
+ last_point_numerator * last_point_denom_inverse,
);
}
}

Expand All @@ -255,11 +212,9 @@ impl WideFibComponent {
bit_reverse(&mut denoms);
let mut denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&denoms, &mut denom_inverses);
let mut numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);

#[allow(clippy::needless_range_loop)]
for i in 0..trace_eval_domain.size() {
for (i, denom_inverse) in denom_inverses.iter().enumerate() {
let value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][i]
}));
Expand All @@ -271,7 +226,7 @@ impl WideFibComponent {
let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| {
trace_evals[INTERACTION_TRACE][j][prev_index]
}));
numerators[i] = accum.random_coeff_powers[self.n_columns()]
let numerator = accum.random_coeff_powers[self.n_columns()]
* ((value
* shifted_secure_combination(
&[
Expand All @@ -287,9 +242,7 @@ impl WideFibComponent {
alpha,
z,
)));
}
for (i, (num, denom_inverse)) in numerators.iter().zip(denom_inverses.iter()).enumerate() {
accum.accumulate(i, *num * *denom_inverse);
accum.accumulate(i, numerator * *denom_inverse);
}
}
}
Expand Down

0 comments on commit defcfe2

Please sign in to comment.