Skip to content

Commit

Permalink
Add lookup final boundary constraints.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Jun 23, 2024
1 parent 55ada84 commit 3578d66
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 17 deletions.
13 changes: 12 additions & 1 deletion crates/prover/src/core/air/air_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::core::pcs::{CommitmentTreeProver, TreeVec};
use crate::core::poly::circle::SecureCirclePoly;
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::ops::MerkleOps;
use crate::core::{ColumnVec, ComponentVec, InteractionElements};
use crate::core::{ColumnVec, ComponentVec, InteractionElements, LookupValues};

pub trait AirExt: Air {
fn composition_log_degree_bound(&self) -> u32 {
Expand Down Expand Up @@ -58,6 +58,7 @@ pub trait AirExt: Air {
mask_values: &ComponentVec<Vec<SecureField>>,
random_coeff: SecureField,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) -> SecureField {
let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff);
zip_eq(self.components(), &mask_values.0).for_each(|(component, mask)| {
Expand All @@ -66,6 +67,7 @@ pub trait AirExt: Air {
mask,
&mut evaluation_accumulator,
interaction_elements,
lookup_values,
)
});
evaluation_accumulator.finalize()
Expand Down Expand Up @@ -130,6 +132,7 @@ pub trait AirProverExt<B: Backend>: AirProver<B> {
random_coeff: SecureField,
component_traces: &[ComponentTrace<'_, B>],
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) -> SecureCirclePoly<B> {
let total_constraints: usize = self
.prover_components()
Expand All @@ -146,10 +149,18 @@ pub trait AirProverExt<B: Backend>: AirProver<B> {
trace,
&mut accumulator,
interaction_elements,
lookup_values,
)
});
accumulator.finalize()
}

fn lookup_values(&self, component_traces: &[ComponentTrace<'_, B>]) -> LookupValues {
let mut values = LookupValues::default();
zip_eq(self.prover_components(), component_traces)
.for_each(|(component, trace)| values.extend(component.lookup_values(trace)));
values
}
}

impl<B: Backend, A: AirProver<B>> AirProverExt<B> for A {}
9 changes: 8 additions & 1 deletion crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::fields::qm31::SecureField;
use super::pcs::TreeVec;
use super::poly::circle::{CircleEvaluation, CirclePoly};
use super::poly::BitReversedOrder;
use super::{ColumnVec, InteractionElements};
use super::{ColumnVec, InteractionElements, LookupValues};

pub mod accumulation;
mod air_ext;
Expand Down Expand Up @@ -74,6 +74,7 @@ pub trait Component {
mask: &ColumnVec<Vec<SecureField>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
);
}

Expand All @@ -93,7 +94,13 @@ pub trait ComponentProver<B: Backend>: Component {
trace: &ComponentTrace<'_, B>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<B>,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
);

/// Returns the values needed to evaluate the components lookup boundary constraints.
fn lookup_values(&self, _trace: &ComponentTrace<'_, B>) -> LookupValues {
LookupValues::default()
}
}

/// A component trace is a set of polynomials for each column on that component.
Expand Down
11 changes: 11 additions & 0 deletions crates/prover/src/core/fields/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,17 @@ macro_rules! impl_extension_field {
}
}

impl TryInto<M31> for $field_name {
type Error = ();

fn try_into(self) -> Result<M31, Self::Error> {
if self.1 != <$extended_field_name>::zero() {
return Err(());
}
self.0.try_into().map_err(|_| ())
}
}

