Skip to content

Commit

Permalink
Add wide fib lookup constraint. (#637)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 authored Jun 20, 2024
1 parent c23c6c7 commit f67fadc
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 38 deletions.
8 changes: 8 additions & 0 deletions crates/prover/src/core/air/air_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ pub trait AirExt: Air {
.unwrap()
}

fn n_interaction_phases(&self) -> u32 {
self.components()
.iter()
.map(|component| component.n_interaction_phases())
.max()
.unwrap()
}

fn trace_commitment_domains(&self) -> Vec<CanonicCoset> {
self.column_log_sizes()
.iter()
Expand Down
3 changes: 3 additions & 0 deletions crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ pub trait Component {

fn max_constraint_log_degree_bound(&self) -> u32;

/// Returns the number of interaction phases done by the component.
fn n_interaction_phases(&self) -> u32;

/// Returns the degree bounds of each trace column.
fn trace_log_degree_bounds(&self) -> Vec<u32>;

Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/core/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ pub fn pair_vanishing<F: ExtensionOf<BaseField>>(
/// Evaluates a vanishing polynomial of the vanish_point at a point.
/// Note that this function has a pole on the antipode of the vanish_point.
pub fn point_vanishing<F: ExtensionOf<BaseField>, EF: ExtensionOf<F>>(
vanish_point: CirclePoint<EF>,
p: CirclePoint<F>,
vanish_point: CirclePoint<F>,
p: CirclePoint<EF>,
) -> EF {
let h = p.into_ef() - vanish_point;
let h = p - vanish_point.into_ef();
h.y / (EF::one() + h.x)
}

Expand Down
15 changes: 9 additions & 6 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ pub fn evaluate_and_commit_on_trace<B: Backend + MerkleOps<MerkleHasher>>(
.map(|poly| poly.interpolate_with_twiddles(twiddles))
})
.collect_vec();
let n_interaction_traces = interaction_trace_polys.len();
if n_interaction_traces > 0 {
if !interaction_trace_polys.is_empty() {
commitment_scheme.commit(interaction_trace_polys, channel, twiddles);
}

Expand Down Expand Up @@ -119,7 +118,7 @@ pub fn generate_proof<B: Backend + MerkleOps<MerkleHasher>>(
// TODO(spapini): Change when we support multiple interactions.
// First tree - trace.
let mut sample_points = TreeVec::new(vec![sample_points.flatten()]);
if commitment_scheme.trees.len() > 2 {
if air.n_interaction_phases() == 2 {
// Second tree - interaction trace.
sample_points.push(vec![
vec![oods_point];
Expand Down Expand Up @@ -210,7 +209,7 @@ pub fn verify(
commitment_scheme.commit(proof.commitments[0], air.column_log_sizes(), channel);
let interaction_elements = air.interaction_elements(channel);

if proof.commitments.len() > 2 {
if air.n_interaction_phases() == 2 {
commitment_scheme.commit(
proof.commitments[1],
air.column_log_sizes()[..1].to_vec(),
Expand All @@ -236,7 +235,7 @@ pub fn verify(
// TODO(spapini): Change when we support multiple interactions.
// First tree - trace.
let mut sample_points = TreeVec::new(vec![trace_sample_points.flatten()]);
if proof.commitments.len() > 2 {
if air.n_interaction_phases() == 2 {
// Second tree - interaction trace.
// TODO(AlonH): Get the number of interaction traces from the air.
sample_points.push(vec![vec![oods_point]; 1]);
Expand Down Expand Up @@ -288,7 +287,7 @@ fn sampled_values_to_mask(
)
});

if sampled_values.len() > 2 {
if air.n_interaction_phases() == 2 {
let interaction_values = &mut sampled_values
.get(1)
.ok_or(InvalidOodsSampleStructure)?
Expand Down Expand Up @@ -429,6 +428,10 @@ mod tests {
self.max_constraint_log_degree_bound
}

fn n_interaction_phases(&self) -> u32 {
1
}

fn trace_log_degree_bounds(&self) -> Vec<u32> {
vec![self.log_size]
}
Expand Down
11 changes: 6 additions & 5 deletions crates/prover/src/core/utils.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::iter::Peekable;

use num_traits::{One, Zero};
use num_traits::One;

use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::fields::ExtensionOf;

pub trait IteratorMutExt<'a, T: 'a>: Iterator<Item = &'a mut T> {
fn assign(self, other: impl IntoIterator<Item = T>)
Expand Down Expand Up @@ -89,14 +90,14 @@ pub fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec<SecureF

/// Securely combines the given values using the given random alpha and z.
/// Alpha and z should be secure field elements for soundness.
pub fn shifted_secure_combination(
values: &[BaseField],
pub fn shifted_secure_combination<F: ExtensionOf<BaseField>>(
values: &[F],
alpha: BaseField,
z: BaseField,
) -> BaseField {
) -> F {
let res = values
.iter()
.fold(BaseField::zero(), |acc, &value| acc * alpha + value);
.fold(F::zero(), |acc, &value| acc * alpha + value);
res - z
}

Expand Down
4 changes: 4 additions & 0 deletions crates/prover/src/examples/fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ impl Component for FibonacciComponent {
self.log_size + 1
}

fn n_interaction_phases(&self) -> u32 {
1
}

fn trace_log_degree_bounds(&self) -> Vec<u32> {
vec![self.log_size]
}
Expand Down
26 changes: 21 additions & 5 deletions crates/prover/src/examples/wide_fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@ use crate::core::air::mask::fixed_mask_points;
use crate::core::air::{Air, Component, ComponentTraceWriter};
use crate::core::backend::CpuBackend;
use crate::core::circle::CirclePoint;
use crate::core::constraints::coset_vanishing;
use crate::core::constraints::{coset_vanishing, point_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::shifted_secure_combination;
use crate::core::{ColumnVec, InteractionElements};
use crate::examples::wide_fibonacci::trace_gen::write_lookup_column;

pub const LOG_N_COLUMNS: usize = 8;
pub const N_COLUMNS: usize = 1 << LOG_N_COLUMNS;

const ALPHA_ID: &str = "wide_fibonacci_alpha";
const Z_ID: &str = "wide_fibonacci_z";
pub const ALPHA_ID: &str = "wide_fibonacci_alpha";
pub const Z_ID: &str = "wide_fibonacci_z";

/// Component that computes 2^`self.log_n_instances` instances of fibonacci sequences of size
/// 2^`self.log_fibonacci_size`. The numbers are computes over [N_COLUMNS] trace columns. The
Expand Down Expand Up @@ -57,13 +58,17 @@ impl Air for WideFibAir {

impl Component for WideFibComponent {
fn n_constraints(&self) -> usize {
self.n_columns() - 2
self.n_columns() - 1
}

fn max_constraint_log_degree_bound(&self) -> u32 {
self.log_column_size() + 1
}

fn n_interaction_phases(&self) -> u32 {
2
}

fn trace_log_degree_bounds(&self) -> Vec<u32> {
vec![self.log_column_size(); self.n_columns()]
}
Expand All @@ -84,9 +89,20 @@ impl Component for WideFibComponent {
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
interaction_elements: &InteractionElements,
) {
let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset;
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);
let lookup_numerator = (mask[self.n_columns()][0]
* shifted_secure_combination(
&[mask[self.n_columns() - 2][0], mask[self.n_columns() - 1][0]],
alpha,
z,
))
- shifted_secure_combination(&[mask[0][0], mask[1][0]], alpha, z);
let lookup_denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(lookup_numerator / lookup_denom);

let denom = coset_vanishing(constraint_zero_domain, point);
let denom_inverse = denom.inverse();
for i in 0..self.n_columns() - 2 {
Expand Down
44 changes: 37 additions & 7 deletions crates/prover/src/examples/wide_fibonacci/constraint_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::BTreeMap;
use itertools::{zip_eq, Itertools};
use num_traits::Zero;

use super::component::{Input, WideFibAir, WideFibComponent};
use super::component::{Input, WideFibAir, WideFibComponent, ALPHA_ID, Z_ID};
use super::trace_gen::write_trace_row;
use crate::core::air::accumulation::DomainEvaluationAccumulator;
use crate::core::air::{
Expand All @@ -12,13 +12,13 @@ use crate::core::air::{
};
use crate::core::backend::CpuBackend;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::constraints::coset_vanishing;
use crate::core::constraints::{coset_vanishing, point_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse;
use crate::core::utils::{bit_reverse, shifted_secure_combination};
use crate::core::{ColumnVec, ComponentVec, InteractionElements};
use crate::examples::wide_fibonacci::component::LOG_N_COLUMNS;

Expand Down Expand Up @@ -59,34 +59,64 @@ impl ComponentProver<CpuBackend> for WideFibComponent {
&self,
trace: &ComponentTrace<'_, CpuBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<CpuBackend>,
_interaction_elements: &InteractionElements,
interaction_elements: &InteractionElements,
) {
let max_constraint_degree = self.max_constraint_log_degree_bound();
let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain();
let trace_evals = &trace.evals;
let zero_domain = CanonicCoset::new(self.log_column_size()).coset;
let mut denoms = vec![];
let mut lookup_denoms = vec![];
for point in trace_eval_domain.iter() {
denoms.push(coset_vanishing(zero_domain, point));
lookup_denoms.push(point_vanishing(zero_domain.at(0), point));
}
bit_reverse(&mut denoms);
let mut denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&denoms, &mut denom_inverses);
bit_reverse(&mut lookup_denoms);
let mut lookup_denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)];
BaseField::batch_inverse(&lookup_denoms, &mut lookup_denom_inverses);
let mut numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];
let mut lookup_numerators = vec![SecureField::zero(); 1 << (max_constraint_degree)];
let [mut accum] =
evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]);
let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]);

#[allow(clippy::needless_range_loop)]
for i in 0..trace_eval_domain.size() {
// Step constraints.
for j in 0..self.n_columns() - 2 {
numerators[i] += accum.random_coeff_powers[self.n_columns() - 3 - j]
* (trace_evals[0][j][i].square() + trace_evals[0][j + 1][i].square()
- trace_evals[0][j + 2][i]);
}

// Lookup constraints.
lookup_numerators[i] = accum.random_coeff_powers[self.n_columns() - 2]
* ((trace_evals[1][0][i]
* shifted_secure_combination(
&[
trace_evals[0][self.n_columns() - 2][i],
trace_evals[0][self.n_columns() - 1][i],
],
alpha,
z,
))
- shifted_secure_combination(
&[trace_evals[0][0][i], trace_evals[0][1][i]],
alpha,
z,
));
}
for (i, (num, denom_inverse)) in numerators.iter().zip(denom_inverses.iter()).enumerate() {
accum.accumulate(i, *num * *denom_inverse);
}
for (i, (num, denom)) in numerators.iter().zip(denom_inverses.iter()).enumerate() {
accum.accumulate(i, *num * *denom);
for (i, (num, denom_inverse)) in lookup_numerators
.iter()
.zip(lookup_denom_inverses.iter())
.enumerate()
{
accum.accumulate(i, *num * *denom_inverse);
}
}
}
Expand Down
Loading

0 comments on commit f67fadc

Please sign in to comment.