From e3858fbc135ab33ba5943cc9fd7e07f2836f3e77 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Wed, 28 Aug 2024 14:08:17 +0100 Subject: [PATCH] Pass entire mask to components (#801) --- .../src/constraint_framework/component.rs | 112 ++++++++++--- crates/prover/src/constraint_framework/mod.rs | 2 +- .../prover/src/constraint_framework/point.rs | 4 +- crates/prover/src/core/air/components.rs | 72 ++------- crates/prover/src/core/air/mod.rs | 20 +-- crates/prover/src/core/pcs/mod.rs | 2 +- crates/prover/src/core/pcs/prover.rs | 21 ++- crates/prover/src/core/pcs/utils.rs | 31 ++++ crates/prover/src/core/prover/mod.rs | 92 ++++------- crates/prover/src/examples/blake/air.rs | 93 +++++++---- crates/prover/src/examples/blake/round/mod.rs | 38 ++--- .../examples/blake/scheduler/constraints.rs | 95 ++++++----- .../src/examples/blake/scheduler/mod.rs | 52 +++--- .../src/examples/blake/xor_table/mod.rs | 27 ++-- crates/prover/src/examples/plonk/mod.rs | 37 +++-- crates/prover/src/examples/poseidon/mod.rs | 150 +++++++++--------- .../wide_fibonacci/constraint_eval.rs | 6 +- .../prover/src/examples/wide_fibonacci/mod.rs | 4 +- .../src/examples/wide_fibonacci/simd.rs | 6 +- crates/prover/src/lib.rs | 3 +- crates/prover/src/trace_generation/prove.rs | 8 +- 21 files changed, 472 insertions(+), 403 deletions(-) diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 5e20b3db7..2b92a3648 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -1,11 +1,13 @@ use std::borrow::Cow; +use std::iter::zip; +use std::ops::Deref; use itertools::Itertools; use tracing::{span, Level}; use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator}; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use crate::core::air::{Component, ComponentProver, ComponentTrace}; +use crate::core::air::{Component, ComponentProver, Trace}; use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS}; @@ -15,36 +17,87 @@ use crate::core::constraints::coset_vanishing; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; -use crate::core::pcs::TreeVec; +use crate::core::pcs::{TreeSubspan, TreeVec}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; use crate::core::{utils, ColumnVec, InteractionElements, LookupValues}; +// TODO(andrew): Docs. +// TODO(andrew): Consider better location for this. +#[derive(Debug, Default)] +pub struct TraceLocationAllocator { + /// Mapping of tree index to next available column offset. + next_tree_offsets: TreeVec, +} + +impl TraceLocationAllocator { + fn next_for_structure(&mut self, structure: &TreeVec>) -> TreeVec { + if structure.len() > self.next_tree_offsets.len() { + self.next_tree_offsets.resize(structure.len(), 0); + } + + TreeVec::new( + zip(&mut *self.next_tree_offsets, &**structure) + .enumerate() + .map(|(tree_index, (offset, cols))| { + let col_start = *offset; + let col_end = col_start + cols.len(); + *offset = col_end; + TreeSubspan { + tree_index, + col_start, + col_end, + } + }) + .collect(), + ) + } +} + /// A component defined solely in means of the constraints framework. -/// Implementing this trait introduces implementations for [Component] and [ComponentProver] for the -/// SIMD backend. +/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for +/// the SIMD backend. /// Note that the constraint framework only support components with columns of the same size. -pub trait FrameworkComponent { +pub trait FrameworkEval { fn log_size(&self) -> u32; + fn max_constraint_log_degree_bound(&self) -> u32; + fn evaluate(&self, eval: E) -> E; } -impl Component for C { +pub struct FrameworkComponent { + eval: C, + trace_locations: TreeVec, +} + +impl FrameworkComponent { + pub fn new(provider: &mut TraceLocationAllocator, eval: E) -> Self { + let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets; + let trace_locations = provider.next_for_structure(&eval_tree_structure); + Self { + eval, + trace_locations, + } + } +} + +impl Component for FrameworkComponent { fn n_constraints(&self) -> usize { - self.evaluate(InfoEvaluator::default()).n_constraints + self.eval.evaluate(InfoEvaluator::default()).n_constraints } fn max_constraint_log_degree_bound(&self) -> u32 { - FrameworkComponent::max_constraint_log_degree_bound(self) + self.eval.max_constraint_log_degree_bound() } fn trace_log_degree_bounds(&self) -> TreeVec> { TreeVec::new( - self.evaluate(InfoEvaluator::default()) + self.eval + .evaluate(InfoEvaluator::default()) .mask_offsets .iter() - .map(|tree_masks| vec![self.log_size(); tree_masks.len()]) + .map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()]) .collect(), ) } @@ -53,8 +106,8 @@ impl Component for C { &self, point: CirclePoint, ) -> TreeVec>>> { - let info = self.evaluate(InfoEvaluator::default()); - let trace_step = CanonicCoset::new(self.log_size()).step(); + let info = self.eval.evaluate(InfoEvaluator::default()); + let trace_step = CanonicCoset::new(self.eval.log_size()).step(); info.mask_offsets.map_cols(|col_mask| { col_mask .iter() @@ -71,30 +124,32 @@ impl Component for C { _interaction_elements: &InteractionElements, _lookup_values: &LookupValues, ) { - self.evaluate(PointEvaluator::new( - mask.as_ref(), + self.eval.evaluate(PointEvaluator::new( + mask.sub_tree(&self.trace_locations), evaluation_accumulator, - coset_vanishing(CanonicCoset::new(self.log_size()).coset, point).inverse(), + coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(), )); } } -impl ComponentProver for C { +impl ComponentProver for FrameworkComponent { fn evaluate_constraint_quotients_on_domain( &self, - trace: &ComponentTrace<'_, SimdBackend>, + trace: &Trace<'_, SimdBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, _interaction_elements: &InteractionElements, _lookup_values: &LookupValues, ) { let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); - let trace_domain = CanonicCoset::new(self.log_size()); + let trace_domain = CanonicCoset::new(self.eval.log_size()); + + let component_polys = trace.polys.sub_tree(&self.trace_locations); + let component_evals = trace.evals.sub_tree(&self.trace_locations); // Extend trace if necessary. // TODO(spapini): Don't extend when eval_size < committed_size. Instead, pick a good // subdomain. - let need_to_extend = trace - .evals + let need_to_extend = component_evals .iter() .flatten() .any(|c| c.domain != eval_domain); @@ -103,12 +158,11 @@ impl ComponentProver for C { > = if need_to_extend { let _span = span!(Level::INFO, "Extension").entered(); let twiddles = SimdBackend::precompute_twiddles(eval_domain.half_coset); - trace - .polys + component_polys .as_cols_ref() .map_cols(|col| Cow::Owned(col.evaluate_with_twiddles(eval_domain, &twiddles))) } else { - trace.evals.as_cols_ref().map_cols(|c| Cow::Borrowed(*c)) + component_evals.clone().map_cols(|c| Cow::Borrowed(*c)) }; // Denom inverses. @@ -137,7 +191,7 @@ impl ComponentProver for C { trace_domain.log_size(), eval_domain.log_size(), ); - let row_res = self.evaluate(eval).row_res; + let row_res = self.eval.evaluate(eval).row_res; // Finalize row. unsafe { @@ -150,7 +204,15 @@ impl ComponentProver for C { } } - fn lookup_values(&self, _trace: &ComponentTrace<'_, SimdBackend>) -> LookupValues { + fn lookup_values(&self, _trace: &Trace<'_, SimdBackend>) -> LookupValues { LookupValues::default() } } + +impl Deref for FrameworkComponent { + type Target = E; + + fn deref(&self) -> &E { + &self.eval + } +} diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index f0d6ca9be..87069d344 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -12,7 +12,7 @@ use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Neg, Sub}; pub use assert::{assert_constraints, AssertEvaluator}; -pub use component::FrameworkComponent; +pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator}; pub use info::InfoEvaluator; use num_traits::{One, Zero}; pub use point::PointEvaluator; diff --git a/crates/prover/src/constraint_framework/point.rs b/crates/prover/src/constraint_framework/point.rs index 5bbdb778d..6c6f72f81 100644 --- a/crates/prover/src/constraint_framework/point.rs +++ b/crates/prover/src/constraint_framework/point.rs @@ -9,14 +9,14 @@ use crate::core::ColumnVec; /// Evaluates expressions at a point out of domain. pub struct PointEvaluator<'a> { - pub mask: TreeVec<&'a ColumnVec>>, + pub mask: TreeVec>>, pub evaluation_accumulator: &'a mut PointEvaluationAccumulator, pub col_index: Vec, pub denom_inverse: SecureField, } impl<'a> PointEvaluator<'a> { pub fn new( - mask: TreeVec<&'a ColumnVec>>, + mask: TreeVec>>, evaluation_accumulator: &'a mut PointEvaluationAccumulator, denom_inverse: SecureField, ) -> Self { diff --git a/crates/prover/src/core/air/components.rs b/crates/prover/src/core/air/components.rs index 9b3d1200e..397468616 100644 --- a/crates/prover/src/core/air/components.rs +++ b/crates/prover/src/core/air/components.rs @@ -1,12 +1,11 @@ -use itertools::{zip_eq, Itertools}; +use itertools::Itertools; use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use super::{Component, ComponentProver, ComponentTrace}; -use crate::core::backend::{Backend, BackendForChannel}; -use crate::core::channel::MerkleChannel; +use super::{Component, ComponentProver, Trace}; +use crate::core::backend::Backend; use crate::core::circle::CirclePoint; use crate::core::fields::qm31::SecureField; -use crate::core::pcs::{CommitmentTreeProver, TreeVec}; +use crate::core::pcs::TreeVec; use crate::core::poly::circle::SecureCirclePoly; use crate::core::{ColumnVec, InteractionElements, LookupValues}; @@ -31,21 +30,21 @@ impl<'a> Components<'a> { pub fn eval_composition_polynomial_at_point( &self, point: CirclePoint, - mask_values: &Vec>>>, + mask_values: &TreeVec>>, random_coeff: SecureField, interaction_elements: &InteractionElements, lookup_values: &LookupValues, ) -> SecureField { let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff); - zip_eq(&self.0, mask_values).for_each(|(component, mask)| { + for component in &self.0 { component.evaluate_constraint_quotients_at_point( point, - mask, + mask_values, &mut evaluation_accumulator, interaction_elements, lookup_values, ) - }); + } evaluation_accumulator.finalize() } @@ -67,7 +66,7 @@ impl<'a, B: Backend> ComponentProvers<'a, B> { pub fn compute_composition_polynomial( &self, random_coeff: SecureField, - component_traces: &[ComponentTrace<'_, B>], + trace: &Trace<'_, B>, interaction_elements: &InteractionElements, lookup_values: &LookupValues, ) -> SecureCirclePoly { @@ -77,63 +76,22 @@ impl<'a, B: Backend> ComponentProvers<'a, B> { self.components().composition_log_degree_bound(), total_constraints, ); - zip_eq(&self.0, component_traces).for_each(|(component, trace)| { + for component in &self.0 { component.evaluate_constraint_quotients_on_domain( trace, &mut accumulator, interaction_elements, lookup_values, ) - }); + } accumulator.finalize() } - pub fn component_traces<'b, MC: MerkleChannel>( - &'b self, - trees: &'b [CommitmentTreeProver], - ) -> Vec> - where - B: BackendForChannel, - { - let mut poly_iters = trees - .iter() - .map(|tree| tree.polynomials.iter()) - .collect_vec(); - let mut eval_iters = trees - .iter() - .map(|tree| tree.evaluations.iter()) - .collect_vec(); - - self.0 - .iter() - .map(|component| { - let col_sizes_per_tree = component - .trace_log_degree_bounds() - .iter() - .map(|col_sizes| col_sizes.len()) - .collect_vec(); - let polys = col_sizes_per_tree - .iter() - .zip(poly_iters.iter_mut()) - .map(|(n_columns, iter)| iter.take(*n_columns).collect_vec()) - .collect_vec(); - let evals = col_sizes_per_tree - .iter() - .zip(eval_iters.iter_mut()) - .map(|(n_columns, iter)| iter.take(*n_columns).collect_vec()) - .collect_vec(); - ComponentTrace { - polys: TreeVec::new(polys), - evals: TreeVec::new(evals), - } - }) - .collect_vec() - } - - pub fn lookup_values(&self, component_traces: &[ComponentTrace<'_, B>]) -> LookupValues { + pub fn lookup_values(&self, trace: &Trace<'_, B>) -> LookupValues { let mut values = LookupValues::default(); - zip_eq(&self.0, component_traces) - .for_each(|(component, trace)| values.extend(component.lookup_values(trace))); + for component in &self.0 { + values.extend(component.lookup_values(trace)) + } values } } diff --git a/crates/prover/src/core/air/mod.rs b/crates/prover/src/core/air/mod.rs index efd2d23c5..39834296a 100644 --- a/crates/prover/src/core/air/mod.rs +++ b/crates/prover/src/core/air/mod.rs @@ -62,30 +62,22 @@ pub trait ComponentProver: Component { /// Accumulates quotients in `evaluation_accumulator`. fn evaluate_constraint_quotients_on_domain( &self, - trace: &ComponentTrace<'_, B>, + trace: &Trace<'_, B>, evaluation_accumulator: &mut DomainEvaluationAccumulator, interaction_elements: &InteractionElements, lookup_values: &LookupValues, ); /// Returns the values needed to evaluate the components lookup boundary constraints. - fn lookup_values(&self, _trace: &ComponentTrace<'_, B>) -> LookupValues; + fn lookup_values(&self, _trace: &Trace<'_, B>) -> LookupValues; } -/// A component trace is a set of polynomials for each column on that component. +/// The set of polynomials that make up the trace. +/// /// Each polynomial is stored both in a coefficients, and evaluations form (for efficiency) -pub struct ComponentTrace<'a, B: Backend> { +pub struct Trace<'a, B: Backend> { /// Polynomials for each column. pub polys: TreeVec>>, - /// Evaluations for each column (evaluated on the commitment domains). + /// Evaluations for each column (evaluated on their commitment domains). pub evals: TreeVec>>, } - -impl<'a, B: Backend> ComponentTrace<'a, B> { - pub fn new( - polys: TreeVec>>, - evals: TreeVec>>, - ) -> Self { - Self { polys, evals } - } -} diff --git a/crates/prover/src/core/pcs/mod.rs b/crates/prover/src/core/pcs/mod.rs index 07ddfd2d3..d9acf524b 100644 --- a/crates/prover/src/core/pcs/mod.rs +++ b/crates/prover/src/core/pcs/mod.rs @@ -19,7 +19,7 @@ pub use self::verifier::CommitmentSchemeVerifier; use super::fri::FriConfig; #[derive(Copy, Debug, Clone, PartialEq, Eq)] -pub struct TreeColumnSpan { +pub struct TreeSubspan { pub tree_index: usize, pub col_start: usize, pub col_end: usize, diff --git a/crates/prover/src/core/pcs/prover.rs b/crates/prover/src/core/pcs/prover.rs index 425b4581a..7a761eaa9 100644 --- a/crates/prover/src/core/pcs/prover.rs +++ b/crates/prover/src/core/pcs/prover.rs @@ -13,7 +13,8 @@ use super::super::poly::BitReversedOrder; use super::super::ColumnVec; use super::quotients::{compute_fri_quotients, PointSample}; use super::utils::TreeVec; -use super::{PcsConfig, TreeColumnSpan}; +use super::{PcsConfig, TreeSubspan}; +use crate::core::air::Trace; use crate::core::backend::BackendForChannel; use crate::core::channel::{Channel, MerkleChannel}; use crate::core::poly::circle::{CircleEvaluation, CirclePoly}; @@ -66,12 +67,20 @@ impl<'a, B: BackendForChannel, MC: MerkleChannel> CommitmentSchemeProver<'a, .map(|tree| tree.polynomials.iter().collect()) } - fn evaluations(&self) -> TreeVec>> { + pub fn evaluations( + &self, + ) -> TreeVec>> { self.trees .as_ref() .map(|tree| tree.evaluations.iter().collect()) } + pub fn trace(&self) -> Trace<'_, B> { + let polys = self.polynomials(); + let evals = self.evaluations(); + Trace { polys, evals } + } + pub fn prove_values( &self, sampled_points: TreeVec>>>, @@ -159,7 +168,7 @@ impl<'a, 'b, B: BackendForChannel, MC: MerkleChannel> TreeBuilder<'a, 'b, B, pub fn extend_evals( &mut self, columns: ColumnVec>, - ) -> TreeColumnSpan { + ) -> TreeSubspan { let span = span!(Level::INFO, "Interpolation for commitment").entered(); let col_start = self.polys.len(); let polys = columns @@ -168,17 +177,17 @@ impl<'a, 'b, B: BackendForChannel, MC: MerkleChannel> TreeBuilder<'a, 'b, B, .collect_vec(); span.exit(); self.polys.extend(polys); - TreeColumnSpan { + TreeSubspan { tree_index: self.tree_index, col_start, col_end: self.polys.len(), } } - pub fn extend_polys(&mut self, polys: ColumnVec>) -> TreeColumnSpan { + pub fn extend_polys(&mut self, polys: ColumnVec>) -> TreeSubspan { let col_start = self.polys.len(); self.polys.extend(polys); - TreeColumnSpan { + TreeSubspan { tree_index: self.tree_index, col_start, col_end: self.polys.len(), diff --git a/crates/prover/src/core/pcs/utils.rs b/crates/prover/src/core/pcs/utils.rs index bd1c6f9ca..bfdbdb5d9 100644 --- a/crates/prover/src/core/pcs/utils.rs +++ b/crates/prover/src/core/pcs/utils.rs @@ -1,8 +1,10 @@ +use std::collections::BTreeSet; use std::ops::{Deref, DerefMut}; use itertools::zip_eq; use serde::{Deserialize, Serialize}; +use super::TreeSubspan; use crate::core::ColumnVec; /// A container that holds an element for each commitment tree. @@ -67,6 +69,7 @@ impl TreeVec> { .collect(), ) } + /// Zips two [`TreeVec>`] with the same structure (number of columns in each tree). /// The resulting [`TreeVec>`] has the same structure, with each value being a tuple /// of the corresponding values from the input [`TreeVec>`]. @@ -81,9 +84,11 @@ impl TreeVec> { .collect(), ) } + pub fn as_cols_ref(&self) -> TreeVec> { TreeVec(self.iter().map(|column| column.iter().collect()).collect()) } + /// Flattens the [`TreeVec>`] into a single [`ColumnVec`] with all the columns /// combined. pub fn flatten(self) -> ColumnVec { @@ -110,6 +115,32 @@ impl TreeVec> { } result } + + /// Extracts a sub-tree based on the specified locations. + /// + /// # Panics + /// + /// If two or more locations have the same tree index. + pub fn sub_tree(&self, locations: &[TreeSubspan]) -> TreeVec> { + let tree_indicies: BTreeSet = locations.iter().map(|l| l.tree_index).collect(); + assert_eq!(tree_indicies.len(), locations.len()); + let max_tree_index = tree_indicies.iter().max().unwrap_or(&0); + let mut res = TreeVec(vec![Vec::new(); max_tree_index + 1]); + + for &location in locations { + // TODO(andrew): Throwing error here might be better instead. + let chunk = self.get_chunk(location).unwrap(); + res[location.tree_index] = chunk; + } + + res + } + + fn get_chunk(&self, location: TreeSubspan) -> Option> { + let tree = self.0.get(location.tree_index)?; + let chunk = tree.get(location.col_start..location.col_end)?; + Some(chunk.iter().collect()) + } } impl<'a, T> From<&'a TreeVec>> for TreeVec> { diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index b4cf0c537..d5b9c4c4a 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -1,4 +1,5 @@ -use itertools::Itertools; +use std::array; + use serde::{Deserialize, Serialize}; use thiserror::Error; use tracing::{span, Level}; @@ -10,7 +11,7 @@ use super::fields::secure_column::SECURE_EXTENSION_DEGREE; use super::fri::FriVerificationError; use super::pcs::{CommitmentSchemeProof, TreeVec}; use super::vcs::ops::MerkleHasher; -use super::{ColumnVec, InteractionElements, LookupValues}; +use super::{InteractionElements, LookupValues}; use crate::core::backend::CpuBackend; use crate::core::channel::Channel; use crate::core::circle::CirclePoint; @@ -42,8 +43,8 @@ pub fn prove, MC: MerkleChannel>( commitment_scheme: &mut CommitmentSchemeProver<'_, B, MC>, ) -> Result, ProvingError> { let component_provers = ComponentProvers(components.to_vec()); - let component_traces = component_provers.component_traces(&commitment_scheme.trees); - let lookup_values = component_provers.lookup_values(&component_traces); + let trace = commitment_scheme.trace(); + let lookup_values = component_provers.lookup_values(&trace); // Evaluate and commit on composition polynomial. let random_coeff = channel.draw_felt(); @@ -52,7 +53,7 @@ pub fn prove, MC: MerkleChannel>( let span1 = span!(Level::INFO, "Generation").entered(); let composition_polynomial_poly = component_provers.compute_composition_polynomial( random_coeff, - &component_traces, + &trace, interaction_elements, &lookup_values, ); @@ -74,21 +75,17 @@ pub fn prove, MC: MerkleChannel>( // Prove the trace and composition OODS values, and retrieve them. let commitment_scheme_proof = commitment_scheme.prove_values(sample_points, channel); + let sampled_oods_values = &commitment_scheme_proof.sampled_values; + let composition_oods_eval = extract_composition_eval(sampled_oods_values).unwrap(); + // Evaluate composition polynomial at OODS point and check that it matches the trace OODS // values. This is a sanity check. - // TODO(spapini): Save clone. - let (trace_oods_values, composition_oods_value) = sampled_values_to_mask( - &component_provers.components(), - &commitment_scheme_proof.sampled_values, - ) - .unwrap(); - - if composition_oods_value + if composition_oods_eval != component_provers .components() .eval_composition_polynomial_at_point( oods_point, - &trace_oods_values, + sampled_oods_values, random_coeff, interaction_elements, &lookup_values, @@ -129,19 +126,15 @@ pub fn verify( // Add the composition polynomial mask points. sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]); - // TODO(spapini): Save clone. - let (trace_oods_values, composition_oods_value) = - sampled_values_to_mask(&components, &proof.commitment_scheme_proof.sampled_values) - .map_err(|_| { - VerificationError::InvalidStructure( - "Unexpected sampled_values structure".to_string(), - ) - })?; + let sampled_oods_values = &proof.commitment_scheme_proof.sampled_values; + let composition_oods_eval = extract_composition_eval(sampled_oods_values).map_err(|_| { + VerificationError::InvalidStructure("Unexpected sampled_values structure".to_string()) + })?; - if composition_oods_value + if composition_oods_eval != components.eval_composition_polynomial_at_point( oods_point, - &trace_oods_values, + sampled_oods_values, random_coeff, interaction_elements, &proof.lookup_values, @@ -153,41 +146,24 @@ pub fn verify( commitment_scheme.verify_values(sample_points, proof.commitment_scheme_proof, channel) } -#[allow(clippy::type_complexity)] -/// Structures the tree-wise sampled values into component-wise OODS values and a composition -/// polynomial OODS value. -fn sampled_values_to_mask( - components: &Components<'_>, - sampled_values: &TreeVec>>, -) -> Result<(Vec>>>, SecureField), InvalidOodsSampleStructure> { - let mut sampled_values = sampled_values.as_ref(); - let composition_values = sampled_values.pop().ok_or(InvalidOodsSampleStructure)?; - - let mut sample_iters = sampled_values.map(|tree_value| tree_value.iter()); - let trace_oods_values = components - .0 - .iter() - .map(|component| { - component - .mask_points(CirclePoint::zero()) - .zip(sample_iters.as_mut()) - .map(|(mask_per_tree, tree_iter)| { - tree_iter.take(mask_per_tree.len()).cloned().collect_vec() - }) - }) - .collect_vec(); - - let composition_oods_value = SecureField::from_partial_evals( - composition_values - .iter() - .flatten() - .cloned() - .collect_vec() - .try_into() - .map_err(|_| InvalidOodsSampleStructure)?, - ); +/// Extracts the composition trace evaluation from the mask. +fn extract_composition_eval( + mask: &TreeVec>>, +) -> Result { + let mut composition_cols = mask.last().into_iter().flatten(); + + let coordinate_evals = array::try_from_fn(|_| { + let col = &**composition_cols.next().ok_or(InvalidOodsSampleStructure)?; + let [eval] = col.try_into().map_err(|_| InvalidOodsSampleStructure)?; + Ok(eval) + })?; + + // Too many columns. + if composition_cols.next().is_some() { + return Err(InvalidOodsSampleStructure); + } - Ok((trace_oods_values, composition_oods_value)) + Ok(SecureField::from_partial_evals(coordinate_evals)) } /// Error when the sampled values have an invalid structure. diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index 809c2ccde..9975f0687 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -5,9 +5,10 @@ use num_traits::Zero; use serde::Serialize; use tracing::{span, Level}; -use super::round::{blake_round_info, BlakeRoundComponent}; -use super::scheduler::BlakeSchedulerComponent; -use super::xor_table::XorTableComponent; +use super::round::{blake_round_info, BlakeRoundComponent, BlakeRoundEval}; +use super::scheduler::{BlakeSchedulerComponent, BlakeSchedulerEval}; +use super::xor_table::{XorTableComponent, XorTableEval}; +use crate::constraint_framework::TraceLocationAllocator; use crate::core::air::{Component, ComponentProver}; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::SimdBackend; @@ -120,43 +121,67 @@ pub struct BlakeComponents { } impl BlakeComponents { fn new(stmt0: &BlakeStatement0, all_elements: &AllElements, stmt1: &BlakeStatement1) -> Self { + let tree_span_provider = &mut TraceLocationAllocator::default(); Self { - scheduler_component: BlakeSchedulerComponent { - log_size: stmt0.log_size, - blake_lookup_elements: all_elements.blake_elements.clone(), - round_lookup_elements: all_elements.round_elements.clone(), - claimed_sum: stmt1.scheduler_claimed_sum, - }, + scheduler_component: BlakeSchedulerComponent::new( + tree_span_provider, + BlakeSchedulerEval { + log_size: stmt0.log_size, + blake_lookup_elements: all_elements.blake_elements.clone(), + round_lookup_elements: all_elements.round_elements.clone(), + claimed_sum: stmt1.scheduler_claimed_sum, + }, + ), round_components: ROUND_LOG_SPLIT .iter() .zip(stmt1.round_claimed_sums.clone()) - .map(|(l, claimed_sum)| BlakeRoundComponent { - log_size: stmt0.log_size + l, - xor_lookup_elements: all_elements.xor_elements.clone(), - round_lookup_elements: all_elements.round_elements.clone(), - claimed_sum, + .map(|(l, claimed_sum)| { + BlakeRoundComponent::new( + tree_span_provider, + BlakeRoundEval { + log_size: stmt0.log_size + l, + xor_lookup_elements: all_elements.xor_elements.clone(), + round_lookup_elements: all_elements.round_elements.clone(), + claimed_sum, + }, + ) }) .collect(), - xor12: XorTableComponent { - lookup_elements: all_elements.xor_elements.xor12.clone(), - claimed_sum: stmt1.xor12_claimed_sum, - }, - xor9: XorTableComponent { - lookup_elements: all_elements.xor_elements.xor9.clone(), - claimed_sum: stmt1.xor9_claimed_sum, - }, - xor8: XorTableComponent { - lookup_elements: all_elements.xor_elements.xor8.clone(), - claimed_sum: stmt1.xor8_claimed_sum, - }, - xor7: XorTableComponent { - lookup_elements: all_elements.xor_elements.xor7.clone(), - claimed_sum: stmt1.xor7_claimed_sum, - }, - xor4: XorTableComponent { - lookup_elements: all_elements.xor_elements.xor4.clone(), - claimed_sum: stmt1.xor4_claimed_sum, - }, + xor12: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor12.clone(), + claimed_sum: stmt1.xor12_claimed_sum, + }, + ), + xor9: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor9.clone(), + claimed_sum: stmt1.xor9_claimed_sum, + }, + ), + xor8: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor8.clone(), + claimed_sum: stmt1.xor8_claimed_sum, + }, + ), + xor7: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor7.clone(), + claimed_sum: stmt1.xor7_claimed_sum, + }, + ), + xor4: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor4.clone(), + claimed_sum: stmt1.xor4_claimed_sum, + }, + ), } } fn components(&self) -> Vec<&dyn Component> { diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs index c5123b5ce..cf8311339 100644 --- a/crates/prover/src/examples/blake/round/mod.rs +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -1,34 +1,26 @@ mod constraints; mod gen; -use constraints::BlakeRoundEval; +pub use gen::{generate_interaction_trace, generate_trace, BlakeRoundInput}; use num_traits::Zero; -pub use r#gen::{generate_interaction_trace, generate_trace, BlakeRoundInput}; use super::{BlakeXorElements, N_ROUND_INPUT_FELTS}; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; -use crate::constraint_framework::{EvalAtRow, FrameworkComponent, InfoEvaluator}; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; use crate::core::fields::qm31::SecureField; -pub fn blake_round_info() -> InfoEvaluator { - let component = BlakeRoundComponent { - log_size: 1, - xor_lookup_elements: BlakeXorElements::dummy(), - round_lookup_elements: RoundElements::dummy(), - claimed_sum: SecureField::zero(), - }; - component.evaluate(InfoEvaluator::default()) -} +pub type BlakeRoundComponent = FrameworkComponent; pub type RoundElements = LookupElements; -pub struct BlakeRoundComponent { + +pub struct BlakeRoundEval { pub log_size: u32, pub xor_lookup_elements: BlakeXorElements, pub round_lookup_elements: RoundElements, pub claimed_sum: SecureField, } -impl FrameworkComponent for BlakeRoundComponent { +impl FrameworkEval for BlakeRoundEval { fn log_size(&self) -> u32 { self.log_size } @@ -36,7 +28,7 @@ impl FrameworkComponent for BlakeRoundComponent { self.log_size + 1 } fn evaluate(&self, eval: E) -> E { - let blake_eval = BlakeRoundEval { + let blake_eval = constraints::BlakeRoundEval { eval, xor_lookup_elements: &self.xor_lookup_elements, round_lookup_elements: &self.round_lookup_elements, @@ -46,6 +38,16 @@ impl FrameworkComponent for BlakeRoundComponent { } } +pub fn blake_round_info() -> InfoEvaluator { + let component = BlakeRoundEval { + log_size: 1, + xor_lookup_elements: BlakeXorElements::dummy(), + round_lookup_elements: RoundElements::dummy(), + claimed_sum: SecureField::zero(), + }; + component.evaluate(InfoEvaluator::default()) +} + #[cfg(test)] mod tests { use std::simd::Simd; @@ -53,12 +55,12 @@ mod tests { use itertools::Itertools; use crate::constraint_framework::constant_columns::gen_is_first; - use crate::constraint_framework::FrameworkComponent; + use crate::constraint_framework::FrameworkEval; use crate::core::poly::circle::CanonicCoset; use crate::examples::blake::round::r#gen::{ generate_interaction_trace, generate_trace, BlakeRoundInput, }; - use crate::examples::blake::round::{BlakeRoundComponent, RoundElements}; + use crate::examples::blake::round::{BlakeRoundEval, RoundElements}; use crate::examples::blake::{BlakeXorElements, XorAccums}; #[test] @@ -91,7 +93,7 @@ mod tests { let trace = TreeVec::new(vec![trace, interaction_trace, vec![gen_is_first(LOG_SIZE)]]); let trace_polys = trace.map_cols(|c| c.interpolate()); - let component = BlakeRoundComponent { + let component = BlakeRoundEval { log_size: LOG_SIZE, xor_lookup_elements, round_lookup_elements, diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index 9d5057fed..63b3cf696 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -8,60 +8,57 @@ use crate::core::vcs::blake2s_ref::SIGMA; use crate::examples::blake::round::RoundElements; use crate::examples::blake::{Fu32, N_ROUNDS, STATE_SIZE}; -pub struct BlakeSchedulerEval<'a, E: EvalAtRow> { - pub eval: E, - pub blake_lookup_elements: &'a BlakeElements, - pub round_lookup_elements: &'a RoundElements, - pub logup: LogupAtRow<2, E>, -} -impl<'a, E: EvalAtRow> BlakeSchedulerEval<'a, E> { - pub fn eval(mut self) -> E { - let messages: [Fu32; STATE_SIZE] = std::array::from_fn(|_| self.next_u32()); - let states: [[Fu32; STATE_SIZE]; N_ROUNDS + 1] = - std::array::from_fn(|_| std::array::from_fn(|_| self.next_u32())); - - // Schedule. - for i in 0..N_ROUNDS { - let input_state = &states[i]; - let output_state = &states[i + 1]; - let round_messages = SIGMA[i].map(|j| messages[j as usize]); - // Use triplet in round lookup. - self.logup.push_lookup( - &mut self.eval, - E::EF::one(), - &chain![ - input_state.iter().copied().flat_map(Fu32::to_felts), - output_state.iter().copied().flat_map(Fu32::to_felts), - round_messages.iter().copied().flat_map(Fu32::to_felts) - ] - .collect_vec(), - self.round_lookup_elements, - ) - } - - let input_state = &states[0]; - let output_state = &states[N_ROUNDS]; +pub fn eval_blake_scheduler_constraints( + eval: &mut E, + blake_lookup_elements: &BlakeElements, + round_lookup_elements: &RoundElements, + mut logup: LogupAtRow<2, E>, +) { + let messages: [Fu32; STATE_SIZE] = std::array::from_fn(|_| eval_next_u32(eval)); + let states: [[Fu32; STATE_SIZE]; N_ROUNDS + 1] = + std::array::from_fn(|_| std::array::from_fn(|_| eval_next_u32(eval))); - // TODO(spapini): Support multiplicities. - // TODO(spapini): Change to -1. - self.logup.push_lookup( - &mut self.eval, - E::EF::zero(), + // Schedule. + for i in 0..N_ROUNDS { + let input_state = &states[i]; + let output_state = &states[i + 1]; + let round_messages = SIGMA[i].map(|j| messages[j as usize]); + // Use triplet in round lookup. + logup.push_lookup( + eval, + E::EF::one(), &chain![ input_state.iter().copied().flat_map(Fu32::to_felts), output_state.iter().copied().flat_map(Fu32::to_felts), - messages.iter().copied().flat_map(Fu32::to_felts) + round_messages.iter().copied().flat_map(Fu32::to_felts) ] .collect_vec(), - self.blake_lookup_elements, - ); - - self.logup.finalize(&mut self.eval); - self.eval - } - fn next_u32(&mut self) -> Fu32 { - let l = self.eval.next_trace_mask(); - let h = self.eval.next_trace_mask(); - Fu32 { l, h } + round_lookup_elements, + ) } + + let input_state = &states[0]; + let output_state = &states[N_ROUNDS]; + + // TODO(spapini): Support multiplicities. + // TODO(spapini): Change to -1. + logup.push_lookup( + eval, + E::EF::zero(), + &chain![ + input_state.iter().copied().flat_map(Fu32::to_felts), + output_state.iter().copied().flat_map(Fu32::to_felts), + messages.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(), + blake_lookup_elements, + ); + + logup.finalize(eval); +} + +fn eval_next_u32(eval: &mut E) -> Fu32 { + let l = eval.next_trace_mask(); + let h = eval.next_trace_mask(); + Fu32 { l, h } } diff --git a/crates/prover/src/examples/blake/scheduler/mod.rs b/crates/prover/src/examples/blake/scheduler/mod.rs index 116053246..e8a8c32f3 100644 --- a/crates/prover/src/examples/blake/scheduler/mod.rs +++ b/crates/prover/src/examples/blake/scheduler/mod.rs @@ -1,63 +1,65 @@ mod constraints; mod gen; -use constraints::BlakeSchedulerEval; +use constraints::eval_blake_scheduler_constraints; pub use gen::{gen_interaction_trace, gen_trace, BlakeInput}; use num_traits::Zero; use super::round::RoundElements; use super::N_ROUND_INPUT_FELTS; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; -use crate::constraint_framework::{EvalAtRow, FrameworkComponent, InfoEvaluator}; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; use crate::core::fields::qm31::SecureField; -pub type BlakeElements = LookupElements; +pub type BlakeSchedulerComponent = FrameworkComponent; -pub fn blake_scheduler_info() -> InfoEvaluator { - let component = BlakeSchedulerComponent { - log_size: 1, - blake_lookup_elements: BlakeElements::dummy(), - round_lookup_elements: RoundElements::dummy(), - claimed_sum: SecureField::zero(), - }; - component.evaluate(InfoEvaluator::default()) -} +pub type BlakeElements = LookupElements; -pub struct BlakeSchedulerComponent { +pub struct BlakeSchedulerEval { pub log_size: u32, pub blake_lookup_elements: BlakeElements, pub round_lookup_elements: RoundElements, pub claimed_sum: SecureField, } -impl FrameworkComponent for BlakeSchedulerComponent { +impl FrameworkEval for BlakeSchedulerEval { fn log_size(&self) -> u32 { self.log_size } fn max_constraint_log_degree_bound(&self) -> u32 { self.log_size + 1 } - fn evaluate(&self, eval: E) -> E { - let blake_eval = BlakeSchedulerEval { - eval, - blake_lookup_elements: &self.blake_lookup_elements, - round_lookup_elements: &self.round_lookup_elements, - logup: LogupAtRow::new(1, self.claimed_sum, self.log_size), - }; - blake_eval.eval() + fn evaluate(&self, mut eval: E) -> E { + eval_blake_scheduler_constraints( + &mut eval, + &self.blake_lookup_elements, + &self.round_lookup_elements, + LogupAtRow::new(1, self.claimed_sum, self.log_size), + ); + eval } } +pub fn blake_scheduler_info() -> InfoEvaluator { + let component = BlakeSchedulerEval { + log_size: 1, + blake_lookup_elements: BlakeElements::dummy(), + round_lookup_elements: RoundElements::dummy(), + claimed_sum: SecureField::zero(), + }; + component.evaluate(InfoEvaluator::default()) +} + #[cfg(test)] mod tests { use std::simd::Simd; use itertools::Itertools; - use crate::constraint_framework::FrameworkComponent; + use crate::constraint_framework::FrameworkEval; use crate::core::poly::circle::CanonicCoset; use crate::examples::blake::round::RoundElements; use crate::examples::blake::scheduler::r#gen::{gen_interaction_trace, gen_trace, BlakeInput}; - use crate::examples::blake::scheduler::{BlakeElements, BlakeSchedulerComponent}; + use crate::examples::blake::scheduler::{BlakeElements, BlakeSchedulerEval}; #[test] fn test_blake_scheduler() { @@ -87,7 +89,7 @@ mod tests { let trace = TreeVec::new(vec![trace, interaction_trace]); let trace_polys = trace.map_cols(|c| c.interpolate()); - let component = BlakeSchedulerComponent { + let component = BlakeSchedulerEval { log_size: LOG_SIZE, blake_lookup_elements, round_lookup_elements, diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index 21c417cfd..877a65114 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -15,20 +15,19 @@ mod gen; use std::simd::u32x16; -use constraints::XorTableEval; use itertools::Itertools; use num_traits::Zero; pub use r#gen::{generate_constant_trace, generate_interaction_trace, generate_trace}; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; -use crate::constraint_framework::{EvalAtRow, FrameworkComponent, InfoEvaluator}; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::Column; use crate::core::fields::qm31::SecureField; -use crate::core::pcs::TreeVec; +use crate::core::pcs::{TreeSubspan, TreeVec}; pub fn trace_sizes() -> TreeVec> { - let component = XorTableComponent:: { + let component = XorTableEval:: { lookup_elements: LookupElements::<3>::dummy(), claimed_sum: SecureField::zero(), }; @@ -83,13 +82,19 @@ impl XorAccumulator = + FrameworkComponent>; + pub type XorElements = LookupElements<3>; -pub struct XorTableComponent { + +/// Evaluates the xor table. +pub struct XorTableEval { pub lookup_elements: XorElements, pub claimed_sum: SecureField, } -impl FrameworkComponent - for XorTableComponent + +impl FrameworkEval + for XorTableEval { fn log_size(&self) -> u32 { column_bits::() @@ -98,7 +103,7 @@ impl FrameworkComponent column_bits::() + 1 } fn evaluate(&self, mut eval: E) -> E { - let xor_eval = XorTableEval::<'_, _, ELEM_BITS, EXPAND_BITS> { + let xor_eval = constraints::XorTableEval::<'_, _, ELEM_BITS, EXPAND_BITS> { eval, lookup_elements: &self.lookup_elements, logup: LogupAtRow::new(1, self.claimed_sum, self.log_size()), @@ -112,12 +117,12 @@ mod tests { use std::simd::u32x16; use crate::constraint_framework::logup::LookupElements; - use crate::constraint_framework::{assert_constraints, FrameworkComponent}; + use crate::constraint_framework::{assert_constraints, FrameworkEval}; use crate::core::poly::circle::CanonicCoset; use crate::examples::blake::xor_table::r#gen::{ generate_constant_trace, generate_interaction_trace, generate_trace, }; - use crate::examples::blake::xor_table::{column_bits, XorAccumulator, XorTableComponent}; + use crate::examples::blake::xor_table::{column_bits, XorAccumulator, XorTableEval}; #[test] fn test_xor_table() { @@ -138,7 +143,7 @@ mod tests { let trace = TreeVec::new(vec![trace, interaction_trace, constant_trace]); let trace_polys = trace.map_cols(|c| c.interpolate()); - let component = XorTableComponent:: { + let component = XorTableEval:: { lookup_elements, claimed_sum, }; diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index 79a48b14f..14ccd1744 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -3,7 +3,9 @@ use num_traits::One; use tracing::{span, Level}; use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; -use crate::constraint_framework::{assert_constraints, EvalAtRow, FrameworkComponent}; +use crate::constraint_framework::{ + assert_constraints, EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator, +}; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::qm31::PackedSecureField; @@ -12,21 +14,26 @@ use crate::core::backend::Column; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::pcs::{CommitmentSchemeProver, PcsConfig}; +use crate::core::pcs::{CommitmentSchemeProver, PcsConfig, TreeSubspan}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; use crate::core::prover::{prove, StarkProof}; use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; use crate::core::{ColumnVec, InteractionElements}; +pub type PlonkComponent = FrameworkComponent; + #[derive(Clone)] -pub struct PlonkComponent { +pub struct PlonkEval { pub log_n_rows: u32, pub lookup_elements: LookupElements<2>, pub claimed_sum: SecureField, + pub base_trace_location: TreeSubspan, + pub interaction_trace_location: TreeSubspan, + pub constants_trace_location: TreeSubspan, } -impl FrameworkComponent for PlonkComponent { +impl FrameworkEval for PlonkEval { fn log_size(&self) -> u32 { self.log_n_rows } @@ -181,7 +188,7 @@ pub fn prove_fibonacci_plonk( let span = span!(Level::INFO, "Trace").entered(); let trace = gen_trace(log_n_rows, &circuit); let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace); + let base_trace_location = tree_builder.extend_evals(trace); tree_builder.commit(channel); span.exit(); @@ -192,14 +199,14 @@ pub fn prove_fibonacci_plonk( let span = span!(Level::INFO, "Interaction").entered(); let (trace, claimed_sum) = gen_interaction_trace(log_n_rows, &circuit, &lookup_elements); let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace); + let interaction_trace_location = tree_builder.extend_evals(trace); tree_builder.commit(channel); span.exit(); // Constant trace. let span = span!(Level::INFO, "Constant").entered(); let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals( + let constants_trace_location = tree_builder.extend_evals( chain!([circuit.a_wire, circuit.b_wire, circuit.c_wire, circuit.op] .into_iter() .map(|col| { @@ -214,11 +221,17 @@ pub fn prove_fibonacci_plonk( span.exit(); // Prove constraints. - let component = PlonkComponent { - log_n_rows, - lookup_elements, - claimed_sum, - }; + let component = PlonkComponent::new( + &mut TraceLocationAllocator::default(), + PlonkEval { + log_n_rows, + lookup_elements, + claimed_sum, + base_trace_location, + interaction_trace_location, + constants_trace_location, + }, + ); // Sanity check. Remove for production. let trace_polys = commitment_scheme diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 63af1d03f..83e17a88c 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -7,7 +7,9 @@ use num_traits::One; use tracing::{span, Level}; use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; -use crate::constraint_framework::{EvalAtRow, FrameworkComponent}; +use crate::constraint_framework::{ + EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator, +}; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; @@ -39,28 +41,27 @@ const EXTERNAL_ROUND_CONSTS: [[BaseField; N_STATE]; 2 * N_HALF_FULL_ROUNDS] = const INTERNAL_ROUND_CONSTS: [BaseField; N_PARTIAL_ROUNDS] = [BaseField::from_u32_unchecked(1234); N_PARTIAL_ROUNDS]; +pub type PoseidonComponent = FrameworkComponent; + pub type PoseidonElements = LookupElements<{ N_STATE * 2 }>; #[derive(Clone)] -pub struct PoseidonComponent { +pub struct PoseidonEval { pub log_n_rows: u32, pub lookup_elements: PoseidonElements, pub claimed_sum: SecureField, } -impl FrameworkComponent for PoseidonComponent { +impl FrameworkEval for PoseidonEval { fn log_size(&self) -> u32 { self.log_n_rows } fn max_constraint_log_degree_bound(&self) -> u32 { self.log_n_rows + LOG_EXPAND } - fn evaluate(&self, eval: E) -> E { - let poseidon_eval = PoseidonEval { - eval, - logup: LogupAtRow::new(1, self.claimed_sum, self.log_n_rows), - lookup_elements: &self.lookup_elements, - }; - poseidon_eval.eval() + fn evaluate(&self, mut eval: E) -> E { + let logup = LogupAtRow::new(1, self.claimed_sum, self.log_n_rows); + eval_poseidon_constraints(&mut eval, logup, &self.lookup_elements); + eval } } @@ -133,67 +134,60 @@ fn pow5(x: F) -> F { x4 * x } -struct PoseidonEval<'a, E: EvalAtRow> { - eval: E, - logup: LogupAtRow<2, E>, - lookup_elements: &'a PoseidonElements, -} - -impl<'a, E: EvalAtRow> PoseidonEval<'a, E> { - fn eval(mut self) -> E { - for _ in 0..N_INSTANCES_PER_ROW { - let mut state: [_; N_STATE] = std::array::from_fn(|_| self.eval.next_trace_mask()); +pub fn eval_poseidon_constraints( + eval: &mut E, + mut logup: LogupAtRow<2, E>, + lookup_elements: &PoseidonElements, +) { + for _ in 0..N_INSTANCES_PER_ROW { + let mut state: [_; N_STATE] = std::array::from_fn(|_| eval.next_trace_mask()); - // Require state lookup. - self.logup - .push_lookup(&mut self.eval, E::EF::one(), &state, self.lookup_elements); + // Require state lookup. + logup.push_lookup(eval, E::EF::one(), &state, lookup_elements); - // 4 full rounds. - (0..N_HALF_FULL_ROUNDS).for_each(|round| { - (0..N_STATE).for_each(|i| { - state[i] += EXTERNAL_ROUND_CONSTS[round][i]; - }); - apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); - state.iter_mut().for_each(|s| { - let m = self.eval.next_trace_mask(); - self.eval.add_constraint(*s - m); - *s = m; - }); + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += EXTERNAL_ROUND_CONSTS[round][i]; }); - - // Partial rounds. - (0..N_PARTIAL_ROUNDS).for_each(|round| { - state[0] += INTERNAL_ROUND_CONSTS[round]; - apply_internal_round_matrix(&mut state); - state[0] = pow5(state[0]); - let m = self.eval.next_trace_mask(); - self.eval.add_constraint(state[0] - m); - state[0] = m; + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter_mut().for_each(|s| { + let m = eval.next_trace_mask(); + eval.add_constraint(*s - m); + *s = m; }); + }); - // 4 full rounds. - (0..N_HALF_FULL_ROUNDS).for_each(|round| { - (0..N_STATE).for_each(|i| { - state[i] += EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i]; - }); - apply_external_round_matrix(&mut state); - state = std::array::from_fn(|i| pow5(state[i])); - state.iter_mut().for_each(|s| { - let m = self.eval.next_trace_mask(); - self.eval.add_constraint(*s - m); - *s = m; - }); - }); + // Partial rounds. + (0..N_PARTIAL_ROUNDS).for_each(|round| { + state[0] += INTERNAL_ROUND_CONSTS[round]; + apply_internal_round_matrix(&mut state); + state[0] = pow5(state[0]); + let m = eval.next_trace_mask(); + eval.add_constraint(state[0] - m); + state[0] = m; + }); - // Provide state lookup. - self.logup - .push_lookup(&mut self.eval, -E::EF::one(), &state, self.lookup_elements); - } + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i]; + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter_mut().for_each(|s| { + let m = eval.next_trace_mask(); + eval.add_constraint(*s - m); + *s = m; + }); + }); - self.logup.finalize(&mut self.eval); - self.eval + // Provide state lookup. + logup.push_lookup(eval, -E::EF::one(), &state, lookup_elements); } + + logup.finalize(eval); } pub struct LookupData { @@ -364,11 +358,14 @@ pub fn prove_poseidon( span.exit(); // Prove constraints. - let component = PoseidonComponent { - log_n_rows, - lookup_elements, - claimed_sum, - }; + let component = PoseidonComponent::new( + &mut TraceLocationAllocator::default(), + PoseidonEval { + log_n_rows, + lookup_elements, + claimed_sum, + }, + ); let proof = prove::( &[&component], channel, @@ -399,8 +396,8 @@ mod tests { use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; use crate::core::InteractionElements; use crate::examples::poseidon::{ - apply_internal_round_matrix, apply_m4, gen_interaction_trace, gen_trace, prove_poseidon, - PoseidonElements, PoseidonEval, + apply_internal_round_matrix, apply_m4, eval_poseidon_constraints, gen_interaction_trace, + gen_trace, prove_poseidon, PoseidonElements, }; use crate::math::matrix::{RowMajorMatrix, SquareMatrix}; @@ -467,13 +464,12 @@ mod tests { let traces = TreeVec::new(vec![trace0, trace1]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); - assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |eval| { - PoseidonEval { - eval, - logup: LogupAtRow::new(1, claimed_sum, LOG_N_ROWS), - lookup_elements: &lookup_elements, - } - .eval(); + assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |mut eval| { + eval_poseidon_constraints( + &mut eval, + LogupAtRow::new(1, claimed_sum, LOG_N_ROWS), + &lookup_elements, + ); }); } diff --git a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs b/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs index 3aa983d09..bd3f9025f 100644 --- a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs +++ b/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs @@ -9,7 +9,7 @@ use super::component::{ }; use super::trace_gen::write_trace_row; use crate::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator}; -use crate::core::air::{AirProver, Component, ComponentProver, ComponentTrace}; +use crate::core::air::{AirProver, Component, ComponentProver, Trace}; use crate::core::backend::CpuBackend; use crate::core::channel::Channel; use crate::core::circle::Coset; @@ -256,7 +256,7 @@ impl WideFibComponent { impl ComponentProver for WideFibComponent { fn evaluate_constraint_quotients_on_domain( &self, - trace: &ComponentTrace<'_, CpuBackend>, + trace: &Trace<'_, CpuBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, interaction_elements: &InteractionElements, lookup_values: &LookupValues, @@ -300,7 +300,7 @@ impl ComponentProver for WideFibComponent { ); } - fn lookup_values(&self, trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { + fn lookup_values(&self, trace: &Trace<'_, CpuBackend>) -> LookupValues { let domain = CanonicCoset::new(self.log_column_size()); let trace_poly = &trace.polys[BASE_TRACE]; let values = BTreeMap::from_iter([ diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index 7976324e6..15afa3294 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -13,7 +13,7 @@ mod tests { use super::component::{Input, WideFibAir, WideFibComponent, LOG_N_COLUMNS}; use super::constraint_eval::gen_trace; use crate::core::air::accumulation::DomainEvaluationAccumulator; - use crate::core::air::{Component, ComponentProver, ComponentTrace}; + use crate::core::air::{Component, ComponentProver, Trace}; use crate::core::backend::cpu::CpuCircleEvaluation; use crate::core::backend::CpuBackend; use crate::core::channel::Blake2sChannel; @@ -183,7 +183,7 @@ mod tests { .iter() .map(|poly| poly.evaluate(eval_domain)) .collect_vec(); - let trace = ComponentTrace { + let trace = Trace { polys: TreeVec::new(vec![ trace_polys.iter().collect_vec(), interaction_poly.iter().collect_vec(), diff --git a/crates/prover/src/examples/wide_fibonacci/simd.rs b/crates/prover/src/examples/wide_fibonacci/simd.rs index 8d12f0cd5..a7ed4dcec 100644 --- a/crates/prover/src/examples/wide_fibonacci/simd.rs +++ b/crates/prover/src/examples/wide_fibonacci/simd.rs @@ -5,7 +5,7 @@ use tracing::{span, Level}; use super::component::LOG_N_COLUMNS; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::mask::fixed_mask_points; -use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace}; +use crate::core::air::{Air, AirProver, Component, ComponentProver, Trace}; use crate::core::backend::simd::column::BaseColumn; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; @@ -193,7 +193,7 @@ impl ComponentTraceGenerator for SimdWideFibComponent { impl ComponentProver for SimdWideFibComponent { fn evaluate_constraint_quotients_on_domain( &self, - trace: &ComponentTrace<'_, SimdBackend>, + trace: &Trace<'_, SimdBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, _interaction_elements: &InteractionElements, _lookup_values: &LookupValues, @@ -250,7 +250,7 @@ impl ComponentProver for SimdWideFibComponent { } } - fn lookup_values(&self, _trace: &ComponentTrace<'_, SimdBackend>) -> LookupValues { + fn lookup_values(&self, _trace: &Trace<'_, SimdBackend>) -> LookupValues { LookupValues::default() } } diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index a06b0c2d1..9bb13a72b 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -1,7 +1,8 @@ #![allow(incomplete_features)] #![feature( - array_methods, array_chunks, + array_methods, + array_try_from_fn, assert_matches, exact_size_is_empty, generic_const_exprs, diff --git a/crates/prover/src/trace_generation/prove.rs b/crates/prover/src/trace_generation/prove.rs index 93d6e4bf5..9328ae293 100644 --- a/crates/prover/src/trace_generation/prove.rs +++ b/crates/prover/src/trace_generation/prove.rs @@ -62,7 +62,7 @@ pub fn commit_and_prove, MC: MerkleChannel>( let components = ComponentProvers(air_prover.component_provers()); channel.mix_felts( &components - .lookup_values(&components.component_traces(&commitment_scheme.trees)) + .lookup_values(&commitment_scheme.trace()) .0 .values() .map(|v| SecureField::from(*v)) @@ -175,7 +175,7 @@ mod tests { use num_traits::Zero; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; - use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace}; + use crate::core::air::{Air, AirProver, Component, ComponentProver, Trace}; use crate::core::backend::cpu::CpuCircleEvaluation; use crate::core::backend::CpuBackend; use crate::core::channel::Channel; @@ -310,7 +310,7 @@ mod tests { impl ComponentProver for TestComponent { fn evaluate_constraint_quotients_on_domain( &self, - _trace: &ComponentTrace<'_, CpuBackend>, + _trace: &Trace<'_, CpuBackend>, _evaluation_accumulator: &mut DomainEvaluationAccumulator, _interaction_elements: &InteractionElements, _lookup_values: &LookupValues, @@ -318,7 +318,7 @@ mod tests { // Does nothing. } - fn lookup_values(&self, _trace: &ComponentTrace<'_, CpuBackend>) -> LookupValues { + fn lookup_values(&self, _trace: &Trace<'_, CpuBackend>) -> LookupValues { LookupValues::default() } }