impl AddAssign<M31> for $field_name {
fn add_assign(&mut self, rhs: M31) {
*self = *self + rhs;
Expand Down
26 changes: 25 additions & 1 deletion crates/prover/src/core/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::collections::BTreeMap;
use std::ops::{Deref, DerefMut, Index};

use fields::m31::BaseField;

use self::fields::qm31::SecureField;

pub mod air;
Expand Down Expand Up @@ -61,7 +63,7 @@ impl<T> DerefMut for ComponentVec<T> {
}
}

#[derive(Default)]
#[derive(Default, Debug)]
pub struct InteractionElements(BTreeMap<String, SecureField>);

impl InteractionElements {
Expand All @@ -82,3 +84,25 @@ impl Index<&str> for InteractionElements {
&self.0[index]
}
}

#[derive(Default, Debug)]
pub struct LookupValues(BTreeMap<String, BaseField>);

impl LookupValues {
pub fn new(values: BTreeMap<String, BaseField>) -> Self {
Self(values)
}

pub fn extend(&mut self, other: Self) {
self.0.extend(other.0);
}
}

impl Index<&str> for LookupValues {
type Output = BaseField;

fn index(&self, index: &str) -> &Self::Output {
// TODO(AlonH): Return an error if the key is not found.
&self.0[index]
}
}
15 changes: 12 additions & 3 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::pcs::{CommitmentSchemeProof, TreeVec};
use super::poly::circle::{CanonicCoset, SecureCirclePoly, MAX_CIRCLE_DOMAIN_LOG_SIZE};
use super::poly::twiddles::TwiddleTree;
use super::proof_of_work::ProofOfWorkVerificationError;
use super::{ColumnVec, InteractionElements};
use super::{ColumnVec, InteractionElements, LookupValues};
use crate::core::air::{Air, AirExt, AirProverExt};
use crate::core::backend::CpuBackend;
use crate::core::channel::{Blake2sChannel, Channel as ChannelTrait};
Expand Down Expand Up @@ -39,6 +39,7 @@ pub const N_QUERIES: usize = 3;
#[derive(Debug)]
pub struct StarkProof {
pub commitments: TreeVec<<ChannelHasher as Hasher>::Hash>,
pub lookup_values: LookupValues,
pub commitment_scheme_proof: CommitmentSchemeProof,
}

Expand Down Expand Up @@ -96,10 +97,13 @@ pub fn generate_proof<B: Backend + MerkleOps<MerkleHasher>>(
let random_coeff = channel.draw_felt();

let span = span!(Level::INFO, "Composition generation").entered();
let component_traces = air.component_traces(&commitment_scheme.trees);
let lookup_values = air.lookup_values(&component_traces);
let composition_polynomial_poly = air.compute_composition_polynomial(
random_coeff,
&air.component_traces(&commitment_scheme.trees),
&component_traces,
interaction_elements,
&lookup_values,
);
span.exit();

Expand Down Expand Up @@ -128,13 +132,15 @@ pub fn generate_proof<B: Backend + MerkleOps<MerkleHasher>>(
&trace_oods_values,
random_coeff,
interaction_elements,
&lookup_values,
)
{
return Err(ProvingError::ConstraintsNotSatisfied);
}

Ok(StarkProof {
commitments: commitment_scheme.roots(),
lookup_values,
commitment_scheme_proof,
})
}
Expand Down Expand Up @@ -229,6 +235,7 @@ pub fn verify(
&trace_oods_values,
random_coeff,
&interaction_elements,
&proof.lookup_values,
)
{
return Err(VerificationError::OodsNotMatching);
Expand Down Expand Up @@ -353,7 +360,7 @@ mod tests {
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{prove, ProvingError};
use crate::core::test_utils::test_channel;
use crate::core::{ColumnVec, InteractionElements};
use crate::core::{ColumnVec, InteractionElements, LookupValues};
use crate::qm31;

struct TestAir<C: ComponentProver<CpuBackend>> {
Expand Down Expand Up @@ -431,6 +438,7 @@ mod tests {
_mask: &crate::core::ColumnVec<Vec<SecureField>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
) {
evaluation_accumulator.accumulate(qm31!(0, 0, 0, 1))
}
Expand All @@ -452,6 +460,7 @@ mod tests {
_trace: &ComponentTrace<'_, CpuBackend>,
_evaluation_accumulator: &mut DomainEvaluationAccumulator<CpuBackend>,
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
) {
// Does nothing.
}
Expand Down
4 changes: 3 additions & 1 deletion crates/prover/src/examples/fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse_index;
use crate::core::{ColumnVec, InteractionElements};
use crate::core::{ColumnVec, InteractionElements, LookupValues};

pub struct FibonacciComponent {
pub log_size: u32,
Expand Down Expand Up @@ -113,6 +113,7 @@ impl Component for FibonacciComponent {
mask: &ColumnVec<Vec<SecureField>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
) {
evaluation_accumulator.accumulate(
self.step_constraint_eval_quotient_by_mask(point, &mask[0][..].try_into().unwrap()),
Expand Down Expand Up @@ -142,6 +143,7 @@ impl ComponentProver<CpuBackend> for FibonacciComponent {
trace: &ComponentTrace<'_, CpuBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<CpuBackend>,
_interaction_elements: &InteractionElements,
_lookup_values: &LookupValues,
) {
let poly = &trace.polys[0][0];
let trace_domain = CanonicCoset::new(self.log_size);
Expand Down
9 changes: 5 additions & 4 deletions crates/prover/src/examples/fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ impl MultiFibonacci {
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use std::collections::BTreeMap;
use std::iter::zip;

use itertools::Itertools;
Expand All @@ -129,7 +128,7 @@ mod tests {
use crate::core::prover::VerificationError;
use crate::core::queries::Queries;
use crate::core::utils::bit_reverse;
use crate::core::InteractionElements;
use crate::core::{InteractionElements, LookupValues};
use crate::{m31, qm31};

pub fn generate_test_queries(n_queries: usize, trace_length: usize) -> Vec<usize> {
Expand Down Expand Up @@ -159,7 +158,8 @@ mod tests {
let composition_polynomial_poly = fib.air.compute_composition_polynomial(
random_coeff,
&component_traces,
&InteractionElements::new(BTreeMap::new()),
&InteractionElements::default(),
&LookupValues::default(),
);

// Evaluate this polynomial at another point out of the evaluation domain and compare to
Expand All @@ -181,7 +181,8 @@ mod tests {
point,
&mask_values,
&mut evaluation_accumulator,
&InteractionElements::new(BTreeMap::new()),
&InteractionElements::default(),
&LookupValues::default(),
);
let oods_value = evaluation_accumulator.finalize();

Expand Down
38 changes: 36 additions & 2 deletions crates/prover/src/examples/wide_fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@ use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, SecureCirclePoly};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::shifted_secure_combination;
use crate::core::{ColumnVec, InteractionElements};
use crate::core::{ColumnVec, InteractionElements, LookupValues};
use crate::examples::wide_fibonacci::trace_gen::write_lookup_column;

pub const LOG_N_COLUMNS: usize = 8;
pub const N_COLUMNS: usize = 1 << LOG_N_COLUMNS;

pub const ALPHA_ID: &str = "wide_fibonacci_alpha";
pub const Z_ID: &str = "wide_fibonacci_z";
pub const LOOKUP_VALUE_0_ID: &str = "wide_fibonacci_0";
pub const LOOKUP_VALUE_1_ID: &str = "wide_fibonacci_1";
pub const LOOKUP_VALUE_N_MINUS_2_ID: &str = "wide_fibonacci_n-2";
pub const LOOKUP_VALUE_N_MINUS_1_ID: &str = "wide_fibonacci_n-1";

/// Component that computes 2^`self.log_n_instances` instances of fibonacci sequences of size
/// 2^`self.log_fibonacci_size`. The numbers are computes over [N_COLUMNS] trace columns. The
Expand All @@ -48,6 +52,28 @@ impl WideFibComponent {
N_COLUMNS
}

fn evaluate_trace_boundary_constraints_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
constraint_zero_domain: Coset,
lookup_values: &LookupValues,
) {
let numerator = mask[0][0] - lookup_values[LOOKUP_VALUE_0_ID];
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
let numerator = mask[1][0] - lookup_values[LOOKUP_VALUE_1_ID];
evaluation_accumulator.accumulate(numerator / denom);

let numerator = mask[self.n_columns() - 2][0] - lookup_values[LOOKUP_VALUE_N_MINUS_2_ID];
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
let numerator = mask[self.n_columns() - 1][0] - lookup_values[LOOKUP_VALUE_N_MINUS_1_ID];
let denom = point_vanishing(constraint_zero_domain.at(0), point);
evaluation_accumulator.accumulate(numerator / denom);
}

fn evaluate_trace_step_constraints_at_point(
&self,
point: CirclePoint<SecureField>,
Expand Down Expand Up @@ -129,7 +155,7 @@ impl Air for WideFibAir {

impl Component for WideFibComponent {
fn n_constraints(&self) -> usize {
self.n_columns()
self.n_columns() + 4
}

fn max_constraint_log_degree_bound(&self) -> u32 {
Expand Down Expand Up @@ -168,8 +194,16 @@ impl Component for WideFibComponent {
mask: &ColumnVec<Vec<SecureField>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) {
let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset;
self.evaluate_trace_boundary_constraints_at_point(
point,
mask,
evaluation_accumulator,
constraint_zero_domain,
lookup_values,
);
self.evaluate_lookup_step_constraints_at_point(
point,
mask,
Expand Down
Loading

0 comments on commit 3578d66

Please sign in to comment.