Skip to content

Commit

Permalink
Simplify Air less (#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware authored Aug 6, 2024
1 parent c1c4d3d commit 8344112
Show file tree
Hide file tree
Showing 13 changed files with 136 additions and 894 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use itertools::{zip_eq, Itertools};

use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use super::{Air, AirProver, ComponentTrace};
use super::{Component, ComponentProver, ComponentTrace};
use crate::core::backend::Backend;
use crate::core::circle::CirclePoint;
use crate::core::fields::qm31::SecureField;
Expand All @@ -11,21 +11,23 @@ use crate::core::poly::circle::SecureCirclePoly;
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};
use crate::core::{ColumnVec, InteractionElements, LookupValues};

pub trait AirExt: Air {
fn composition_log_degree_bound(&self) -> u32 {
self.components()
pub struct Components<'a>(pub Vec<&'a dyn Component>);

impl<'a> Components<'a> {
pub fn composition_log_degree_bound(&self) -> u32 {
self.0
.iter()
.map(|component| component.max_constraint_log_degree_bound())
.max()
.unwrap()
}

fn mask_points(
pub fn mask_points(
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let mut air_points = TreeVec::default();
for component in self.components() {
for component in &self.0 {
let component_points = component.mask_points(point);
if air_points.len() < component_points.len() {
air_points.resize(component_points.len(), vec![]);
Expand All @@ -41,7 +43,7 @@ pub trait AirExt: Air {
air_points
}

fn eval_composition_polynomial_at_point(
pub fn eval_composition_polynomial_at_point(
&self,
point: CirclePoint<SecureField>,
mask_values: &Vec<TreeVec<Vec<Vec<SecureField>>>>,
Expand All @@ -50,7 +52,7 @@ pub trait AirExt: Air {
lookup_values: &LookupValues,
) -> SecureField {
let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff);
zip_eq(self.components(), mask_values).for_each(|(component, mask)| {
zip_eq(&self.0, mask_values).for_each(|(component, mask)| {
component.evaluate_constraint_quotients_at_point(
point,
mask,
Expand All @@ -62,9 +64,9 @@ pub trait AirExt: Air {
evaluation_accumulator.finalize()
}

fn column_log_sizes(&self) -> TreeVec<ColumnVec<u32>> {
pub fn column_log_sizes(&self) -> TreeVec<ColumnVec<u32>> {
let mut air_sizes = TreeVec::default();
self.components().iter().for_each(|component| {
self.0.iter().for_each(|component| {
let component_sizes = component.trace_log_degree_bounds();
if air_sizes.len() < component_sizes.len() {
air_sizes.resize(component_sizes.len(), vec![]);
Expand All @@ -77,11 +79,45 @@ pub trait AirExt: Air {
});
air_sizes
}
}

pub struct ComponentProvers<'a, B: Backend>(pub Vec<&'a dyn ComponentProver<B>>);

impl<'a, B: Backend> ComponentProvers<'a, B> {
pub fn components(&self) -> Components<'_> {
Components(self.0.iter().map(|c| *c as &dyn Component).collect_vec())
}
pub fn compute_composition_polynomial(
&self,
random_coeff: SecureField,
component_traces: &[ComponentTrace<'_, B>],
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) -> SecureCirclePoly<B> {
let total_constraints: usize = self.0.iter().map(|c| c.n_constraints()).sum();
let mut accumulator = DomainEvaluationAccumulator::new(
random_coeff,
self.components().composition_log_degree_bound(),
total_constraints,
);
zip_eq(&self.0, component_traces).for_each(|(component, trace)| {
component.evaluate_constraint_quotients_on_domain(
trace,
&mut accumulator,
interaction_elements,
lookup_values,
)
});
accumulator.finalize()
}

fn component_traces<'a, B: Backend + MerkleOps<H>, H: MerkleHasher>(
&'a self,
trees: &'a [CommitmentTreeProver<B, H>],
) -> Vec<ComponentTrace<'_, B>> {
pub fn component_traces<'b, H: MerkleHasher>(
&'b self,
trees: &'b [CommitmentTreeProver<B, H>],
) -> Vec<ComponentTrace<'b, B>>
where
B: MerkleOps<H>,
{
let mut poly_iters = trees
.iter()
.map(|tree| tree.polynomials.iter())
Expand All @@ -91,7 +127,7 @@ pub trait AirExt: Air {
.map(|tree| tree.evaluations.iter())
.collect_vec();

self.components()
self.0
.iter()
.map(|component| {
let col_sizes_per_tree = component
Expand All @@ -116,45 +152,11 @@ pub trait AirExt: Air {
})
.collect_vec()
}
}

impl<A: Air + ?Sized> AirExt for A {}

pub trait AirProverExt<B: Backend>: AirProver<B> {
fn compute_composition_polynomial(
&self,
random_coeff: SecureField,
component_traces: &[ComponentTrace<'_, B>],
interaction_elements: &InteractionElements,
lookup_values: &LookupValues,
) -> SecureCirclePoly<B> {
let total_constraints: usize = self
.prover_components()
.iter()
.map(|c| c.n_constraints())
.sum();
let mut accumulator = DomainEvaluationAccumulator::new(
random_coeff,
self.composition_log_degree_bound(),
total_constraints,
);
zip_eq(self.prover_components(), component_traces).for_each(|(component, trace)| {
component.evaluate_constraint_quotients_on_domain(
trace,
&mut accumulator,
interaction_elements,
lookup_values,
)
});
accumulator.finalize()
}

fn lookup_values(&self, component_traces: &[ComponentTrace<'_, B>]) -> LookupValues {
pub fn lookup_values(&self, component_traces: &[ComponentTrace<'_, B>]) -> LookupValues {
let mut values = LookupValues::default();
zip_eq(self.prover_components(), component_traces)
zip_eq(&self.0, component_traces)
.for_each(|(component, trace)| values.extend(component.lookup_values(trace)));
values
}
}

impl<B: Backend, A: AirProver<B>> AirProverExt<B> for A {}
8 changes: 4 additions & 4 deletions crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub use components::{ComponentProvers, Components};

use self::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use super::backend::Backend;
use super::circle::CirclePoint;
Expand All @@ -9,11 +11,9 @@ use super::poly::BitReversedOrder;
use super::{ColumnVec, InteractionElements, LookupValues};

pub mod accumulation;
mod air_ext;
mod components;
pub mod mask;

pub use air_ext::{AirExt, AirProverExt};

/// Arithmetic Intermediate Representation (AIR).
/// An Air instance is assumed to already contain all the information needed to
/// evaluate the constraints.
Expand All @@ -25,7 +25,7 @@ pub trait Air {
}

pub trait AirProver<B: Backend>: Air {
fn prover_components(&self) -> Vec<&dyn ComponentProver<B>>;
fn component_provers(&self) -> Vec<&dyn ComponentProver<B>>;
}

/// A component is a set of trace columns of various sizes along with a set of
Expand Down
4 changes: 3 additions & 1 deletion crates/prover/src/core/pcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ pub mod quotients;
mod utils;
mod verifier;

pub use self::prover::{CommitmentSchemeProof, CommitmentSchemeProver, CommitmentTreeProver};
pub use self::prover::{
CommitmentSchemeProof, CommitmentSchemeProver, CommitmentTreeProver, TreeBuilder,
};
pub use self::utils::TreeVec;
pub use self::verifier::CommitmentSchemeVerifier;

Expand Down
66 changes: 36 additions & 30 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::{span, Level};

use super::air::AirProver;
use super::air::{Component, ComponentProver, ComponentProvers, Components};
use super::backend::Backend;
use super::fields::secure_column::SECURE_EXTENSION_DEGREE;
use super::fri::FriVerificationError;
Expand All @@ -12,7 +12,6 @@ use super::poly::circle::MAX_CIRCLE_DOMAIN_LOG_SIZE;
use super::proof_of_work::ProofOfWorkVerificationError;
use super::vcs::ops::MerkleHasher;
use super::{ColumnVec, InteractionElements, LookupValues};
use crate::core::air::{Air, AirExt, AirProverExt};
use crate::core::backend::CpuBackend;
use crate::core::channel::Channel;
use crate::core::circle::CirclePoint;
Expand Down Expand Up @@ -44,7 +43,7 @@ pub struct AdditionalProofData {
}

pub fn prove<B, C, H>(
air: &impl AirProver<B>,
components: &[&dyn ComponentProver<B>],
channel: &mut C,
interaction_elements: &InteractionElements,
commitment_scheme: &mut CommitmentSchemeProver<'_, B, H>,
Expand All @@ -54,15 +53,16 @@ where
C: Channel,
H: MerkleHasher<Hash = C::Digest>,
{
let component_traces = air.component_traces(&commitment_scheme.trees);
let lookup_values = air.lookup_values(&component_traces);
let component_provers = ComponentProvers(components.to_vec());
let component_traces = component_provers.component_traces(&commitment_scheme.trees);
let lookup_values = component_provers.lookup_values(&component_traces);

// Evaluate and commit on composition polynomial.
let random_coeff = channel.draw_felt();

let span = span!(Level::INFO, "Composition").entered();
let span1 = span!(Level::INFO, "Generation").entered();
let composition_polynomial_poly = air.compute_composition_polynomial(
let composition_polynomial_poly = component_provers.compute_composition_polynomial(
random_coeff,
&component_traces,
interaction_elements,
Expand All @@ -79,25 +79,30 @@ where
let oods_point = CirclePoint::<SecureField>::get_random_point(channel);

// Get mask sample points relative to oods point.
let sample_points = air.mask_points(oods_point);
let sample_points = component_provers.components().mask_points(oods_point);

// Prove the trace and composition OODS values, and retrieve them.
let commitment_scheme_proof = commitment_scheme.prove_values(sample_points, channel);

// Evaluate composition polynomial at OODS point and check that it matches the trace OODS
// values. This is a sanity check.
// TODO(spapini): Save clone.
let (trace_oods_values, composition_oods_value) =
sampled_values_to_mask(air, &commitment_scheme_proof.sampled_values).unwrap();
let (trace_oods_values, composition_oods_value) = sampled_values_to_mask(
&component_provers.components(),
&commitment_scheme_proof.sampled_values,
)
.unwrap();

if composition_oods_value
!= air.eval_composition_polynomial_at_point(
oods_point,
&trace_oods_values,
random_coeff,
interaction_elements,
&lookup_values,
)
!= component_provers
.components()
.eval_composition_polynomial_at_point(
oods_point,
&trace_oods_values,
random_coeff,
interaction_elements,
&lookup_values,
)
{
return Err(ProvingError::ConstraintsNotSatisfied);
}
Expand All @@ -110,7 +115,7 @@ where
}

pub fn verify<C, H>(
air: &impl Air,
components: &[&dyn Component],
channel: &mut C,
interaction_elements: &InteractionElements,
commitment_scheme: &mut CommitmentSchemeVerifier<H>,
Expand All @@ -120,32 +125,33 @@ where
C: Channel,
H: MerkleHasher<Hash = C::Digest>,
{
let components = Components(components.to_vec());
let random_coeff = channel.draw_felt();

// Read composition polynomial commitment.
commitment_scheme.commit(
*proof.commitments.last().unwrap(),
&[air.composition_log_degree_bound(); SECURE_EXTENSION_DEGREE],
&[components.composition_log_degree_bound(); SECURE_EXTENSION_DEGREE],
channel,
);

// Draw OODS point.
let oods_point = CirclePoint::<SecureField>::get_random_point(channel);

// Get mask sample points relative to oods point.
let sample_points = air.mask_points(oods_point);
let sample_points = components.mask_points(oods_point);

// TODO(spapini): Save clone.
let (trace_oods_values, composition_oods_value) = sampled_values_to_mask(
air,
&proof.commitment_scheme_proof.sampled_values,
)
.map_err(|_| {
VerificationError::InvalidStructure("Unexpected sampled_values structure".to_string())
})?;
let (trace_oods_values, composition_oods_value) =
sampled_values_to_mask(&components, &proof.commitment_scheme_proof.sampled_values)
.map_err(|_| {
VerificationError::InvalidStructure(
"Unexpected sampled_values structure".to_string(),
)
})?;

if composition_oods_value
!= air.eval_composition_polynomial_at_point(
!= components.eval_composition_polynomial_at_point(
oods_point,
&trace_oods_values,
random_coeff,
Expand All @@ -163,15 +169,15 @@ where
/// Structures the tree-wise sampled values into component-wise OODS values and a composition
/// polynomial OODS value.
fn sampled_values_to_mask(
air: &impl Air,
components: &Components<'_>,
sampled_values: &TreeVec<ColumnVec<Vec<SecureField>>>,
) -> Result<(Vec<TreeVec<Vec<Vec<SecureField>>>>, SecureField), InvalidOodsSampleStructure> {
let mut sampled_values = sampled_values.as_ref();
let composition_values = sampled_values.pop().ok_or(InvalidOodsSampleStructure)?;

let mut sample_iters = sampled_values.map(|tree_value| tree_value.iter());
let trace_oods_values = air
.components()
let trace_oods_values = components
.0
.iter()
.map(|component| {
component
Expand Down
Loading

0 comments on commit 8344112

Please sign in to comment.