diff --git a/crates/prover/src/core/air/air_ext.rs b/crates/prover/src/core/air/air_ext.rs index ff43c33d5..577497595 100644 --- a/crates/prover/src/core/air/air_ext.rs +++ b/crates/prover/src/core/air/air_ext.rs @@ -1,17 +1,16 @@ -use std::iter::zip; - -use itertools::Itertools; +use itertools::{zip_eq, Itertools}; use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use super::{Air, AirProver, ComponentTrace}; use crate::core::backend::Backend; use crate::core::circle::CirclePoint; -use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, SecureCirclePoly}; -use crate::core::poly::BitReversedOrder; +use crate::core::pcs::{CommitmentTreeProver, TreeVec}; +use crate::core::poly::circle::{CanonicCoset, SecureCirclePoly}; use crate::core::prover::LOG_BLOWUP_FACTOR; -use crate::core::ComponentVec; +use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; +use crate::core::vcs::ops::MerkleOps; +use crate::core::{ComponentVec, InteractionElements}; pub trait AirExt: Air { fn composition_log_degree_bound(&self) -> u32 { @@ -46,13 +45,15 @@ pub trait AirExt: Air { point: CirclePoint, mask_values: &ComponentVec>, random_coeff: SecureField, + interaction_elements: &InteractionElements, ) -> SecureField { let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff); - zip(self.components(), &mask_values.0).for_each(|(component, mask)| { + zip_eq(self.components(), &mask_values.0).for_each(|(component, mask)| { component.evaluate_constraint_quotients_at_point( point, mask, &mut evaluation_accumulator, + interaction_elements, ) }); evaluation_accumulator.finalize() @@ -65,24 +66,42 @@ pub trait AirExt: Air { .collect() } - fn component_traces<'a, B: Backend>( + fn component_traces<'a, B: Backend + MerkleOps>( &'a self, - polynomials: &'a [CirclePoly], - evals: &'a [CircleEvaluation], + trees: &'a [CommitmentTreeProver], ) -> Vec> { - let poly_iter = &mut polynomials.iter(); - let eval_iter = &mut evals.iter(); - self.components() - .iter() - .map(|component| { - let n_columns = component.trace_log_degree_bounds().len(); - let polys = poly_iter.take(n_columns).collect(); - let evals = eval_iter.take(n_columns).collect(); - ComponentTrace::new(polys, evals) - }) - .collect() + let poly_iter = &mut trees[0].polynomials.iter(); + let eval_iter = &mut trees[0].evaluations.iter(); + let mut component_traces = vec![]; + self.components().iter().for_each(|component| { + let n_columns = component.trace_log_degree_bounds().len(); + let polys = poly_iter.take(n_columns).collect_vec(); + let evals = eval_iter.take(n_columns).collect_vec(); + + component_traces.push(ComponentTrace { + polys: TreeVec::new(vec![polys]), + evals: TreeVec::new(vec![evals]), + }); + }); + + if trees.len() > 1 { + let poly_iter = &mut trees[1].polynomials.iter(); + let eval_iter = &mut trees[1].evaluations.iter(); + self.components() + .iter() + .zip_eq(&mut component_traces) + .for_each(|(_component, component_trace)| { + // TODO(AlonH): Implement n_interaction_columns() for component. + let polys = poly_iter.take(1).collect_vec(); + let evals = eval_iter.take(1).collect_vec(); + component_trace.polys.push(polys); + component_trace.evals.push(evals); + }); + } + component_traces } } + impl AirExt for A {} pub trait AirProverExt: AirProver { @@ -90,6 +109,7 @@ pub trait AirProverExt: AirProver { &self, random_coeff: SecureField, component_traces: &[ComponentTrace<'_, B>], + interaction_elements: &InteractionElements, ) -> SecureCirclePoly { let total_constraints: usize = self .prover_components() @@ -101,8 +121,12 @@ pub trait AirProverExt: AirProver { self.composition_log_degree_bound(), total_constraints, ); - zip(self.prover_components(), component_traces).for_each(|(component, trace)| { - component.evaluate_constraint_quotients_on_domain(trace, &mut accumulator) + zip_eq(self.prover_components(), component_traces).for_each(|(component, trace)| { + component.evaluate_constraint_quotients_on_domain( + trace, + &mut accumulator, + interaction_elements, + ) }); accumulator.finalize() } diff --git a/crates/prover/src/core/air/mod.rs b/crates/prover/src/core/air/mod.rs index e4bafba18..ba43d3ef8 100644 --- a/crates/prover/src/core/air/mod.rs +++ b/crates/prover/src/core/air/mod.rs @@ -4,6 +4,7 @@ use super::channel::Blake2sChannel; use super::circle::CirclePoint; use super::fields::m31::BaseField; use super::fields::qm31::SecureField; +use super::pcs::TreeVec; use super::poly::circle::{CircleEvaluation, CirclePoly}; use super::poly::BitReversedOrder; use super::{ColumnVec, ComponentVec, InteractionElements}; @@ -35,7 +36,7 @@ pub trait AirTraceWriter: AirTraceVerifier { elements: &InteractionElements, ) -> ComponentVec>; - fn to_air_prover(&self) -> &dyn AirProver; + fn to_air_prover(&self) -> &impl AirProver; } pub trait AirProver: Air { @@ -66,6 +67,7 @@ pub trait Component { point: CirclePoint, mask: &ColumnVec>, evaluation_accumulator: &mut PointEvaluationAccumulator, + interaction_elements: &InteractionElements, ); } @@ -84,6 +86,7 @@ pub trait ComponentProver: Component { &self, trace: &ComponentTrace<'_, B>, evaluation_accumulator: &mut DomainEvaluationAccumulator, + interaction_elements: &InteractionElements, ); } @@ -91,16 +94,16 @@ pub trait ComponentProver: Component { /// Each polynomial is stored both in a coefficients, and evaluations form (for efficiency) pub struct ComponentTrace<'a, B: Backend> { /// Polynomials for each column. - pub polys: Vec<&'a CirclePoly>, + pub polys: TreeVec>>, /// Evaluations for each column. The evaluation domain is the commitment domain for that column /// obtained from [AirExt::trace_commitment_domains()]. - pub evals: Vec<&'a CircleEvaluation>, + pub evals: TreeVec>>, } impl<'a, B: Backend> ComponentTrace<'a, B> { pub fn new( - polys: Vec<&'a CirclePoly>, - evals: Vec<&'a CircleEvaluation>, + polys: TreeVec>>, + evals: TreeVec>>, ) -> Self { Self { polys, evals } } diff --git a/crates/prover/src/core/mod.rs b/crates/prover/src/core/mod.rs index 899c584f6..f127f6515 100644 --- a/crates/prover/src/core/mod.rs +++ b/crates/prover/src/core/mod.rs @@ -61,8 +61,19 @@ impl DerefMut for ComponentVec { } } +#[derive(Default)] pub struct InteractionElements(BTreeMap); +impl InteractionElements { + pub fn new(elements: BTreeMap) -> Self { + Self(elements) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + impl Index<&str> for InteractionElements { type Output = BaseField; diff --git a/crates/prover/src/core/pcs/mod.rs b/crates/prover/src/core/pcs/mod.rs index 6dfcc11e6..def968218 100644 --- a/crates/prover/src/core/pcs/mod.rs +++ b/crates/prover/src/core/pcs/mod.rs @@ -11,6 +11,6 @@ pub mod quotients; mod utils; mod verifier; -pub use self::prover::{CommitmentSchemeProof, CommitmentSchemeProver}; +pub use self::prover::{CommitmentSchemeProof, CommitmentSchemeProver, CommitmentTreeProver}; pub use self::utils::TreeVec; pub use self::verifier::CommitmentSchemeVerifier; diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index 6d5fffcba..078a16810 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -2,14 +2,14 @@ use itertools::Itertools; use thiserror::Error; use tracing::{span, Level}; -use super::air::AirProver; +use super::air::{AirProver, AirTraceVerifier, AirTraceWriter}; use super::backend::Backend; use super::fri::FriVerificationError; use super::pcs::{CommitmentSchemeProof, TreeVec}; use super::poly::circle::{CanonicCoset, SecureCirclePoly, MAX_CIRCLE_DOMAIN_LOG_SIZE}; use super::poly::twiddles::TwiddleTree; use super::proof_of_work::ProofOfWorkVerificationError; -use super::ColumnVec; +use super::{ColumnVec, InteractionElements}; use crate::core::air::{Air, AirExt, AirProverExt}; use crate::core::backend::CpuBackend; use crate::core::channel::{Blake2sChannel, Channel as ChannelTrait}; @@ -50,12 +50,15 @@ pub struct AdditionalProofData { } pub fn evaluate_and_commit_on_trace>( + air: &impl AirTraceWriter, channel: &mut Channel, twiddles: &TwiddleTree, trace: ColumnVec>, -) -> Result, ProvingError> { +) -> Result<(CommitmentSchemeProver, InteractionElements), ProvingError> { let span = span!(Level::INFO, "Trace interpolation").entered(); + // TODO(AlonH): Remove clone. let trace_polys = trace + .clone() .into_iter() .map(|poly| poly.interpolate_with_twiddles(twiddles)) .collect(); @@ -66,12 +69,29 @@ pub fn evaluate_and_commit_on_trace>( commitment_scheme.commit(trace_polys, channel, twiddles); span.exit(); - Ok(commitment_scheme) + let interaction_elements = air.interaction_elements(channel); + let interaction_traces = air.interact(&trace, &interaction_elements); + let interaction_trace_polys = interaction_traces + .0 + .into_iter() + .flat_map(|trace| { + trace + .into_iter() + .map(|poly| poly.interpolate_with_twiddles(twiddles)) + }) + .collect_vec(); + let n_interaction_traces = interaction_trace_polys.len(); + if n_interaction_traces > 0 { + commitment_scheme.commit(interaction_trace_polys, channel, twiddles); + } + + Ok((commitment_scheme, interaction_elements)) } pub fn generate_proof>( air: &impl AirProver, channel: &mut Channel, + interaction_elements: &InteractionElements, twiddles: &TwiddleTree, commitment_scheme: &mut CommitmentSchemeProver, ) -> Result { @@ -81,10 +101,8 @@ pub fn generate_proof>( let span = span!(Level::INFO, "Composition generation").entered(); let composition_polynomial_poly = air.compute_composition_polynomial( random_coeff, - &air.component_traces( - &commitment_scheme.trees[0].polynomials, - &commitment_scheme.trees[0].evaluations, - ), + &air.component_traces(&commitment_scheme.trees), + interaction_elements, ); span.exit(); @@ -101,7 +119,14 @@ pub fn generate_proof>( // TODO(spapini): Change when we support multiple interactions. // First tree - trace. let mut sample_points = TreeVec::new(vec![sample_points.flatten()]); - // Second tree - composition polynomial. + if commitment_scheme.trees.len() > 2 { + // Second tree - interaction trace. + sample_points.push(vec![ + vec![oods_point]; + commitment_scheme.trees[1].polynomials.len() + ]); + } + // Final tree - composition polynomial. sample_points.push(vec![vec![oods_point]; 4]); // Prove the trace and composition OODS values, and retrieve them. @@ -111,10 +136,15 @@ pub fn generate_proof>( // values. This is a sanity check. // TODO(spapini): Save clone. let (trace_oods_values, composition_oods_value) = - sampled_values_to_mask(air, commitment_scheme_proof.sampled_values.clone()).unwrap(); + sampled_values_to_mask(air, &commitment_scheme_proof.sampled_values).unwrap(); if composition_oods_value - != air.eval_composition_polynomial_at_point(oods_point, &trace_oods_values, random_coeff) + != air.eval_composition_polynomial_at_point( + oods_point, + &trace_oods_values, + random_coeff, + interaction_elements, + ) { return Err(ProvingError::ConstraintsNotSatisfied); } @@ -126,7 +156,7 @@ pub fn generate_proof>( } pub fn prove>( - air: &impl AirProver, + air: &impl AirTraceWriter, channel: &mut Channel, trace: ColumnVec>, ) -> Result { @@ -141,7 +171,9 @@ pub fn prove>( } // Check that the composition polynomial is not too big. - let composition_polynomial_log_degree_bound = air.composition_log_degree_bound(); + // TODO(AlonH): Get traces log degree bounds from trace writer. + let composition_polynomial_log_degree_bound = + air.to_air_prover().composition_log_degree_bound(); if composition_polynomial_log_degree_bound + LOG_BLOWUP_FACTOR > MAX_CIRCLE_DOMAIN_LOG_SIZE { return Err(ProvingError::MaxCompositionDegreeExceeded { degree: composition_polynomial_log_degree_bound, @@ -150,30 +182,47 @@ pub fn prove>( let span = span!(Level::INFO, "Precompute twiddle").entered(); let twiddles = B::precompute_twiddles( - CanonicCoset::new(air.composition_log_degree_bound() + LOG_BLOWUP_FACTOR) + CanonicCoset::new(composition_polynomial_log_degree_bound + LOG_BLOWUP_FACTOR) .circle_domain() .half_coset, ); span.exit(); - let mut commitment_scheme = evaluate_and_commit_on_trace(channel, &twiddles, trace)?; + let (mut commitment_scheme, interaction_elements) = + evaluate_and_commit_on_trace(air, channel, &twiddles, trace)?; - generate_proof(air, channel, &twiddles, &mut commitment_scheme) + generate_proof( + air.to_air_prover(), + channel, + &interaction_elements, + &twiddles, + &mut commitment_scheme, + ) } pub fn verify( proof: StarkProof, - air: &impl Air, + air: &(impl Air + AirTraceVerifier), channel: &mut Channel, ) -> Result<(), VerificationError> { // Read trace commitment. let mut commitment_scheme = CommitmentSchemeVerifier::new(); commitment_scheme.commit(proof.commitments[0], air.column_log_sizes(), channel); + let interaction_elements = air.interaction_elements(channel); + + if proof.commitments.len() > 2 { + commitment_scheme.commit( + proof.commitments[1], + air.column_log_sizes()[..1].to_vec(), + channel, + ); + } + let random_coeff = channel.draw_felt(); // Read composition polynomial commitment. commitment_scheme.commit( - proof.commitments[1], + *proof.commitments.last().unwrap(), vec![air.composition_log_degree_bound(); 4], channel, ); @@ -187,20 +236,30 @@ pub fn verify( // TODO(spapini): Change when we support multiple interactions. // First tree - trace. let mut sample_points = TreeVec::new(vec![trace_sample_points.flatten()]); - // Second tree - composition polynomial. + if proof.commitments.len() > 2 { + // Second tree - interaction trace. + // TODO(AlonH): Get the number of interaction traces from the air. + sample_points.push(vec![vec![oods_point]; 1]); + } + // Final tree - composition polynomial. sample_points.push(vec![vec![oods_point]; 4]); // TODO(spapini): Save clone. let (trace_oods_values, composition_oods_value) = sampled_values_to_mask( air, - proof.commitment_scheme_proof.sampled_values.clone(), + &proof.commitment_scheme_proof.sampled_values, ) .map_err(|_| { VerificationError::InvalidStructure("Unexpected sampled_values structure".to_string()) })?; if composition_oods_value - != air.eval_composition_polynomial_at_point(oods_point, &trace_oods_values, random_coeff) + != air.eval_composition_polynomial_at_point( + oods_point, + &trace_oods_values, + random_coeff, + &interaction_elements, + ) { return Err(VerificationError::OodsNotMatching); } @@ -212,10 +271,40 @@ pub fn verify( /// polynomial OODS value. fn sampled_values_to_mask( air: &impl Air, - mut sampled_values: TreeVec>>, + sampled_values: &TreeVec>>, ) -> Result<(ComponentVec>, SecureField), InvalidOodsSampleStructure> { + // Retrieve sampled mask values for each component. + let flat_trace_values = &mut sampled_values + .first() + .ok_or(InvalidOodsSampleStructure)? + .iter(); + let mut trace_oods_values = vec![]; + air.components().iter().for_each(|component| { + trace_oods_values.push( + flat_trace_values + .take(component.mask_points(CirclePoint::zero()).len()) + .cloned() + .collect_vec(), + ) + }); + + if sampled_values.len() > 2 { + let interaction_values = &mut sampled_values + .get(1) + .ok_or(InvalidOodsSampleStructure)? + .iter(); + + air.components() + .iter() + .zip_eq(&mut trace_oods_values) + .for_each(|(_component, values)| { + // TODO(AlonH): Implement n_interaction_columns() for component. + values.extend(interaction_values.take(1).cloned().collect_vec()) + }); + } + let composition_partial_sampled_values = - sampled_values.pop().ok_or(InvalidOodsSampleStructure)?; + sampled_values.last().ok_or(InvalidOodsSampleStructure)?; let composition_oods_value = SecureCirclePoly::::eval_from_partial_evals( composition_partial_sampled_values .iter() @@ -226,23 +315,7 @@ fn sampled_values_to_mask( .map_err(|_| InvalidOodsSampleStructure)?, ); - // Retrieve sampled mask values for each component. - let flat_trace_values = &mut sampled_values - .pop() - .ok_or(InvalidOodsSampleStructure)? - .into_iter(); - let trace_oods_values = ComponentVec( - air.components() - .iter() - .map(|c| { - flat_trace_values - .take(c.mask_points(CirclePoint::zero()).len()) - .collect_vec() - }) - .collect(), - ); - - Ok((trace_oods_values, composition_oods_value)) + Ok((ComponentVec(trace_oods_values), composition_oods_value)) } /// Error when the sampled values have an invalid structure. @@ -288,10 +361,12 @@ mod tests { use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::{ - Air, AirProver, Component, ComponentProver, ComponentTrace, ComponentTraceWriter, + Air, AirProver, AirTraceVerifier, AirTraceWriter, Component, ComponentProver, + ComponentTrace, ComponentTraceWriter, }; use crate::core::backend::cpu::CpuCircleEvaluation; use crate::core::backend::CpuBackend; + use crate::core::channel::Blake2sChannel; use crate::core::circle::{CirclePoint, CirclePointIndex, Coset}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; @@ -301,7 +376,7 @@ mod tests { use crate::core::poly::BitReversedOrder; use crate::core::prover::{prove, ProvingError}; use crate::core::test_utils::test_channel; - use crate::core::{ColumnVec, InteractionElements}; + use crate::core::{ColumnVec, ComponentVec, InteractionElements}; use crate::qm31; struct TestAir> { @@ -314,6 +389,26 @@ mod tests { } } + impl AirTraceVerifier for TestAir { + fn interaction_elements(&self, _channel: &mut Blake2sChannel) -> InteractionElements { + InteractionElements::default() + } + } + + impl AirTraceWriter for TestAir { + fn interact( + &self, + _trace: &ColumnVec>, + _elements: &InteractionElements, + ) -> ComponentVec> { + ComponentVec(vec![vec![]]) + } + + fn to_air_prover(&self) -> &impl AirProver { + self + } + } + impl AirProver for TestAir { fn prover_components(&self) -> Vec<&dyn ComponentProver> { vec![&self.component] @@ -354,6 +449,7 @@ mod tests { _point: CirclePoint, _mask: &crate::core::ColumnVec>, evaluation_accumulator: &mut PointEvaluationAccumulator, + _interaction_elements: &InteractionElements, ) { evaluation_accumulator.accumulate(qm31!(0, 0, 0, 1)) } @@ -374,6 +470,7 @@ mod tests { &self, _trace: &ComponentTrace<'_, CpuBackend>, _evaluation_accumulator: &mut DomainEvaluationAccumulator, + _interaction_elements: &InteractionElements, ) { // Does nothing. } diff --git a/crates/prover/src/examples/fibonacci/air.rs b/crates/prover/src/examples/fibonacci/air.rs index c76ba1f79..dbec91ecf 100644 --- a/crates/prover/src/examples/fibonacci/air.rs +++ b/crates/prover/src/examples/fibonacci/air.rs @@ -1,9 +1,15 @@ use itertools::{zip_eq, Itertools}; use super::component::FibonacciComponent; -use crate::core::air::{Air, AirProver, Component, ComponentProver}; +use crate::core::air::{ + Air, AirProver, AirTraceVerifier, AirTraceWriter, Component, ComponentProver, +}; use crate::core::backend::CpuBackend; +use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; +use crate::core::{ColumnVec, ComponentVec, InteractionElements}; pub struct FibonacciAir { pub component: FibonacciComponent, @@ -14,11 +20,33 @@ impl FibonacciAir { Self { component } } } + impl Air for FibonacciAir { fn components(&self) -> Vec<&dyn Component> { vec![&self.component] } } + +impl AirTraceVerifier for FibonacciAir { + fn interaction_elements(&self, _channel: &mut Blake2sChannel) -> InteractionElements { + InteractionElements::default() + } +} + +impl AirTraceWriter for FibonacciAir { + fn interact( + &self, + _trace: &ColumnVec>, + _elements: &InteractionElements, + ) -> ComponentVec> { + ComponentVec(vec![vec![]]) + } + + fn to_air_prover(&self) -> &impl AirProver { + self + } +} + impl AirProver for FibonacciAir { fn prover_components(&self) -> Vec<&dyn ComponentProver> { vec![&self.component] @@ -28,6 +56,7 @@ impl AirProver for FibonacciAir { pub struct MultiFibonacciAir { pub components: Vec, } + impl MultiFibonacciAir { pub fn new(log_sizes: &[u32], claim: &[BaseField]) -> Self { let mut components = Vec::new(); @@ -37,6 +66,7 @@ impl MultiFibonacciAir { Self { components } } } + impl Air for MultiFibonacciAir { fn components(&self) -> Vec<&dyn Component> { self.components @@ -45,6 +75,27 @@ impl Air for MultiFibonacciAir { .collect_vec() } } + +impl AirTraceVerifier for MultiFibonacciAir { + fn interaction_elements(&self, _channel: &mut Blake2sChannel) -> InteractionElements { + InteractionElements::default() + } +} + +impl AirTraceWriter for MultiFibonacciAir { + fn interact( + &self, + _trace: &ColumnVec>, + _elements: &InteractionElements, + ) -> ComponentVec> { + ComponentVec(vec![vec![]]) + } + + fn to_air_prover(&self) -> &impl AirProver { + self + } +} + impl AirProver for MultiFibonacciAir { fn prover_components(&self) -> Vec<&dyn ComponentProver> { self.components diff --git a/crates/prover/src/examples/fibonacci/component.rs b/crates/prover/src/examples/fibonacci/component.rs index b61bcfc4e..243669d99 100644 --- a/crates/prover/src/examples/fibonacci/component.rs +++ b/crates/prover/src/examples/fibonacci/component.rs @@ -104,6 +104,7 @@ impl Component for FibonacciComponent { point: CirclePoint, mask: &ColumnVec>, evaluation_accumulator: &mut PointEvaluationAccumulator, + _interaction_elements: &InteractionElements, ) { evaluation_accumulator.accumulate( self.step_constraint_eval_quotient_by_mask(point, &mask[0][..].try_into().unwrap()), @@ -132,8 +133,9 @@ impl ComponentProver for FibonacciComponent { &self, trace: &ComponentTrace<'_, CpuBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, + _interaction_elements: &InteractionElements, ) { - let poly = &trace.polys[0]; + let poly = &trace.polys[0][0]; let trace_domain = CanonicCoset::new(self.log_size); let trace_eval_domain = CanonicCoset::new(self.log_size + 1).circle_domain(); let trace_eval = poly.evaluate(trace_eval_domain).bit_reverse(); diff --git a/crates/prover/src/examples/fibonacci/mod.rs b/crates/prover/src/examples/fibonacci/mod.rs index 64f97a39b..940f979b2 100644 --- a/crates/prover/src/examples/fibonacci/mod.rs +++ b/crates/prover/src/examples/fibonacci/mod.rs @@ -110,6 +110,7 @@ impl MultiFibonacci { #[cfg(test)] mod tests { use std::assert_matches::assert_matches; + use std::collections::BTreeMap; use std::iter::zip; use itertools::Itertools; @@ -123,10 +124,12 @@ mod tests { use crate::core::circle::CirclePoint; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; + use crate::core::pcs::TreeVec; use crate::core::poly::circle::CanonicCoset; use crate::core::prover::VerificationError; use crate::core::queries::Queries; use crate::core::utils::bit_reverse; + use crate::core::InteractionElements; use crate::{m31, qm31}; pub fn generate_test_queries(n_queries: usize, trace_length: usize) -> Vec { @@ -146,20 +149,25 @@ mod tests { let trace_poly = trace.interpolate(); let trace_eval = trace_poly.evaluate(CanonicCoset::new(trace_poly.log_size() + 1).circle_domain()); - let trace = ComponentTrace::new(vec![&trace_poly], vec![&trace_eval]); + let trace = ComponentTrace::new( + TreeVec::new(vec![vec![&trace_poly]]), + TreeVec::new(vec![vec![&trace_eval]]), + ); let random_coeff = qm31!(2213980, 2213981, 2213982, 2213983); let component_traces = vec![trace]; - let composition_polynomial_poly = fib - .air - .compute_composition_polynomial(random_coeff, &component_traces); + let composition_polynomial_poly = fib.air.compute_composition_polynomial( + random_coeff, + &component_traces, + &InteractionElements::new(BTreeMap::new()), + ); // Evaluate this polynomial at another point out of the evaluation domain and compare to // what we expect. let point = CirclePoint::::get_point(98989892); let points = fib.air.mask_points(point); - let mask_values = zip(&component_traces[0].polys, &points[0]) + let mask_values = zip(&component_traces[0].polys[0], &points[0]) .map(|(poly, points)| { points .iter() @@ -173,8 +181,10 @@ mod tests { point, &mask_values, &mut evaluation_accumulator, + &InteractionElements::new(BTreeMap::new()), ); let oods_value = evaluation_accumulator.finalize(); + assert_eq!(oods_value, composition_polynomial_poly.eval_at_point(point)); } diff --git a/crates/prover/src/examples/wide_fibonacci/component.rs b/crates/prover/src/examples/wide_fibonacci/component.rs index 2c19edbee..b8d6ed19b 100644 --- a/crates/prover/src/examples/wide_fibonacci/component.rs +++ b/crates/prover/src/examples/wide_fibonacci/component.rs @@ -84,6 +84,7 @@ impl Component for WideFibComponent { point: CirclePoint, mask: &ColumnVec>, evaluation_accumulator: &mut PointEvaluationAccumulator, + _interaction_elements: &InteractionElements, ) { let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset; let denom = coset_vanishing(constraint_zero_domain, point); diff --git a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs b/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs index 8b488a2ea..18fc41b71 100644 --- a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs +++ b/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs @@ -1,22 +1,53 @@ -use itertools::Itertools; +use std::collections::BTreeMap; + +use itertools::{zip_eq, Itertools}; use num_traits::Zero; use super::component::{Input, WideFibAir, WideFibComponent}; use super::trace_gen::write_trace_row; use crate::core::air::accumulation::DomainEvaluationAccumulator; -use crate::core::air::{AirProver, Component, ComponentProver, ComponentTrace}; -use crate::core::backend::{Column, CpuBackend}; +use crate::core::air::{ + AirProver, AirTraceVerifier, AirTraceWriter, Component, ComponentProver, ComponentTrace, + ComponentTraceWriter, +}; +use crate::core::backend::CpuBackend; +use crate::core::channel::{Blake2sChannel, Channel}; use crate::core::constraints::coset_vanishing; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; -use crate::core::poly::circle::CanonicCoset; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; use crate::core::utils::bit_reverse; -use crate::core::ColumnVec; +use crate::core::{ColumnVec, ComponentVec, InteractionElements}; use crate::examples::wide_fibonacci::component::LOG_N_COLUMNS; // TODO(AlonH): Rename file to `cpu.rs`. +impl AirTraceVerifier for WideFibAir { + fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements { + let ids = self.component.interaction_element_ids(); + let elements = channel.draw_felts(ids.len()).into_iter().map(|e| e.0 .0); + InteractionElements::new(BTreeMap::from_iter(zip_eq(ids, elements))) + } +} + +impl AirTraceWriter for WideFibAir { + fn interact( + &self, + trace: &ColumnVec>, + elements: &InteractionElements, + ) -> ComponentVec> { + ComponentVec(vec![self + .component + .write_interaction_trace(&trace.iter().collect(), elements)]) + } + + fn to_air_prover(&self) -> &impl AirProver { + self + } +} + impl AirProver for WideFibAir { fn prover_components(&self) -> Vec<&dyn ComponentProver> { vec![&self.component] @@ -28,6 +59,7 @@ impl ComponentProver for WideFibComponent { &self, trace: &ComponentTrace<'_, CpuBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, + _interaction_elements: &InteractionElements, ) { let max_constraint_degree = self.max_constraint_log_degree_bound(); let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain(); @@ -49,9 +81,8 @@ impl ComponentProver for WideFibComponent { // Step constraints. for j in 0..self.n_columns() - 2 { numerators[i] += accum.random_coeff_powers[self.n_columns() - 3 - j] - * (trace_evals[j].values.at(i).square() - + trace_evals[j + 1].values.at(i).square() - - trace_evals[j + 2].values.at(i)); + * (trace_evals[0][j][i].square() + trace_evals[0][j + 1][i].square() + - trace_evals[0][j + 2][i]); } } for (i, (num, denom)) in numerators.iter().zip(denom_inverses.iter()).enumerate() { diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index 59486b9ac..809a35616 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -5,27 +5,30 @@ pub mod trace_gen; #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use itertools::Itertools; use num_traits::{One, Zero}; use super::component::{Input, WideFibAir, WideFibComponent, LOG_N_COLUMNS}; use super::constraint_eval::gen_trace; use crate::core::air::accumulation::DomainEvaluationAccumulator; - use crate::core::air::{Component, ComponentProver, ComponentTrace}; + use crate::core::air::{Component, ComponentProver, ComponentTrace, ComponentTraceWriter}; use crate::core::backend::cpu::CpuCircleEvaluation; use crate::core::backend::CpuBackend; use crate::core::channel::{Blake2sChannel, Channel}; use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::QM31; use crate::core::fields::IntoSlice; + use crate::core::pcs::TreeVec; use crate::core::poly::circle::CanonicCoset; use crate::core::poly::BitReversedOrder; use crate::core::prover::{prove, verify}; use crate::core::utils::shifted_secure_combination; use crate::core::vcs::blake2_hash::Blake2sHasher; use crate::core::vcs::hasher::Hasher; + use crate::core::InteractionElements; use crate::examples::wide_fibonacci::trace_gen::write_lookup_column; - use crate::m31; + use crate::{m31, qm31}; pub fn assert_constraints_on_row(row: &[BaseField]) { for i in 2..row.len() { @@ -119,11 +122,12 @@ mod tests { #[test] fn test_composition_is_low_degree() { let wide_fib = WideFibComponent { - log_fibonacci_size: LOG_N_COLUMNS as u32, - log_n_instances: 7, + log_fibonacci_size: 3 + LOG_N_COLUMNS as u32, + log_n_instances: 0, }; + let random_coeff = qm31!(1, 2, 3, 4); let mut acc = DomainEvaluationAccumulator::new( - QM31::from_u32_unchecked(1, 2, 3, 4), + random_coeff, wide_fib.max_constraint_log_degree_bound(), wide_fib.n_constraints(), ); @@ -136,33 +140,59 @@ mod tests { let trace = gen_trace(&wide_fib, inputs); - let trace_domain = CanonicCoset::new(wide_fib.log_column_size()); + let trace_domain = CanonicCoset::new(wide_fib.log_column_size()).circle_domain(); let trace = trace .into_iter() - .map(|col| CpuCircleEvaluation::new_canonical_ordered(trace_domain, col)) + .map(|eval| CpuCircleEvaluation::<_, BitReversedOrder>::new(trace_domain, eval)) .collect_vec(); let trace_polys = trace + .clone() .into_iter() .map(|eval| eval.interpolate()) .collect_vec(); - let eval_domain = CanonicCoset::new(wide_fib.log_column_size() + 1).circle_domain(); + let eval_domain = + CanonicCoset::new(wide_fib.max_constraint_log_degree_bound()).circle_domain(); let trace_evals = trace_polys .iter() .map(|poly| poly.evaluate(eval_domain)) .collect_vec(); + let interaction_elements = InteractionElements::new(BTreeMap::from_iter( + wide_fib + .interaction_element_ids() + .iter() + .cloned() + .enumerate() + .map(|(i, id)| (id, m31!(43 + i as u32))), + )); + let interaction_trace = + wide_fib.write_interaction_trace(&trace.iter().collect(), &interaction_elements); + + let interaction_poly = interaction_trace + .iter() + .map(|trace| trace.clone().interpolate()) + .collect_vec(); + let interaction_trace = interaction_poly + .iter() + .map(|poly| poly.evaluate(eval_domain)) + .collect_vec(); let trace = ComponentTrace { - polys: trace_polys.iter().collect(), - evals: trace_evals.iter().collect(), + polys: TreeVec::new(vec![ + trace_polys.iter().collect_vec(), + interaction_poly.iter().collect_vec(), + ]), + evals: TreeVec::new(vec![ + trace_evals.iter().collect_vec(), + interaction_trace.iter().collect_vec(), + ]), }; - wide_fib.evaluate_constraint_quotients_on_domain(&trace, &mut acc); + wide_fib.evaluate_constraint_quotients_on_domain(&trace, &mut acc, &interaction_elements); let res = acc.finalize(); let poly = res.0[0].clone(); - for coeff in - poly.coeffs[(1 << (wide_fib.max_constraint_log_degree_bound() - 1)) + 1..].iter() - { + + for coeff in poly.coeffs[1 << wide_fib.max_constraint_log_degree_bound()..].iter() { assert_eq!(*coeff, BaseField::zero()); } } diff --git a/crates/prover/src/examples/wide_fibonacci/simd.rs b/crates/prover/src/examples/wide_fibonacci/simd.rs index 246e37372..114a7fee2 100644 --- a/crates/prover/src/examples/wide_fibonacci/simd.rs +++ b/crates/prover/src/examples/wide_fibonacci/simd.rs @@ -5,7 +5,7 @@ use tracing::{span, Level}; use super::component::{WideFibAir, WideFibComponent}; use crate::core::air::accumulation::DomainEvaluationAccumulator; use crate::core::air::{ - AirProver, Component, ComponentProver, ComponentTrace, ComponentTraceWriter, + AirProver, AirTraceWriter, Component, ComponentProver, ComponentTrace, ComponentTraceWriter, }; use crate::core::backend::simd::column::BaseFieldVec; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; @@ -17,9 +17,23 @@ use crate::core::fields::m31::BaseField; use crate::core::fields::{FieldExpOps, FieldOps}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; use crate::core::poly::BitReversedOrder; -use crate::core::{ColumnVec, InteractionElements}; +use crate::core::{ColumnVec, ComponentVec, InteractionElements}; use crate::examples::wide_fibonacci::component::N_COLUMNS; +impl AirTraceWriter for WideFibAir { + fn interact( + &self, + _trace: &ColumnVec>, + _elements: &InteractionElements, + ) -> ComponentVec> { + ComponentVec(vec![vec![]]) + } + + fn to_air_prover(&self) -> &impl AirProver { + self + } +} + impl AirProver for WideFibAir { fn prover_components(&self) -> Vec<&dyn ComponentProver> { vec![&self.component] @@ -67,8 +81,9 @@ impl ComponentProver for WideFibComponent { &self, trace: &ComponentTrace<'_, SimdBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, + _interaction_elements: &InteractionElements, ) { - assert_eq!(trace.polys.len(), self.n_columns()); + assert_eq!(trace.polys[0].len(), self.n_columns()); // TODO(spapini): Steal evaluation from commitment. let eval_domain = CanonicCoset::new(self.log_column_size() + 1).circle_domain(); let trace_eval = &trace.evals; @@ -93,14 +108,17 @@ impl ComponentProver for WideFibComponent { for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES)) { // Numerator. - let a = trace_eval[0].data[vec_row]; + let a = trace_eval[0][0].data[vec_row]; let mut row_res = PackedSecureField::zero(); let mut a_sq = a.square(); - let mut b_sq = trace_eval[1].data[vec_row].square(); + let mut b_sq = trace_eval[0][1].data[vec_row].square(); #[allow(clippy::needless_range_loop)] for i in 0..(self.n_columns() - 2) { unsafe { - let c = *trace_eval.get_unchecked(i + 2).data.get_unchecked(vec_row); + let c = *trace_eval[0] + .get_unchecked(i + 2) + .data + .get_unchecked(vec_row); row_res += PackedSecureField::broadcast( accum.random_coeff_powers[self.n_columns() - 3 - i], ) * (a_sq + b_sq - c);