From bdbb4a398a9c7e1cab55817bbdf7f5fde4dd0020 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Fri, 23 Aug 2024 22:20:20 -0400 Subject: [PATCH] Remove InteractionElements and LookupValues --- .../src/constraint_framework/component.rs | 10 +- .../prover/src/constraint_framework/logup.rs | 5 +- crates/prover/src/core/air/components.rs | 23 +- crates/prover/src/core/air/mod.rs | 9 +- .../prover/src/core/backend/simd/quotients.rs | 9 +- crates/prover/src/core/mod.rs | 52 +-- crates/prover/src/core/pcs/prover.rs | 20 +- .../src/core/poly/circle/secure_poly.rs | 4 + crates/prover/src/core/prover/mod.rs | 36 +- crates/prover/src/examples/blake/air.rs | 10 +- crates/prover/src/examples/blake/round/gen.rs | 4 +- .../src/examples/blake/scheduler/gen.rs | 4 +- crates/prover/src/examples/mod.rs | 1 - crates/prover/src/examples/plonk/mod.rs | 43 +- crates/prover/src/examples/poseidon/mod.rs | 24 +- .../src/examples/wide_fibonacci/component.rs | 297 ------------- .../wide_fibonacci/constraint_eval.rs | 369 ---------------- .../prover/src/examples/wide_fibonacci/mod.rs | 289 ------------ .../src/examples/wide_fibonacci/simd.rs | 293 ------------ .../src/examples/wide_fibonacci/trace_gen.rs | 79 ---- crates/prover/src/lib.rs | 1 - crates/prover/src/trace_generation/mod.rs | 75 ---- crates/prover/src/trace_generation/prove.rs | 417 ------------------ .../prover/src/trace_generation/registry.rs | 178 -------- 24 files changed, 47 insertions(+), 2205 deletions(-) delete mode 100644 crates/prover/src/examples/wide_fibonacci/component.rs delete mode 100644 crates/prover/src/examples/wide_fibonacci/constraint_eval.rs delete mode 100644 crates/prover/src/examples/wide_fibonacci/mod.rs delete mode 100644 crates/prover/src/examples/wide_fibonacci/simd.rs delete mode 100644 crates/prover/src/examples/wide_fibonacci/trace_gen.rs delete mode 100644 crates/prover/src/trace_generation/mod.rs delete mode 100644 crates/prover/src/trace_generation/prove.rs delete mode 100644 crates/prover/src/trace_generation/registry.rs diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 2b92a3648..c0d8319fa 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -20,7 +20,7 @@ use crate::core::fields::FieldExpOps; use crate::core::pcs::{TreeSubspan, TreeVec}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; -use crate::core::{utils, ColumnVec, InteractionElements, LookupValues}; +use crate::core::{utils, ColumnVec}; // TODO(andrew): Docs. // TODO(andrew): Consider better location for this. @@ -121,8 +121,6 @@ impl Component for FrameworkComponent { point: CirclePoint, mask: &TreeVec>>, evaluation_accumulator: &mut PointEvaluationAccumulator, - _interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, ) { self.eval.evaluate(PointEvaluator::new( mask.sub_tree(&self.trace_locations), @@ -137,8 +135,6 @@ impl ComponentProver for FrameworkComponent { &self, trace: &Trace<'_, SimdBackend>, evaluation_accumulator: &mut DomainEvaluationAccumulator, - _interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, ) { let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); let trace_domain = CanonicCoset::new(self.eval.log_size()); @@ -203,10 +199,6 @@ impl ComponentProver for FrameworkComponent { } } } - - fn lookup_values(&self, _trace: &Trace<'_, SimdBackend>) -> LookupValues { - LookupValues::default() - } } impl Deref for FrameworkComponent { diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index 52b87f951..44a461e56 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -227,10 +227,7 @@ impl LogupTraceGenerator { .into_iter() .flat_map(|eval| { eval.columns.map(|c| { - CircleEvaluation::::new( - CanonicCoset::new(self.log_size).circle_domain(), - c, - ) + CircleEvaluation::new(CanonicCoset::new(self.log_size).circle_domain(), c) }) }) .collect_vec(); diff --git a/crates/prover/src/core/air/components.rs b/crates/prover/src/core/air/components.rs index 397468616..320ec35cb 100644 --- a/crates/prover/src/core/air/components.rs +++ b/crates/prover/src/core/air/components.rs @@ -7,7 +7,7 @@ use crate::core::circle::CirclePoint; use crate::core::fields::qm31::SecureField; use crate::core::pcs::TreeVec; use crate::core::poly::circle::SecureCirclePoly; -use crate::core::{ColumnVec, InteractionElements, LookupValues}; +use crate::core::ColumnVec; pub struct Components<'a>(pub Vec<&'a dyn Component>); @@ -32,8 +32,6 @@ impl<'a> Components<'a> { point: CirclePoint, mask_values: &TreeVec>>, random_coeff: SecureField, - interaction_elements: &InteractionElements, - lookup_values: &LookupValues, ) -> SecureField { let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff); for component in &self.0 { @@ -41,8 +39,6 @@ impl<'a> Components<'a> { point, mask_values, &mut evaluation_accumulator, - interaction_elements, - lookup_values, ) } evaluation_accumulator.finalize() @@ -67,8 +63,6 @@ impl<'a, B: Backend> ComponentProvers<'a, B> { &self, random_coeff: SecureField, trace: &Trace<'_, B>, - interaction_elements: &InteractionElements, - lookup_values: &LookupValues, ) -> SecureCirclePoly { let total_constraints: usize = self.0.iter().map(|c| c.n_constraints()).sum(); let mut accumulator = DomainEvaluationAccumulator::new( @@ -77,21 +71,8 @@ impl<'a, B: Backend> ComponentProvers<'a, B> { total_constraints, ); for component in &self.0 { - component.evaluate_constraint_quotients_on_domain( - trace, - &mut accumulator, - interaction_elements, - lookup_values, - ) + component.evaluate_constraint_quotients_on_domain(trace, &mut accumulator) } accumulator.finalize() } - - pub fn lookup_values(&self, trace: &Trace<'_, B>) -> LookupValues { - let mut values = LookupValues::default(); - for component in &self.0 { - values.extend(component.lookup_values(trace)) - } - values - } } diff --git a/crates/prover/src/core/air/mod.rs b/crates/prover/src/core/air/mod.rs index 39834296a..fcdd4d5f8 100644 --- a/crates/prover/src/core/air/mod.rs +++ b/crates/prover/src/core/air/mod.rs @@ -8,7 +8,7 @@ use super::fields::qm31::SecureField; use super::pcs::TreeVec; use super::poly::circle::{CircleEvaluation, CirclePoly}; use super::poly::BitReversedOrder; -use super::{ColumnVec, InteractionElements, LookupValues}; +use super::ColumnVec; pub mod accumulation; mod components; @@ -52,8 +52,6 @@ pub trait Component { point: CirclePoint, mask: &TreeVec>>, evaluation_accumulator: &mut PointEvaluationAccumulator, - interaction_elements: &InteractionElements, - lookup_values: &LookupValues, ); } @@ -64,12 +62,7 @@ pub trait ComponentProver: Component { &self, trace: &Trace<'_, B>, evaluation_accumulator: &mut DomainEvaluationAccumulator, - interaction_elements: &InteractionElements, - lookup_values: &LookupValues, ); - - /// Returns the values needed to evaluate the components lookup boundary constraints. - fn lookup_values(&self, _trace: &Trace<'_, B>) -> LookupValues; } /// The set of polynomials that make up the trace. diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index 3cb664aeb..382d2a14c 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -282,13 +282,8 @@ mod tests { }]; let cpu_columns = columns .iter() - .map(|c| { - CircleEvaluation::::new( - c.domain, - c.values.to_cpu(), - ) - }) - .collect::>(); + .map(|c| CircleEvaluation::new(c.domain, c.values.to_cpu())) + .collect_vec(); let cpu_result = CpuBackend::accumulate_quotients( domain, &cpu_columns.iter().collect_vec(), diff --git a/crates/prover/src/core/mod.rs b/crates/prover/src/core/mod.rs index 1daa2b6e7..a00aad687 100644 --- a/crates/prover/src/core/mod.rs +++ b/crates/prover/src/core/mod.rs @@ -1,10 +1,4 @@ -use std::collections::BTreeMap; -use std::ops::{Deref, DerefMut, Index}; - -use fields::m31::BaseField; -use serde::{Deserialize, Serialize}; - -use self::fields::qm31::SecureField; +use std::ops::{Deref, DerefMut}; pub mod air; pub mod backend; @@ -63,47 +57,3 @@ impl DerefMut for ComponentVec { &mut self.0 } } - -#[derive(Default, Debug)] -pub struct InteractionElements(BTreeMap); - -impl InteractionElements { - pub fn new(elements: BTreeMap) -> Self { - Self(elements) - } - - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } -} - -impl Index<&str> for InteractionElements { - type Output = SecureField; - - fn index(&self, index: &str) -> &Self::Output { - // TODO(AlonH): Return an error if the key is not found. - &self.0[index] - } -} - -#[derive(Default, Debug, Serialize, Deserialize)] -pub struct LookupValues(pub BTreeMap); - -impl LookupValues { - pub fn new(values: BTreeMap) -> Self { - Self(values) - } - - pub fn extend(&mut self, other: Self) { - self.0.extend(other.0); - } -} - -impl Index<&str> for LookupValues { - type Output = BaseField; - - fn index(&self, index: &str) -> &Self::Output { - // TODO(AlonH): Return an error if the key is not found. - &self.0[index] - } -} diff --git a/crates/prover/src/core/pcs/prover.rs b/crates/prover/src/core/pcs/prover.rs index 7a761eaa9..6da991150 100644 --- a/crates/prover/src/core/pcs/prover.rs +++ b/crates/prover/src/core/pcs/prover.rs @@ -167,30 +167,28 @@ pub struct TreeBuilder<'a, 'b, B: BackendForChannel, MC: MerkleChannel> { impl<'a, 'b, B: BackendForChannel, MC: MerkleChannel> TreeBuilder<'a, 'b, B, MC> { pub fn extend_evals( &mut self, - columns: ColumnVec>, + columns: impl IntoIterator>, ) -> TreeSubspan { let span = span!(Level::INFO, "Interpolation for commitment").entered(); - let col_start = self.polys.len(); let polys = columns .into_iter() .map(|eval| eval.interpolate_with_twiddles(self.commitment_scheme.twiddles)) .collect_vec(); span.exit(); - self.polys.extend(polys); - TreeSubspan { - tree_index: self.tree_index, - col_start, - col_end: self.polys.len(), - } + self.extend_polys(polys) } - pub fn extend_polys(&mut self, polys: ColumnVec>) -> TreeSubspan { + pub fn extend_polys( + &mut self, + columns: impl IntoIterator>, + ) -> TreeSubspan { let col_start = self.polys.len(); - self.polys.extend(polys); + self.polys.extend(columns); + let col_end = self.polys.len(); TreeSubspan { tree_index: self.tree_index, col_start, - col_end: self.polys.len(), + col_end, } } diff --git a/crates/prover/src/core/poly/circle/secure_poly.rs b/crates/prover/src/core/poly/circle/secure_poly.rs index a503bd2c6..de6d75d40 100644 --- a/crates/prover/src/core/poly/circle/secure_poly.rs +++ b/crates/prover/src/core/poly/circle/secure_poly.rs @@ -43,6 +43,10 @@ impl SecureCirclePoly { let columns = polys.map(|poly| poly.evaluate_with_twiddles(domain, twiddles).values); SecureEvaluation::new(domain, SecureColumnByCoords { columns }) } + + pub fn into_coordinate_polys(self) -> [CirclePoly; SECURE_EXTENSION_DEGREE] { + self.0 + } } impl> Deref for SecureCirclePoly { diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index d5b9c4c4a..4ad9ac44a 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -11,56 +11,36 @@ use super::fields::secure_column::SECURE_EXTENSION_DEGREE; use super::fri::FriVerificationError; use super::pcs::{CommitmentSchemeProof, TreeVec}; use super::vcs::ops::MerkleHasher; -use super::{InteractionElements, LookupValues}; -use crate::core::backend::CpuBackend; use crate::core::channel::Channel; use crate::core::circle::CirclePoint; use crate::core::fields::qm31::SecureField; use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier}; -use crate::core::poly::circle::CircleEvaluation; -use crate::core::poly::BitReversedOrder; use crate::core::vcs::verifier::MerkleVerificationError; #[derive(Debug, Serialize, Deserialize)] pub struct StarkProof { pub commitments: TreeVec, - pub lookup_values: LookupValues, pub commitment_scheme_proof: CommitmentSchemeProof, } -#[derive(Debug)] -pub struct AdditionalProofData { - pub composition_polynomial_oods_value: SecureField, - pub composition_polynomial_random_coeff: SecureField, - pub oods_point: CirclePoint, - pub oods_quotients: Vec>, -} - pub fn prove, MC: MerkleChannel>( components: &[&dyn ComponentProver], channel: &mut MC::C, - interaction_elements: &InteractionElements, commitment_scheme: &mut CommitmentSchemeProver<'_, B, MC>, ) -> Result, ProvingError> { let component_provers = ComponentProvers(components.to_vec()); let trace = commitment_scheme.trace(); - let lookup_values = component_provers.lookup_values(&trace); // Evaluate and commit on composition polynomial. let random_coeff = channel.draw_felt(); let span = span!(Level::INFO, "Composition").entered(); let span1 = span!(Level::INFO, "Generation").entered(); - let composition_polynomial_poly = component_provers.compute_composition_polynomial( - random_coeff, - &trace, - interaction_elements, - &lookup_values, - ); + let composition_poly = component_provers.compute_composition_polynomial(random_coeff, &trace); span1.exit(); let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_polys(composition_polynomial_poly.to_vec()); + tree_builder.extend_polys(composition_poly.into_coordinate_polys()); tree_builder.commit(channel); span.exit(); @@ -83,20 +63,13 @@ pub fn prove, MC: MerkleChannel>( if composition_oods_eval != component_provers .components() - .eval_composition_polynomial_at_point( - oods_point, - sampled_oods_values, - random_coeff, - interaction_elements, - &lookup_values, - ) + .eval_composition_polynomial_at_point(oods_point, sampled_oods_values, random_coeff) { return Err(ProvingError::ConstraintsNotSatisfied); } Ok(StarkProof { commitments: commitment_scheme.roots(), - lookup_values, commitment_scheme_proof, }) } @@ -104,7 +77,6 @@ pub fn prove, MC: MerkleChannel>( pub fn verify( components: &[&dyn Component], channel: &mut MC::C, - interaction_elements: &InteractionElements, commitment_scheme: &mut CommitmentSchemeVerifier, proof: StarkProof, ) -> Result<(), VerificationError> { @@ -136,8 +108,6 @@ pub fn verify( oods_point, sampled_oods_values, random_coeff, - interaction_elements, - &proof.lookup_values, ) { return Err(VerificationError::OodsNotMatching); diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index 9975f0687..58bb8633b 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -19,7 +19,6 @@ use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConf use crate::core::poly::circle::{CanonicCoset, PolyOps}; use crate::core::prover::{prove, verify, StarkProof, VerificationError}; use crate::core::vcs::ops::MerkleHasher; -use crate::core::InteractionElements; use crate::examples::blake::round::RoundElements; use crate::examples::blake::scheduler::{self, blake_scheduler_info, BlakeElements, BlakeInput}; use crate::examples::blake::{ @@ -390,13 +389,7 @@ where // Prove constraints. let components = BlakeComponents::new(&stmt0, &all_elements, &stmt1); - let stark_proof = prove::( - &components.component_provers(), - channel, - &InteractionElements::default(), - commitment_scheme, - ) - .unwrap(); + let stark_proof = prove(&components.component_provers(), channel, commitment_scheme).unwrap(); BlakeProof { stmt0, @@ -450,7 +443,6 @@ pub fn verify_blake( verify( &components.components(), channel, - &InteractionElements::default(), // Not in use. commitment_scheme, stark_proof, ) diff --git a/crates/prover/src/examples/blake/round/gen.rs b/crates/prover/src/examples/blake/round/gen.rs index ce0966381..9bddcdd40 100644 --- a/crates/prover/src/examples/blake/round/gen.rs +++ b/crates/prover/src/examples/blake/round/gen.rs @@ -230,8 +230,8 @@ pub fn generate_trace( generator .trace .into_iter() - .map(|eval| CircleEvaluation::::new(domain, eval)) - .collect_vec(), + .map(|eval| CircleEvaluation::new(domain, eval)) + .collect(), BlakeRoundLookupData { xor_lookups: generator.xor_lookups, round_lookup: generator.round_lookup, diff --git a/crates/prover/src/examples/blake/scheduler/gen.rs b/crates/prover/src/examples/blake/scheduler/gen.rs index 0581b2fe1..cd6a99b2f 100644 --- a/crates/prover/src/examples/blake/scheduler/gen.rs +++ b/crates/prover/src/examples/blake/scheduler/gen.rs @@ -107,8 +107,8 @@ pub fn gen_trace( let domain = CanonicCoset::new(log_size).circle_domain(); let trace = trace .into_iter() - .map(|eval| CircleEvaluation::::new(domain, eval)) - .collect_vec(); + .map(|eval| CircleEvaluation::new(domain, eval)) + .collect(); (trace, lookup_data, round_inputs) } diff --git a/crates/prover/src/examples/mod.rs b/crates/prover/src/examples/mod.rs index 330662de9..c5e3a4eda 100644 --- a/crates/prover/src/examples/mod.rs +++ b/crates/prover/src/examples/mod.rs @@ -1,5 +1,4 @@ pub mod blake; pub mod plonk; pub mod poseidon; -pub mod wide_fibonacci; pub mod xor; diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index 14ccd1744..1aecd1066 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -19,7 +19,7 @@ use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; use crate::core::prover::{prove, StarkProof}; use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; -use crate::core::{ColumnVec, InteractionElements}; +use crate::core::ColumnVec; pub type PlonkComponent = FrameworkComponent; @@ -108,8 +108,8 @@ pub fn gen_trace( &circuit.c_val, ] .into_iter() - .map(|eval| CircleEvaluation::::new(domain, eval.clone())) - .collect_vec() + .map(|eval| CircleEvaluation::new(domain, eval.clone())) + .collect() } pub fn gen_interaction_trace( @@ -206,17 +206,14 @@ pub fn prove_fibonacci_plonk( // Constant trace. let span = span!(Level::INFO, "Constant").entered(); let mut tree_builder = commitment_scheme.tree_builder(); - let constants_trace_location = tree_builder.extend_evals( - chain!([circuit.a_wire, circuit.b_wire, circuit.c_wire, circuit.op] - .into_iter() - .map(|col| { - CircleEvaluation::::new( - CanonicCoset::new(log_n_rows).circle_domain(), - col, - ) - })) - .collect_vec(), - ); + let constants_trace_location = tree_builder.extend_evals(chain!([ + circuit.a_wire, + circuit.b_wire, + circuit.c_wire, + circuit.op + ] + .into_iter() + .map(|col| CircleEvaluation::new(CanonicCoset::new(log_n_rows).circle_domain(), col)))); tree_builder.commit(channel); span.exit(); @@ -242,13 +239,7 @@ pub fn prove_fibonacci_plonk( component.evaluate(eval); }); - let proof = prove::( - &[&component], - channel, - &InteractionElements::default(), - commitment_scheme, - ) - .unwrap(); + let proof = prove(&[&component], channel, commitment_scheme).unwrap(); (component, proof) } @@ -264,7 +255,6 @@ mod tests { use crate::core::pcs::{CommitmentSchemeVerifier, PcsConfig}; use crate::core::prover::verify; use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - use crate::core::InteractionElements; use crate::examples::plonk::prove_fibonacci_plonk; #[test_log::test] @@ -301,13 +291,6 @@ mod tests { // Constant columns. commitment_scheme.commit(proof.commitments[2], &sizes[2], channel); - verify( - &[&component], - channel, - &InteractionElements::default(), - commitment_scheme, - proof, - ) - .unwrap(); + verify(&[&component], channel, commitment_scheme, proof).unwrap(); } } diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 83e17a88c..c8538cad9 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -24,7 +24,7 @@ use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; use crate::core::prover::{prove, StarkProof}; use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; -use crate::core::{ColumnVec, InteractionElements}; +use crate::core::ColumnVec; const N_LOG_INSTANCES_PER_ROW: usize = 3; const N_INSTANCES_PER_ROW: usize = 1 << N_LOG_INSTANCES_PER_ROW; @@ -278,8 +278,8 @@ pub fn gen_trace( let domain = CanonicCoset::new(log_size).circle_domain(); let trace = trace .into_iter() - .map(|eval| CircleEvaluation::::new(domain, eval)) - .collect_vec(); + .map(|eval| CircleEvaluation::new(domain, eval)) + .collect(); (trace, lookup_data) } @@ -366,13 +366,7 @@ pub fn prove_poseidon( claimed_sum, }, ); - let proof = prove::( - &[&component], - channel, - &InteractionElements::default(), - commitment_scheme, - ) - .unwrap(); + let proof = prove(&[&component], channel, commitment_scheme).unwrap(); (component, proof) } @@ -394,7 +388,6 @@ mod tests { use crate::core::poly::circle::CanonicCoset; use crate::core::prover::verify; use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - use crate::core::InteractionElements; use crate::examples::poseidon::{ apply_internal_round_matrix, apply_m4, eval_poseidon_constraints, gen_interaction_trace, gen_trace, prove_poseidon, PoseidonElements, @@ -510,13 +503,6 @@ mod tests { // Interaction columns. commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); - verify( - &[&component], - channel, - &InteractionElements::default(), - commitment_scheme, - proof, - ) - .unwrap(); + verify(&[&component], channel, commitment_scheme, proof).unwrap(); } } diff --git a/crates/prover/src/examples/wide_fibonacci/component.rs b/crates/prover/src/examples/wide_fibonacci/component.rs deleted file mode 100644 index 4eb08740f..000000000 --- a/crates/prover/src/examples/wide_fibonacci/component.rs +++ /dev/null @@ -1,297 +0,0 @@ -use itertools::Itertools; - -use crate::core::air::accumulation::PointEvaluationAccumulator; -use crate::core::air::mask::fixed_mask_points; -use crate::core::air::{Air, Component}; -use crate::core::backend::cpu::CpuCircleEvaluation; -use crate::core::backend::CpuBackend; -use crate::core::circle::{CirclePoint, Coset}; -use crate::core::constraints::{coset_vanishing, point_excluder, point_vanishing}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; -use crate::core::fields::FieldExpOps; -use crate::core::pcs::TreeVec; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; -use crate::core::poly::BitReversedOrder; -use crate::core::utils::shifted_secure_combination; -use crate::core::{ColumnVec, InteractionElements, LookupValues}; -use crate::examples::wide_fibonacci::trace_gen::write_lookup_column; -use crate::trace_generation::registry::ComponentGenerationRegistry; -use crate::trace_generation::ComponentTraceGenerator; - -pub const LOG_N_COLUMNS: usize = 8; -pub const N_COLUMNS: usize = 1 << LOG_N_COLUMNS; - -pub const ALPHA_ID: &str = "wide_fibonacci_alpha"; -pub const Z_ID: &str = "wide_fibonacci_z"; -pub const LOOKUP_VALUE_0_ID: &str = "wide_fibonacci_0"; -pub const LOOKUP_VALUE_1_ID: &str = "wide_fibonacci_1"; -pub const LOOKUP_VALUE_N_MINUS_2_ID: &str = "wide_fibonacci_n-2"; -pub const LOOKUP_VALUE_N_MINUS_1_ID: &str = "wide_fibonacci_n-1"; - -/// Component that computes 2^`self.log_n_instances` instances of fibonacci sequences of size -/// 2^`self.log_fibonacci_size`. The numbers are computes over [N_COLUMNS] trace columns. The -/// number of rows (i.e the size of the columns) is determined by the parameters above (see -/// [WideFibComponent::log_column_size()]). -#[derive(Clone)] -pub struct WideFibComponent { - pub log_fibonacci_size: u32, - pub log_n_instances: u32, -} - -impl WideFibComponent { - /// Returns the log of the size of the columns in the trace (which could also be looked at as - /// the log number of rows). - pub fn log_column_size(&self) -> u32 { - self.log_n_instances + self.log_fibonacci_size - LOG_N_COLUMNS as u32 - } - - pub fn log_n_columns(&self) -> usize { - LOG_N_COLUMNS - } - - pub fn n_columns(&self) -> usize { - N_COLUMNS - } - - pub fn interaction_element_ids(&self) -> Vec { - vec![ALPHA_ID.to_string(), Z_ID.to_string()] - } - - fn evaluate_trace_boundary_constraints_at_point( - &self, - point: CirclePoint, - mask: &TreeVec>>, - evaluation_accumulator: &mut PointEvaluationAccumulator, - constraint_zero_domain: Coset, - lookup_values: &LookupValues, - ) { - let numerator = mask[0][0][0] - lookup_values[LOOKUP_VALUE_0_ID]; - let denom = point_vanishing(constraint_zero_domain.at(0), point); - evaluation_accumulator.accumulate(numerator / denom); - let numerator = mask[0][1][0] - lookup_values[LOOKUP_VALUE_1_ID]; - evaluation_accumulator.accumulate(numerator / denom); - - let numerator = mask[0][self.n_columns() - 2][0] - lookup_values[LOOKUP_VALUE_N_MINUS_2_ID]; - let denom = point_vanishing( - constraint_zero_domain.at(constraint_zero_domain.size() - 1), - point, - ); - evaluation_accumulator.accumulate(numerator / denom); - let numerator = mask[0][self.n_columns() - 1][0] - lookup_values[LOOKUP_VALUE_N_MINUS_1_ID]; - evaluation_accumulator.accumulate(numerator / denom); - } - - fn evaluate_trace_step_constraints_at_point( - &self, - point: CirclePoint, - mask: &ColumnVec>, - evaluation_accumulator: &mut PointEvaluationAccumulator, - constraint_zero_domain: Coset, - ) { - let denom = coset_vanishing(constraint_zero_domain, point); - let denom_inverse = denom.inverse(); - for i in 0..self.n_columns() - 2 { - let numerator = mask[i][0].square() + mask[i + 1][0].square() - mask[i + 2][0]; - evaluation_accumulator.accumulate(numerator * denom_inverse); - } - } - - fn evaluate_lookup_boundary_constraints_at_point( - &self, - point: CirclePoint, - mask: &TreeVec>>, - evaluation_accumulator: &mut PointEvaluationAccumulator, - constraint_zero_domain: Coset, - interaction_elements: &InteractionElements, - lookup_values: &LookupValues, - ) { - let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]); - let value = SecureField::from_partial_evals(std::array::from_fn(|i| mask[1][i][0])); - let numerator = (value - * shifted_secure_combination( - &[ - mask[0][self.n_columns() - 2][0], - mask[0][self.n_columns() - 1][0], - ], - alpha, - z, - )) - - shifted_secure_combination(&[mask[0][0][0], mask[0][1][0]], alpha, z); - let denom = point_vanishing(constraint_zero_domain.at(0), point); - evaluation_accumulator.accumulate(numerator / denom); - - let numerator = (value - * shifted_secure_combination( - &[ - lookup_values[LOOKUP_VALUE_N_MINUS_2_ID], - lookup_values[LOOKUP_VALUE_N_MINUS_1_ID], - ], - alpha, - z, - )) - - shifted_secure_combination( - &[ - lookup_values[LOOKUP_VALUE_0_ID], - lookup_values[LOOKUP_VALUE_1_ID], - ], - alpha, - z, - ); - let denom = point_vanishing( - constraint_zero_domain.at(constraint_zero_domain.size() - 1), - point, - ); - evaluation_accumulator.accumulate(numerator / denom); - } - - fn evaluate_lookup_step_constraints_at_point( - &self, - point: CirclePoint, - mask: &TreeVec>>, - evaluation_accumulator: &mut PointEvaluationAccumulator, - constraint_zero_domain: Coset, - interaction_elements: &InteractionElements, - ) { - let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]); - let value = SecureField::from_partial_evals(std::array::from_fn(|i| mask[1][i][0])); - let prev_value = SecureField::from_partial_evals(std::array::from_fn(|i| mask[1][i][1])); - let numerator = (value - * shifted_secure_combination( - &[ - mask[0][self.n_columns() - 2][0], - mask[0][self.n_columns() - 1][0], - ], - alpha, - z, - )) - - (prev_value * shifted_secure_combination(&[mask[0][0][0], mask[0][1][0]], alpha, z)); - let denom = coset_vanishing(constraint_zero_domain, point) - / point_excluder(constraint_zero_domain.at(0), point); - evaluation_accumulator.accumulate(numerator / denom); - } -} - -#[derive(Clone)] -pub struct WideFibAir { - pub component: WideFibComponent, -} - -impl Air for WideFibAir { - fn components(&self) -> Vec<&dyn Component> { - vec![&self.component] - } -} - -impl Component for WideFibComponent { - fn n_constraints(&self) -> usize { - self.n_columns() + 5 - } - - fn max_constraint_log_degree_bound(&self) -> u32 { - self.log_column_size() + 1 - } - - fn trace_log_degree_bounds(&self) -> TreeVec> { - TreeVec::new(vec![ - vec![self.log_column_size(); self.n_columns()], - vec![self.log_column_size(); SECURE_EXTENSION_DEGREE], - ]) - } - - fn mask_points( - &self, - point: CirclePoint, - ) -> TreeVec>>> { - let domain = CanonicCoset::new(self.log_column_size()); - TreeVec::new(vec![ - fixed_mask_points(&vec![vec![0_usize]; self.n_columns()], point), - vec![vec![point, point - domain.step().into_ef()]; SECURE_EXTENSION_DEGREE], - ]) - } - - fn evaluate_constraint_quotients_at_point( - &self, - point: CirclePoint, - mask: &TreeVec>>, - evaluation_accumulator: &mut PointEvaluationAccumulator, - interaction_elements: &InteractionElements, - lookup_values: &LookupValues, - ) { - let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset; - self.evaluate_trace_boundary_constraints_at_point( - point, - mask, - evaluation_accumulator, - constraint_zero_domain, - lookup_values, - ); - self.evaluate_lookup_step_constraints_at_point( - point, - mask, - evaluation_accumulator, - constraint_zero_domain, - interaction_elements, - ); - self.evaluate_lookup_boundary_constraints_at_point( - point, - mask, - evaluation_accumulator, - constraint_zero_domain, - interaction_elements, - lookup_values, - ); - self.evaluate_trace_step_constraints_at_point( - point, - &mask[0], - evaluation_accumulator, - constraint_zero_domain, - ); - } -} - -impl ComponentTraceGenerator for WideFibComponent { - type Component = Self; - type Inputs = (); - - fn add_inputs(&mut self, _inputs: &Self::Inputs) {} - - fn write_trace( - _component_id: &str, - _registry: &mut ComponentGenerationRegistry, - ) -> ColumnVec> { - vec![] - } - - fn write_interaction_trace( - &self, - trace: &ColumnVec<&CircleEvaluation>, - elements: &InteractionElements, - ) -> ColumnVec> { - let trace_values = trace.iter().map(|eval| &eval.values[..]).collect_vec(); - let (alpha, z) = (elements[ALPHA_ID], elements[Z_ID]); - // TODO(AlonH): Return a secure column directly. - let values = write_lookup_column(&trace_values, alpha, z); - let secure_column: SecureColumnByCoords = values.into_iter().collect(); - secure_column - .columns - .into_iter() - .map(|eval| { - let coset = CanonicCoset::new(trace[0].domain.log_size()); - CpuCircleEvaluation::new_canonical_ordered(coset, eval) - }) - .collect_vec() - } - - fn component(&self) -> Self::Component { - self.clone() - } -} - -// Input for the fibonacci claim. -#[derive(Debug, Clone, Copy)] -pub struct Input { - pub a: BaseField, - pub b: BaseField, -} diff --git a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs b/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs deleted file mode 100644 index bd3f9025f..000000000 --- a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs +++ /dev/null @@ -1,369 +0,0 @@ -use std::collections::BTreeMap; - -use itertools::{zip_eq, Itertools}; -use num_traits::Zero; - -use super::component::{ - Input, WideFibAir, WideFibComponent, ALPHA_ID, LOOKUP_VALUE_0_ID, LOOKUP_VALUE_1_ID, - LOOKUP_VALUE_N_MINUS_1_ID, LOOKUP_VALUE_N_MINUS_2_ID, Z_ID, -}; -use super::trace_gen::write_trace_row; -use crate::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator}; -use crate::core::air::{AirProver, Component, ComponentProver, Trace}; -use crate::core::backend::CpuBackend; -use crate::core::channel::Channel; -use crate::core::circle::Coset; -use crate::core::constraints::{coset_vanishing, point_excluder}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::FieldExpOps; -use crate::core::pcs::TreeVec; -use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation}; -use crate::core::poly::BitReversedOrder; -use crate::core::prover::VerificationError; -use crate::core::utils::{ - bit_reverse, point_vanish_denominator_inverses, previous_bit_reversed_circle_domain_index, - shifted_secure_combination, -}; -use crate::core::{ColumnVec, InteractionElements, LookupValues}; -use crate::examples::wide_fibonacci::component::LOG_N_COLUMNS; -use crate::trace_generation::{ - AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator, BASE_TRACE, INTERACTION_TRACE, -}; - -// TODO(AlonH): Rename file to `cpu.rs`. - -impl AirTraceVerifier for WideFibAir { - fn interaction_elements(&self, channel: &mut impl Channel) -> InteractionElements { - let ids = self.component.interaction_element_ids(); - let elements = channel.draw_felts(ids.len()); - InteractionElements::new(BTreeMap::from_iter(zip_eq(ids, elements))) - } - - fn verify_lookups(&self, _lookup_values: &LookupValues) -> Result<(), VerificationError> { - Ok(()) - } -} - -impl AirTraceGenerator for WideFibAir { - fn interact( - &self, - trace: &ColumnVec>, - elements: &InteractionElements, - ) -> Vec> { - self.component - .write_interaction_trace(&trace.iter().collect(), elements) - } - - fn to_air_prover(&self) -> impl AirProver { - self.clone() - } - - fn composition_log_degree_bound(&self) -> u32 { - self.component.max_constraint_log_degree_bound() - } -} - -impl AirProver for WideFibAir { - fn component_provers(&self) -> Vec<&dyn ComponentProver> { - vec![&self.component] - } -} - -impl WideFibComponent { - fn evaluate_trace_boundary_constraints( - &self, - trace_evals: &TreeVec>>, - trace_eval_domain: CircleDomain, - zero_domain: Coset, - accum: &mut ColumnAccumulator<'_, CpuBackend>, - lookup_values: &LookupValues, - ) { - let first_point_denom_inverses = - point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0)); - let last_point_denom_inverses = point_vanish_denominator_inverses( - trace_eval_domain, - zero_domain.at(zero_domain.size() - 1), - ); - let (lookup_value_0, lookup_value_1, lookup_value_n_minus_2, lookup_value_n_minus_1) = ( - lookup_values[LOOKUP_VALUE_0_ID], - lookup_values[LOOKUP_VALUE_1_ID], - lookup_values[LOOKUP_VALUE_N_MINUS_2_ID], - lookup_values[LOOKUP_VALUE_N_MINUS_1_ID], - ); - - for (i, (first_point_denom_inverse, last_point_denom_inverse)) in - zip_eq(first_point_denom_inverses, last_point_denom_inverses).enumerate() - { - let first_point_numerator = accum.random_coeff_powers[self.n_columns() + 4] - * (trace_evals[BASE_TRACE][0][i] - lookup_value_0) - + accum.random_coeff_powers[self.n_columns() + 3] - * (trace_evals[BASE_TRACE][1][i] - lookup_value_1); - let last_point_numerator = accum.random_coeff_powers[self.n_columns() + 2] - * (trace_evals[BASE_TRACE][self.n_columns() - 2][i] - lookup_value_n_minus_2) - + accum.random_coeff_powers[self.n_columns() + 1] - * (trace_evals[BASE_TRACE][self.n_columns() - 1][i] - lookup_value_n_minus_1); - accum.accumulate( - i, - first_point_numerator * first_point_denom_inverse - + last_point_numerator * last_point_denom_inverse, - ); - } - } - - fn evaluate_trace_step_constraints( - &self, - trace_evals: &TreeVec>>, - trace_eval_domain: CircleDomain, - zero_domain: Coset, - accum: &mut ColumnAccumulator<'_, CpuBackend>, - ) { - let max_constraint_degree = self.max_constraint_log_degree_bound(); - let mut denoms = vec![]; - for point in trace_eval_domain.iter() { - denoms.push(coset_vanishing(zero_domain, point)); - } - bit_reverse(&mut denoms); - let mut denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)]; - BaseField::batch_inverse(&denoms, &mut denom_inverses); - - for (i, denom_inverse) in denom_inverses.iter().enumerate() { - let mut numerator = SecureField::zero(); - for j in 0..self.n_columns() - 2 { - numerator += accum.random_coeff_powers[self.n_columns() - 3 - j] - * (trace_evals[BASE_TRACE][j][i].square() - + trace_evals[BASE_TRACE][j + 1][i].square() - - trace_evals[BASE_TRACE][j + 2][i]); - } - accum.accumulate(i, numerator * *denom_inverse) - } - } - - fn evaluate_lookup_boundary_constraints( - &self, - trace_evals: &TreeVec>>, - trace_eval_domain: CircleDomain, - zero_domain: Coset, - accum: &mut ColumnAccumulator<'_, CpuBackend>, - interaction_elements: &InteractionElements, - lookup_values: &LookupValues, - ) { - let first_point_denom_inverses = - point_vanish_denominator_inverses(trace_eval_domain, zero_domain.at(0)); - let last_point_denom_inverses = point_vanish_denominator_inverses( - trace_eval_domain, - zero_domain.at(zero_domain.size() - 1), - ); - let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]); - let (lookup_value_0, lookup_value_1, lookup_value_n_minus_2, lookup_value_n_minus_1) = ( - lookup_values[LOOKUP_VALUE_0_ID], - lookup_values[LOOKUP_VALUE_1_ID], - lookup_values[LOOKUP_VALUE_N_MINUS_2_ID], - lookup_values[LOOKUP_VALUE_N_MINUS_1_ID], - ); - - for (i, (first_point_denom_inverse, last_point_denom_inverse)) in - zip_eq(first_point_denom_inverses, last_point_denom_inverses).enumerate() - { - let value = SecureField::from_m31_array(std::array::from_fn(|j| { - trace_evals[INTERACTION_TRACE][j][i] - })); - let first_point_numerator = accum.random_coeff_powers[self.n_columns() - 1] - * ((value - * shifted_secure_combination( - &[ - trace_evals[BASE_TRACE][self.n_columns() - 2][i], - trace_evals[BASE_TRACE][self.n_columns() - 1][i], - ], - alpha, - z, - )) - - shifted_secure_combination( - &[trace_evals[BASE_TRACE][0][i], trace_evals[BASE_TRACE][1][i]], - alpha, - z, - )); - let last_point_numerator = accum.random_coeff_powers[self.n_columns() - 2] - * ((value - * shifted_secure_combination( - &[lookup_value_n_minus_2, lookup_value_n_minus_1], - alpha, - z, - )) - - shifted_secure_combination(&[lookup_value_0, lookup_value_1], alpha, z)); - accum.accumulate( - i, - first_point_numerator * first_point_denom_inverse - + last_point_numerator * last_point_denom_inverse, - ); - } - } - - // TODO(AlonH): Simplify this function by using utility functions. - fn evaluate_lookup_step_constraints( - &self, - trace_evals: &TreeVec>>, - trace_eval_domain: CircleDomain, - zero_domain: Coset, - accum: &mut ColumnAccumulator<'_, CpuBackend>, - interaction_elements: &InteractionElements, - ) { - let max_constraint_degree = self.max_constraint_log_degree_bound(); - let mut denoms = vec![]; - for point in trace_eval_domain.iter() { - denoms.push( - coset_vanishing(zero_domain, point) / point_excluder(zero_domain.at(0), point), - ); - } - bit_reverse(&mut denoms); - let mut denom_inverses = vec![BaseField::zero(); 1 << (max_constraint_degree)]; - BaseField::batch_inverse(&denoms, &mut denom_inverses); - let (alpha, z) = (interaction_elements[ALPHA_ID], interaction_elements[Z_ID]); - - for (i, denom_inverse) in denom_inverses.iter().enumerate() { - let value = SecureField::from_m31_array(std::array::from_fn(|j| { - trace_evals[INTERACTION_TRACE][j][i] - })); - let prev_index = previous_bit_reversed_circle_domain_index( - i, - zero_domain.log_size, - trace_eval_domain.log_size(), - ); - let prev_value = SecureField::from_m31_array(std::array::from_fn(|j| { - trace_evals[INTERACTION_TRACE][j][prev_index] - })); - let numerator = accum.random_coeff_powers[self.n_columns()] - * ((value - * shifted_secure_combination( - &[ - trace_evals[BASE_TRACE][self.n_columns() - 2][i], - trace_evals[BASE_TRACE][self.n_columns() - 1][i], - ], - alpha, - z, - )) - - (prev_value - * shifted_secure_combination( - &[trace_evals[BASE_TRACE][0][i], trace_evals[BASE_TRACE][1][i]], - alpha, - z, - ))); - accum.accumulate(i, numerator * *denom_inverse); - } - } -} - -impl ComponentProver for WideFibComponent { - fn evaluate_constraint_quotients_on_domain( - &self, - trace: &Trace<'_, CpuBackend>, - evaluation_accumulator: &mut DomainEvaluationAccumulator, - interaction_elements: &InteractionElements, - lookup_values: &LookupValues, - ) { - let max_constraint_degree = self.max_constraint_log_degree_bound(); - let trace_eval_domain = CanonicCoset::new(max_constraint_degree).circle_domain(); - let trace_evals = &trace.evals; - let zero_domain = CanonicCoset::new(self.log_column_size()).coset; - let [mut accum] = - evaluation_accumulator.columns([(max_constraint_degree, self.n_constraints())]); - - // TODO(AlonH): Evaluate the numerators together and the denominators together (i.e. in the - // same for loop) - self.evaluate_trace_boundary_constraints( - trace_evals, - trace_eval_domain, - zero_domain, - &mut accum, - lookup_values, - ); - self.evaluate_lookup_step_constraints( - trace_evals, - trace_eval_domain, - zero_domain, - &mut accum, - interaction_elements, - ); - self.evaluate_lookup_boundary_constraints( - trace_evals, - trace_eval_domain, - zero_domain, - &mut accum, - interaction_elements, - lookup_values, - ); - self.evaluate_trace_step_constraints( - trace_evals, - trace_eval_domain, - zero_domain, - &mut accum, - ); - } - - fn lookup_values(&self, trace: &Trace<'_, CpuBackend>) -> LookupValues { - let domain = CanonicCoset::new(self.log_column_size()); - let trace_poly = &trace.polys[BASE_TRACE]; - let values = BTreeMap::from_iter([ - ( - LOOKUP_VALUE_0_ID.to_string(), - trace_poly[0] - .eval_at_point(domain.at(0).into_ef()) - .try_into() - .unwrap(), - ), - ( - LOOKUP_VALUE_1_ID.to_string(), - trace_poly[1] - .eval_at_point(domain.at(0).into_ef()) - .try_into() - .unwrap(), - ), - ( - LOOKUP_VALUE_N_MINUS_2_ID.to_string(), - trace_poly[self.n_columns() - 2] - .eval_at_point(domain.at(domain.size() - 1).into_ef()) - .try_into() - .unwrap(), - ), - ( - LOOKUP_VALUE_N_MINUS_1_ID.to_string(), - trace_poly[self.n_columns() - 1] - .eval_at_point(domain.at(domain.size() - 1).into_ef()) - .try_into() - .unwrap(), - ), - ]); - LookupValues::new(values) - } -} - -/// Generates the trace for the wide Fibonacci example. -pub fn gen_trace( - wide_fib: &WideFibComponent, - private_input: Vec, -) -> ColumnVec> { - let n_instances = 1 << wide_fib.log_n_instances; - assert_eq!( - private_input.len(), - n_instances, - "The number of inputs must match the number of instances" - ); - assert!( - wide_fib.log_fibonacci_size >= LOG_N_COLUMNS as u32, - "The fibonacci size must be at least equal to the length of a row" - ); - let n_rows_per_instance = 1 << (wide_fib.log_fibonacci_size - wide_fib.log_n_columns() as u32); - let n_rows = n_instances * n_rows_per_instance; - let zero_vec = vec![BaseField::zero(); n_rows]; - let mut dst = vec![zero_vec; wide_fib.n_columns()]; - (0..n_rows_per_instance).fold(private_input, |input, row| { - (0..n_instances) - .map(|instance| { - let (a, b) = - write_trace_row(&mut dst, &input[instance], row * n_instances + instance); - Input { a, b } - }) - .collect_vec() - }); - dst -} diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs deleted file mode 100644 index 15afa3294..000000000 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ /dev/null @@ -1,289 +0,0 @@ -pub mod component; -pub mod constraint_eval; -pub mod simd; -pub mod trace_gen; - -#[cfg(test)] -mod tests { - use std::collections::BTreeMap; - - use itertools::Itertools; - use num_traits::{One, Zero}; - - use super::component::{Input, WideFibAir, WideFibComponent, LOG_N_COLUMNS}; - use super::constraint_eval::gen_trace; - use crate::core::air::accumulation::DomainEvaluationAccumulator; - use crate::core::air::{Component, ComponentProver, Trace}; - use crate::core::backend::cpu::CpuCircleEvaluation; - use crate::core::backend::CpuBackend; - use crate::core::channel::Blake2sChannel; - #[cfg(not(target_arch = "wasm32"))] - use crate::core::channel::Poseidon252Channel; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::pcs::{PcsConfig, TreeVec}; - use crate::core::poly::circle::CanonicCoset; - use crate::core::utils::{ - bit_reverse, circle_domain_order_to_coset_order, shifted_secure_combination, - }; - use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - #[cfg(not(target_arch = "wasm32"))] - use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel; - use crate::core::InteractionElements; - use crate::examples::wide_fibonacci::trace_gen::write_lookup_column; - use crate::trace_generation::{commit_and_prove, commit_and_verify, ComponentTraceGenerator}; - use crate::{m31, qm31}; - - pub fn assert_constraints_on_row(row: &[BaseField]) { - for i in 2..row.len() { - assert_eq!( - (row[i] - (row[i - 1] * row[i - 1] + row[i - 2] * row[i - 2])), - BaseField::zero() - ); - } - } - - pub fn assert_constraints_on_lookup_column( - column: &[SecureField], - input_trace: &[Vec], - alpha: SecureField, - z: SecureField, - ) { - let n_columns = input_trace.len(); - let column_length = column.len(); - assert_eq!(column_length, input_trace[0].len()); - let mut prev_value = SecureField::one(); - for (i, cell) in column.iter().enumerate() { - assert_eq!( - *cell - * shifted_secure_combination( - &[input_trace[n_columns - 2][i], input_trace[n_columns - 1][i]], - alpha, - z, - ), - shifted_secure_combination(&[input_trace[0][i], input_trace[1][i]], alpha, z) - * prev_value - ); - prev_value = *cell; - } - - // Assert the last cell in the column is equal to the combination of the first two values - // divided by the combination of the last two values in the sequence (all other values - // should cancel out). - assert_eq!( - column[column_length - 1] - * shifted_secure_combination( - &[input_trace[n_columns - 2][1], input_trace[n_columns - 1][1]], - alpha, - z, - ), - (shifted_secure_combination(&[input_trace[0][0], input_trace[1][0]], alpha, z)) - ); - } - - #[test] - fn test_trace_row_constraints() { - let wide_fib = WideFibComponent { - log_fibonacci_size: LOG_N_COLUMNS as u32, - log_n_instances: 1, - }; - let input = Input { - a: m31!(0x76), - b: m31!(0x483), - }; - - let trace = gen_trace(&wide_fib, vec![input, input]); - let row_0 = trace.iter().map(|col| col[0]).collect_vec(); - let row_1 = trace.iter().map(|col| col[1]).collect_vec(); - - assert_constraints_on_row(&row_0); - assert_constraints_on_row(&row_1); - } - - #[test] - fn test_lookup_column_constraints() { - let wide_fib = WideFibComponent { - log_fibonacci_size: 4 + LOG_N_COLUMNS as u32, - log_n_instances: 0, - }; - let input = Input { - a: m31!(1), - b: m31!(1), - }; - - let alpha = qm31!(7, 1, 3, 4); - let z = qm31!(11, 1, 2, 3); - let mut trace = gen_trace(&wide_fib, vec![input]); - let input_trace = trace.iter().map(|values| &values[..]).collect_vec(); - let lookup_column = write_lookup_column(&input_trace, alpha, z); - - trace = trace - .iter_mut() - .map(|column| { - bit_reverse(column); - circle_domain_order_to_coset_order(column) - }) - .collect_vec(); - assert_constraints_on_lookup_column(&lookup_column, &trace, alpha, z) - } - - #[test] - fn test_composition_is_low_degree() { - let wide_fib = WideFibComponent { - log_fibonacci_size: 3 + LOG_N_COLUMNS as u32, - log_n_instances: 0, - }; - let random_coeff = qm31!(1, 2, 3, 4); - let mut acc = DomainEvaluationAccumulator::new( - random_coeff, - wide_fib.max_constraint_log_degree_bound(), - wide_fib.n_constraints(), - ); - let inputs = (0..1 << wide_fib.log_n_instances) - .map(|i| Input { - a: m31!(1), - b: m31!(i + 1_u32), - }) - .collect_vec(); - - let trace_values = gen_trace(&wide_fib, inputs); - - let trace_domain = CanonicCoset::new(wide_fib.log_column_size()); - let trace = trace_values - .into_iter() - .map(|eval| CpuCircleEvaluation::new_canonical_ordered(trace_domain, eval)) - .collect_vec(); - let trace_polys = trace - .clone() - .into_iter() - .map(|eval| eval.interpolate()) - .collect_vec(); - let eval_domain = - CanonicCoset::new(wide_fib.max_constraint_log_degree_bound()).circle_domain(); - let trace_evals = trace_polys - .iter() - .map(|poly| poly.evaluate(eval_domain)) - .collect_vec(); - - let interaction_elements = InteractionElements::new(BTreeMap::from_iter( - wide_fib - .interaction_element_ids() - .iter() - .cloned() - .enumerate() - .map(|(i, id)| (id, qm31!(43 + i as u32, 1, 2, 3))), - )); - let interaction_poly = wide_fib - .write_interaction_trace(&trace.iter().collect(), &interaction_elements) - .into_iter() - .map(|eval| eval.interpolate()) - .collect_vec(); - - let interaction_trace = interaction_poly - .iter() - .map(|poly| poly.evaluate(eval_domain)) - .collect_vec(); - let trace = Trace { - polys: TreeVec::new(vec![ - trace_polys.iter().collect_vec(), - interaction_poly.iter().collect_vec(), - ]), - evals: TreeVec::new(vec![ - trace_evals.iter().collect_vec(), - interaction_trace.iter().collect_vec(), - ]), - }; - - let lookup_values = wide_fib.lookup_values(&trace); - wide_fib.evaluate_constraint_quotients_on_domain( - &trace, - &mut acc, - &interaction_elements, - &lookup_values, - ); - - let res = acc.finalize(); - let poly = res.0[0].clone(); - for coeff in poly.coeffs[(1 << wide_fib.max_constraint_log_degree_bound()) - 1..].iter() { - assert_eq!(*coeff, BaseField::zero()); - } - } - - #[test_log::test] - fn test_single_instance_wide_fib_prove() { - // Note: To see time measurement, run test with - // RUST_LOG_SPAN_EVENTS=enter,close RUST_LOG=info RUST_BACKTRACE=1 cargo test - // test_prove -- --nocapture - - const LOG_N_INSTANCES: u32 = 0; - let config = PcsConfig::default(); - let component = WideFibComponent { - log_fibonacci_size: 3 + LOG_N_COLUMNS as u32, - log_n_instances: LOG_N_INSTANCES, - }; - let private_input = (0..(1 << LOG_N_INSTANCES)) - .map(|i| Input { - a: m31!(1), - b: m31!(i), - }) - .collect(); - let trace = gen_trace(&component, private_input); - - let trace_domain = CanonicCoset::new(component.log_column_size()); - let trace = trace - .into_iter() - .map(|eval| CpuCircleEvaluation::new_canonical_ordered(trace_domain, eval)) - .collect_vec(); - let air = WideFibAir { component }; - let prover_channel = &mut Blake2sChannel::default(); - let proof = commit_and_prove::( - &air, - prover_channel, - trace, - config, - ) - .unwrap(); - - let verifier_channel = &mut Blake2sChannel::default(); - commit_and_verify::(proof, &air, verifier_channel, config).unwrap(); - } - - #[cfg(not(target_arch = "wasm32"))] - #[test_log::test] - fn test_single_instance_wide_fib_prove_with_poseidon() { - use crate::core::backend::CpuBackend; - - const LOG_N_INSTANCES: u32 = 0; - let config = PcsConfig::default(); - let component = WideFibComponent { - log_fibonacci_size: 3 + LOG_N_COLUMNS as u32, - log_n_instances: LOG_N_INSTANCES, - }; - let private_input = (0..(1 << LOG_N_INSTANCES)) - .map(|i| Input { - a: m31!(1), - b: m31!(i), - }) - .collect(); - let trace = gen_trace(&component, private_input); - - let trace_domain = CanonicCoset::new(component.log_column_size()); - let trace = trace - .into_iter() - .map(|eval| CpuCircleEvaluation::new_canonical_ordered(trace_domain, eval)) - .collect_vec(); - let air = WideFibAir { component }; - let prover_channel = &mut Poseidon252Channel::default(); - let proof = commit_and_prove::( - &air, - prover_channel, - trace, - config, - ) - .unwrap(); - - let verifier_channel = &mut Poseidon252Channel::default(); - commit_and_verify::(proof, &air, verifier_channel, config) - .unwrap(); - } -} diff --git a/crates/prover/src/examples/wide_fibonacci/simd.rs b/crates/prover/src/examples/wide_fibonacci/simd.rs deleted file mode 100644 index a7ed4dcec..000000000 --- a/crates/prover/src/examples/wide_fibonacci/simd.rs +++ /dev/null @@ -1,293 +0,0 @@ -use itertools::Itertools; -use num_traits::{One, Zero}; -use tracing::{span, Level}; - -use super::component::LOG_N_COLUMNS; -use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use crate::core::air::mask::fixed_mask_points; -use crate::core::air::{Air, AirProver, Component, ComponentProver, Trace}; -use crate::core::backend::simd::column::BaseColumn; -use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; -use crate::core::backend::simd::qm31::PackedSecureField; -use crate::core::backend::simd::SimdBackend; -use crate::core::backend::{Col, Column, ColumnOps}; -use crate::core::channel::Channel; -use crate::core::circle::CirclePoint; -use crate::core::constraints::coset_vanishing; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::{FieldExpOps, FieldOps}; -use crate::core::pcs::TreeVec; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; -use crate::core::poly::BitReversedOrder; -use crate::core::prover::VerificationError; -use crate::core::{ColumnVec, InteractionElements, LookupValues}; -use crate::examples::wide_fibonacci::component::N_COLUMNS; -use crate::trace_generation::registry::ComponentGenerationRegistry; -use crate::trace_generation::{ - AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator, BASE_TRACE, -}; - -// TODO(AlonH): Remove this once the Cpu and Simd implementations are aligned. -#[derive(Clone)] -pub struct SimdWideFibComponent { - pub log_fibonacci_size: u32, - pub log_n_instances: u32, -} - -impl SimdWideFibComponent { - /// Returns the log of the size of the columns in the trace (which could also be looked at as - /// the log number of rows). - pub fn log_column_size(&self) -> u32 { - self.log_n_instances + self.log_fibonacci_size - LOG_N_COLUMNS as u32 - } - - pub fn log_n_columns(&self) -> usize { - LOG_N_COLUMNS - } - - pub fn n_columns(&self) -> usize { - N_COLUMNS - } -} - -// TODO(AlonH): Remove this once the Cpu and Simd implementations are aligned. -#[derive(Clone)] -pub struct SimdWideFibAir { - pub component: SimdWideFibComponent, -} - -impl Air for SimdWideFibAir { - fn components(&self) -> Vec<&dyn Component> { - vec![&self.component] - } -} - -impl AirTraceVerifier for SimdWideFibAir { - fn interaction_elements(&self, _channel: &mut impl Channel) -> InteractionElements { - InteractionElements::default() - } - - fn verify_lookups(&self, _lookup_values: &LookupValues) -> Result<(), VerificationError> { - Ok(()) - } -} - -impl AirTraceGenerator for SimdWideFibAir { - fn interact( - &self, - _trace: &ColumnVec>, - _elements: &InteractionElements, - ) -> Vec> { - vec![] - } - - fn to_air_prover(&self) -> impl AirProver { - self.clone() - } - - fn composition_log_degree_bound(&self) -> u32 { - self.component.max_constraint_log_degree_bound() - } -} - -impl Component for SimdWideFibComponent { - fn n_constraints(&self) -> usize { - self.n_columns() - 2 - } - - fn max_constraint_log_degree_bound(&self) -> u32 { - self.log_column_size() + 1 - } - - fn trace_log_degree_bounds(&self) -> TreeVec> { - TreeVec::new(vec![vec![self.log_column_size(); self.n_columns()]]) - } - - fn mask_points( - &self, - point: CirclePoint, - ) -> TreeVec>>> { - TreeVec::new(vec![fixed_mask_points( - &vec![vec![0_usize]; self.n_columns()], - point, - )]) - } - - fn evaluate_constraint_quotients_at_point( - &self, - point: CirclePoint, - mask: &TreeVec>>, - evaluation_accumulator: &mut PointEvaluationAccumulator, - _interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, - ) { - let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset; - let denom = coset_vanishing(constraint_zero_domain, point); - let denom_inverse = denom.inverse(); - for i in 0..self.n_columns() - 2 { - let numerator = mask[0][i][0].square() + mask[0][i + 1][0].square() - mask[0][i + 2][0]; - evaluation_accumulator.accumulate(numerator * denom_inverse); - } - } -} - -impl AirProver for SimdWideFibAir { - fn component_provers(&self) -> Vec<&dyn ComponentProver> { - vec![&self.component] - } -} - -pub fn gen_trace( - log_size: u32, -) -> ColumnVec> { - assert!(log_size >= LOG_N_LANES); - let mut trace = (0..N_COLUMNS) - .map(|_| Col::::zeros(1 << log_size)) - .collect_vec(); - for vec_index in 0..(1 << (log_size - LOG_N_LANES)) { - let mut a = PackedBaseField::one(); - let mut b = PackedBaseField::from_array(std::array::from_fn(|i| { - BaseField::from_u32_unchecked((vec_index * 16 + i) as u32) - })); - trace[0].data[vec_index] = a; - trace[1].data[vec_index] = b; - trace.iter_mut().skip(2).for_each(|col| { - (a, b) = (b, a.square() + b.square()); - col.data[vec_index] = b; - }); - } - let domain = CanonicCoset::new(log_size).circle_domain(); - trace - .into_iter() - .map(|eval| CircleEvaluation::::new(domain, eval)) - .collect_vec() -} - -impl ComponentTraceGenerator for SimdWideFibComponent { - type Component = Self; - type Inputs = (); - - fn add_inputs(&mut self, _inputs: &Self::Inputs) {} - - fn write_trace( - _component_id: &str, - _registry: &mut ComponentGenerationRegistry, - ) -> ColumnVec> { - vec![] - } - - fn write_interaction_trace( - &self, - _trace: &ColumnVec<&CircleEvaluation>, - _elements: &InteractionElements, - ) -> ColumnVec> { - vec![] - } - - fn component(&self) -> Self::Component { - self.clone() - } -} - -impl ComponentProver for SimdWideFibComponent { - fn evaluate_constraint_quotients_on_domain( - &self, - trace: &Trace<'_, SimdBackend>, - evaluation_accumulator: &mut DomainEvaluationAccumulator, - _interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, - ) { - assert_eq!(trace.polys[BASE_TRACE].len(), self.n_columns()); - // TODO(spapini): Steal evaluation from commitment. - let eval_domain = CanonicCoset::new(self.log_column_size() + 1).circle_domain(); - let trace_eval = &trace.evals; - - // Denoms. - let span = span!(Level::INFO, "Constraint eval denominators").entered(); - // TODO(spapini): Make this prettier. - let zero_domain = CanonicCoset::new(self.log_column_size()).coset; - let mut denoms = - BaseColumn::from_iter(eval_domain.iter().map(|p| coset_vanishing(zero_domain, p))); - >::bit_reverse_column(&mut denoms); - let mut denom_inverses = BaseColumn::zeros(denoms.len()); - >::batch_inverse(&denoms, &mut denom_inverses); - span.exit(); - - let _span = span!(Level::INFO, "Constraint pointwise eval").entered(); - - let constraint_log_degree_bound = self.max_constraint_log_degree_bound(); - let n_constraints = self.n_constraints(); - let [accum] = - evaluation_accumulator.columns([(constraint_log_degree_bound, n_constraints)]); - - for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES)) { - // Numerator. - let a = trace_eval[BASE_TRACE][0].data[vec_row]; - let mut row_res = PackedSecureField::zero(); - let mut a_sq = a.square(); - let mut b_sq = trace_eval[BASE_TRACE][1].data[vec_row].square(); - #[allow(clippy::needless_range_loop)] - for i in 0..(self.n_columns() - 2) { - unsafe { - let c = *trace_eval[BASE_TRACE] - .get_unchecked(i + 2) - .data - .get_unchecked(vec_row); - row_res += PackedSecureField::broadcast( - accum.random_coeff_powers[self.n_columns() - 3 - i], - ) * (a_sq + b_sq - c); - (a_sq, b_sq) = (b_sq, c.square()); - } - } - - unsafe { - accum.col.set_packed( - vec_row, - accum.col.packed_at(vec_row) + row_res * denom_inverses.data[vec_row], - ) - } - } - } - - fn lookup_values(&self, _trace: &Trace<'_, SimdBackend>) -> LookupValues { - LookupValues::default() - } -} - -#[cfg(test)] -mod tests { - use tracing::{span, Level}; - - use crate::core::channel::Blake2sChannel; - use crate::core::pcs::PcsConfig; - use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - use crate::examples::wide_fibonacci::component::LOG_N_COLUMNS; - use crate::examples::wide_fibonacci::simd::{gen_trace, SimdWideFibAir, SimdWideFibComponent}; - use crate::trace_generation::{commit_and_prove, commit_and_verify}; - - #[test_log::test] - fn test_simd_wide_fib_prove() { - // Note: To see time measurement, run test with - // RUST_LOG_SPAN_EVENTS=enter,close RUST_LOG=info RUST_BACKTRACE=1 RUSTFLAGS=" - // -C target-cpu=native -C target-feature=+avx512f -C opt-level=3" cargo test - // test_simd_wide_fib_prove -- --nocapture - - // Note: 17 means 128MB of trace. - const LOG_N_ROWS: u32 = 12; - let config = PcsConfig::default(); - let component = SimdWideFibComponent { - log_fibonacci_size: LOG_N_COLUMNS as u32, - log_n_instances: LOG_N_ROWS, - }; - let span = span!(Level::INFO, "Trace generation").entered(); - let trace = gen_trace(component.log_column_size()); - span.exit(); - let channel = &mut Blake2sChannel::default(); - let air = SimdWideFibAir { component }; - let proof = commit_and_prove(&air, channel, trace, config).unwrap(); - - let channel = &mut Blake2sChannel::default(); - commit_and_verify::(proof, &air, channel, config).unwrap(); - } -} diff --git a/crates/prover/src/examples/wide_fibonacci/trace_gen.rs b/crates/prover/src/examples/wide_fibonacci/trace_gen.rs deleted file mode 100644 index 550aa0635..000000000 --- a/crates/prover/src/examples/wide_fibonacci/trace_gen.rs +++ /dev/null @@ -1,79 +0,0 @@ -use itertools::Itertools; -use num_traits::{One, Zero}; - -use super::component::Input; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::FieldExpOps; -use crate::core::utils::{ - bit_reverse, circle_domain_order_to_coset_order, shifted_secure_combination, -}; - -/// Writes the trace row for the wide Fibonacci example to dst, given a private input. Returns the -/// last two elements of the row in case the sequence is continued. -pub fn write_trace_row( - dst: &mut [Vec], - private_input: &Input, - row_index: usize, -) -> (BaseField, BaseField) { - let n_columns = dst.len(); - dst[0][row_index] = private_input.a; - dst[1][row_index] = private_input.b; - for i in 2..n_columns { - dst[i][row_index] = dst[i - 1][row_index].square() + dst[i - 2][row_index].square(); - } - - (dst[n_columns - 2][row_index], dst[n_columns - 1][row_index]) -} - -/// Writes and returns the lookup column for the wide Fibonacci example, which is the partial -/// product of the shifted secure combination of the first two elements in each row divided by the -/// the shifted secure combination of the last two elements in each row. -pub fn write_lookup_column( - input_trace: &[&[BaseField]], - alpha: SecureField, - z: SecureField, -) -> Vec { - let n_rows = input_trace[0].len(); - let n_columns = input_trace.len(); - let mut prev_value = SecureField::one(); - let mut input_trace = input_trace - .iter() - .map(|column| column.to_vec()) - .collect_vec(); - let natural_ordered_trace = input_trace - .iter_mut() - .map(|column| { - bit_reverse(column); - circle_domain_order_to_coset_order(column) - }) - .collect_vec(); - - let denominators = (0..n_rows) - .map(|i| { - shifted_secure_combination( - &[ - natural_ordered_trace[n_columns - 2][i], - natural_ordered_trace[n_columns - 1][i], - ], - alpha, - z, - ) - }) - .collect_vec(); - let mut denominator_inverses = vec![SecureField::zero(); denominators.len()]; - SecureField::batch_inverse(&denominators, &mut denominator_inverses); - - (0..n_rows) - .map(|i| { - let numerator = shifted_secure_combination( - &[natural_ordered_trace[0][i], natural_ordered_trace[1][i]], - alpha, - z, - ); - let cell = (numerator * denominator_inverses[i]) * prev_value; - prev_value = cell; - cell - }) - .collect_vec() -} diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index 9bb13a72b..1e9c3be74 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -21,4 +21,3 @@ pub mod constraint_framework; pub mod core; pub mod examples; pub mod math; -pub mod trace_generation; diff --git a/crates/prover/src/trace_generation/mod.rs b/crates/prover/src/trace_generation/mod.rs deleted file mode 100644 index a25f72e97..000000000 --- a/crates/prover/src/trace_generation/mod.rs +++ /dev/null @@ -1,75 +0,0 @@ -mod prove; -pub mod registry; - -use downcast_rs::{impl_downcast, Downcast}; -pub use prove::{commit_and_prove, commit_and_verify}; -use registry::ComponentGenerationRegistry; - -use crate::core::air::{AirProver, Component}; -use crate::core::backend::Backend; -use crate::core::channel::Channel; -use crate::core::fields::m31::BaseField; -use crate::core::poly::circle::CircleEvaluation; -use crate::core::poly::BitReversedOrder; -use crate::core::prover::VerificationError; -use crate::core::{ColumnVec, InteractionElements, LookupValues}; - -pub const BASE_TRACE: usize = 0; -pub const INTERACTION_TRACE: usize = 1; - -pub trait ComponentGen: Downcast {} -impl_downcast!(ComponentGen); - -// A trait to generate a a trace. -// Generates the trace given a list of inputs collects inputs for subcomponents. -pub trait ComponentTraceGenerator { - type Component: Component; - type Inputs; - - /// Add inputs for the trace generation of the component. - /// This function should be called from the caller components before calling `write_trace` of - /// this component. - fn add_inputs(&mut self, inputs: &Self::Inputs); - - /// Allocates and returns the trace of the component and updates the - /// subcomponents with the corresponding inputs. - /// Should be called only after all the inputs are available. - // TODO(ShaharS): change `component_id` to a struct that contains the id and the component name. - fn write_trace( - component_id: &str, - registry: &mut ComponentGenerationRegistry, - ) -> ColumnVec>; - - /// Allocates and returns the interaction trace of the component. - fn write_interaction_trace( - &self, - trace: &ColumnVec<&CircleEvaluation>, - elements: &InteractionElements, - ) -> ColumnVec>; - - fn component(&self) -> Self::Component; -} - -pub trait AirTraceVerifier { - fn interaction_elements(&self, channel: &mut impl Channel) -> InteractionElements; - - /// Verifies the lookups done in the Air. - fn verify_lookups(&self, lookup_values: &LookupValues) -> Result<(), VerificationError>; -} - -pub trait AirTraceGenerator: AirTraceVerifier { - fn composition_log_degree_bound(&self) -> u32; - - // TODO(AlonH): Remove default implementation once all the components are implemented. - fn write_trace(&mut self) -> Vec> { - vec![] - } - - fn interact( - &self, - trace: &ColumnVec>, - elements: &InteractionElements, - ) -> Vec>; - - fn to_air_prover(&self) -> impl AirProver; -} diff --git a/crates/prover/src/trace_generation/prove.rs b/crates/prover/src/trace_generation/prove.rs deleted file mode 100644 index 9328ae293..000000000 --- a/crates/prover/src/trace_generation/prove.rs +++ /dev/null @@ -1,417 +0,0 @@ -use itertools::Itertools; -use thiserror::Error; -use tracing::{span, Level}; - -use super::{AirTraceGenerator, AirTraceVerifier, BASE_TRACE, INTERACTION_TRACE}; -use crate::core::air::{Air, AirProver, ComponentProvers, Components}; -use crate::core::backend::BackendForChannel; -use crate::core::channel::{Channel, MerkleChannel}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; -use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, MAX_CIRCLE_DOMAIN_LOG_SIZE}; -use crate::core::poly::twiddles::TwiddleTree; -use crate::core::poly::BitReversedOrder; -use crate::core::prover::{prove, verify, ProvingError, StarkProof, VerificationError}; -use crate::core::{ColumnVec, InteractionElements}; - -pub fn commit_and_prove, MC: MerkleChannel>( - air: &impl AirTraceGenerator, - channel: &mut MC::C, - trace: ColumnVec>, - config: PcsConfig, -) -> Result, CommitAndProveError> { - // Check that traces are not too big. - for (i, trace) in trace.iter().enumerate() { - if trace.domain.log_size() + config.fri_config.log_blowup_factor - > MAX_CIRCLE_DOMAIN_LOG_SIZE - { - return Err(CommitAndProveError::MaxTraceDegreeExceeded { - trace_index: i, - degree: trace.domain.log_size(), - }); - } - } - - // Check that the composition polynomial is not too big. - // TODO(AlonH): Get traces log degree bounds from trace writer. - let composition_polynomial_log_degree_bound = air.composition_log_degree_bound(); - if composition_polynomial_log_degree_bound + config.fri_config.log_blowup_factor - > MAX_CIRCLE_DOMAIN_LOG_SIZE - { - return Err(CommitAndProveError::MaxCompositionDegreeExceeded { - degree: composition_polynomial_log_degree_bound, - }); - } - - let span = span!(Level::INFO, "Precompute twiddle").entered(); - let composition_polynomial_log_degree_bound = air.composition_log_degree_bound(); - let twiddles = B::precompute_twiddles( - CanonicCoset::new( - composition_polynomial_log_degree_bound + config.fri_config.log_blowup_factor, - ) - .circle_domain() - .half_coset, - ); - span.exit(); - - let (mut commitment_scheme, interaction_elements) = - evaluate_and_commit_on_trace(air, channel, &twiddles, trace, config)?; - - let air_prover = &air.to_air_prover(); - let components = ComponentProvers(air_prover.component_provers()); - channel.mix_felts( - &components - .lookup_values(&commitment_scheme.trace()) - .0 - .values() - .map(|v| SecureField::from(*v)) - .collect_vec(), - ); - - Ok(prove( - &components.0, - channel, - &interaction_elements, - &mut commitment_scheme, - )?) -} - -pub fn evaluate_and_commit_on_trace<'a, B: BackendForChannel, MC: MerkleChannel>( - air: &impl AirTraceGenerator, - channel: &mut MC::C, - twiddles: &'a TwiddleTree, - trace: ColumnVec>, - config: PcsConfig, -) -> Result<(CommitmentSchemeProver<'a, B, MC>, InteractionElements), ProvingError> { - let mut commitment_scheme = CommitmentSchemeProver::new(config, twiddles); - // TODO(spapini): Remove clone. - let span = span!(Level::INFO, "Trace").entered(); - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace.clone()); - tree_builder.commit(channel); - span.exit(); - - let interaction_elements = air.interaction_elements(channel); - let interaction_trace = air.interact(&trace, &interaction_elements); - // TODO(spapini): Make this symmetric with verify, once the TraceGenerator traits support - // retrieveing the column log sizes. - if !interaction_trace.is_empty() { - let _span = span!(Level::INFO, "Interaction").entered(); - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(interaction_trace); - tree_builder.commit(channel); - } - - Ok((commitment_scheme, interaction_elements)) -} - -pub fn commit_and_verify( - proof: StarkProof, - air: &(impl Air + AirTraceVerifier), - channel: &mut MC::C, - config: PcsConfig, -) -> Result<(), VerificationError> { - // Read trace commitment. - let mut commitment_scheme = CommitmentSchemeVerifier::::new(config); - - // TODO(spapini): Retrieve column_log_sizes from AirTraceVerifier, and remove the dependency on - // Air. - let components = Components(air.components()); - let column_log_sizes = components.column_log_sizes(); - commitment_scheme.commit( - proof.commitments[BASE_TRACE], - &column_log_sizes[BASE_TRACE], - channel, - ); - let interaction_elements = air.interaction_elements(channel); - - if components.column_log_sizes().len() == 2 { - commitment_scheme.commit( - proof.commitments[INTERACTION_TRACE], - &column_log_sizes[INTERACTION_TRACE], - channel, - ); - } - - channel.mix_felts( - &proof - .lookup_values - .0 - .values() - .map(|v| SecureField::from(*v)) - .collect_vec(), - ); - air.verify_lookups(&proof.lookup_values)?; - verify( - &components.0, - channel, - &interaction_elements, - &mut commitment_scheme, - proof, - ) -} - -#[derive(Clone, Copy, Debug, Error)] -pub enum CommitAndProveError { - #[error(transparent)] - ProvingError(#[from] ProvingError), - #[error( - "Expanded trace column {trace_index} log degree bound ({degree}) exceeded max log degree \ - ({MAX_CIRCLE_DOMAIN_LOG_SIZE})." - )] - MaxTraceDegreeExceeded { trace_index: usize, degree: u32 }, - #[error( - "Expanded composition polynomial log degree bound ({degree}) exceeded max log degree \ - ({MAX_CIRCLE_DOMAIN_LOG_SIZE})." - )] - MaxCompositionDegreeExceeded { degree: u32 }, -} - -#[cfg(test)] -mod tests { - use std::assert_matches::assert_matches; - - use num_traits::Zero; - - use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; - use crate::core::air::{Air, AirProver, Component, ComponentProver, Trace}; - use crate::core::backend::cpu::CpuCircleEvaluation; - use crate::core::backend::CpuBackend; - use crate::core::channel::Channel; - use crate::core::circle::{CirclePoint, CirclePointIndex, Coset}; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::pcs::{PcsConfig, TreeVec}; - use crate::core::poly::circle::{ - CanonicCoset, CircleDomain, CircleEvaluation, MAX_CIRCLE_DOMAIN_LOG_SIZE, - }; - use crate::core::poly::BitReversedOrder; - use crate::core::prover::{ProvingError, VerificationError}; - use crate::core::test_utils::test_channel; - use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; - use crate::core::{ColumnVec, InteractionElements, LookupValues}; - use crate::qm31; - use crate::trace_generation::prove::CommitAndProveError; - use crate::trace_generation::registry::ComponentGenerationRegistry; - use crate::trace_generation::{ - commit_and_prove, AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator, - }; - - #[derive(Clone)] - struct TestAir> { - component: C, - } - - impl Air for TestAir { - fn components(&self) -> Vec<&dyn Component> { - vec![&self.component] - } - } - - impl AirTraceVerifier for TestAir { - fn interaction_elements(&self, _channel: &mut impl Channel) -> InteractionElements { - InteractionElements::default() - } - - fn verify_lookups(&self, _lookup_values: &LookupValues) -> Result<(), VerificationError> { - Ok(()) - } - } - - impl AirTraceGenerator for TestAir { - fn interact( - &self, - _trace: &ColumnVec>, - _elements: &InteractionElements, - ) -> Vec> { - vec![] - } - - fn to_air_prover(&self) -> impl AirProver { - self.clone() - } - - fn composition_log_degree_bound(&self) -> u32 { - self.component.max_constraint_log_degree_bound() - } - } - - impl AirProver for TestAir { - fn component_provers(&self) -> Vec<&dyn ComponentProver> { - vec![&self.component] - } - } - - #[derive(Clone)] - struct TestComponent { - log_size: u32, - max_constraint_log_degree_bound: u32, - } - - impl Component for TestComponent { - fn n_constraints(&self) -> usize { - 0 - } - - fn max_constraint_log_degree_bound(&self) -> u32 { - self.max_constraint_log_degree_bound - } - - fn trace_log_degree_bounds(&self) -> TreeVec> { - TreeVec::new(vec![vec![self.log_size]]) - } - - fn mask_points( - &self, - point: CirclePoint, - ) -> TreeVec>>> { - TreeVec::new(vec![vec![vec![point]]]) - } - - fn evaluate_constraint_quotients_at_point( - &self, - _point: CirclePoint, - _mask: &TreeVec>>, - evaluation_accumulator: &mut PointEvaluationAccumulator, - _interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, - ) { - evaluation_accumulator.accumulate(qm31!(0, 0, 0, 1)) - } - } - - impl ComponentTraceGenerator for TestComponent { - type Component = Self; - type Inputs = (); - - fn add_inputs(&mut self, _inputs: &Self::Inputs) {} - - fn write_trace( - _component_id: &str, - _registry: &mut ComponentGenerationRegistry, - ) -> ColumnVec> { - vec![] - } - - fn write_interaction_trace( - &self, - _trace: &ColumnVec<&CircleEvaluation>, - _elements: &InteractionElements, - ) -> ColumnVec> { - vec![] - } - - fn component(&self) -> Self::Component { - self.clone() - } - } - - impl ComponentProver for TestComponent { - fn evaluate_constraint_quotients_on_domain( - &self, - _trace: &Trace<'_, CpuBackend>, - _evaluation_accumulator: &mut DomainEvaluationAccumulator, - _interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, - ) { - // Does nothing. - } - - fn lookup_values(&self, _trace: &Trace<'_, CpuBackend>) -> LookupValues { - LookupValues::default() - } - } - - // Ignored because it takes too long and too much memory (in the CI) to run. - #[test] - #[cfg_attr(not(feature = "slow-tests"), ignore)] - fn test_trace_too_big() { - const LOG_DOMAIN_SIZE: u32 = MAX_CIRCLE_DOMAIN_LOG_SIZE; - let air = TestAir { - component: TestComponent { - log_size: LOG_DOMAIN_SIZE, - max_constraint_log_degree_bound: LOG_DOMAIN_SIZE, - }, - }; - let domain = CircleDomain::new(Coset::new( - CirclePointIndex::generator(), - LOG_DOMAIN_SIZE - 1, - )); - let values = vec![BaseField::zero(); 1 << LOG_DOMAIN_SIZE]; - let trace = vec![CpuCircleEvaluation::new(domain, values)]; - - let proof_error = commit_and_prove::<_, Blake2sMerkleChannel>( - &air, - &mut test_channel(), - trace, - PcsConfig::default(), - ) - .unwrap_err(); - assert_matches!( - proof_error, - CommitAndProveError::MaxTraceDegreeExceeded { - trace_index: 0, - degree: LOG_DOMAIN_SIZE - } - ); - } - - #[test] - fn test_composition_polynomial_too_big() { - const COMPOSITION_POLYNOMIAL_DEGREE: u32 = MAX_CIRCLE_DOMAIN_LOG_SIZE; - const LOG_DOMAIN_SIZE: u32 = 5; - let air = TestAir { - component: TestComponent { - log_size: LOG_DOMAIN_SIZE, - max_constraint_log_degree_bound: COMPOSITION_POLYNOMIAL_DEGREE, - }, - }; - let domain = CircleDomain::new(Coset::new( - CirclePointIndex::generator(), - LOG_DOMAIN_SIZE - 1, - )); - let values = vec![BaseField::zero(); 1 << LOG_DOMAIN_SIZE]; - let trace = vec![CpuCircleEvaluation::new(domain, values)]; - - let proof_error = commit_and_prove::<_, Blake2sMerkleChannel>( - &air, - &mut test_channel(), - trace, - PcsConfig::default(), - ) - .unwrap_err(); - assert_matches!( - proof_error, - CommitAndProveError::MaxCompositionDegreeExceeded { - degree: COMPOSITION_POLYNOMIAL_DEGREE - } - ); - } - - #[test] - fn test_constraints_not_satisfied() { - const LOG_DOMAIN_SIZE: u32 = 5; - let air = TestAir { - component: TestComponent { - log_size: LOG_DOMAIN_SIZE, - max_constraint_log_degree_bound: LOG_DOMAIN_SIZE + 1, - }, - }; - let domain = CanonicCoset::new(LOG_DOMAIN_SIZE).circle_domain(); - let values = vec![BaseField::zero(); 1 << LOG_DOMAIN_SIZE]; - let trace = vec![CpuCircleEvaluation::new(domain, values)]; - - let proof = commit_and_prove::<_, Blake2sMerkleChannel>( - &air, - &mut test_channel(), - trace, - PcsConfig::default(), - ) - .unwrap_err(); - assert_matches!( - proof, - CommitAndProveError::ProvingError(ProvingError::ConstraintsNotSatisfied) - ); - } -} diff --git a/crates/prover/src/trace_generation/registry.rs b/crates/prover/src/trace_generation/registry.rs deleted file mode 100644 index 2dfdef45f..000000000 --- a/crates/prover/src/trace_generation/registry.rs +++ /dev/null @@ -1,178 +0,0 @@ -use std::collections::HashMap; - -use super::ComponentGen; - -#[derive(Default)] -pub struct ComponentGenerationRegistry { - components: HashMap>, -} - -impl ComponentGenerationRegistry { - pub fn register(&mut self, component_id: &str, component_generator: impl ComponentGen) { - self.components - .insert(component_id.to_string(), Box::new(component_generator)); - } - - pub fn get_generator(&self, component_id: &str) -> &T { - self.components - .get(component_id) - .unwrap_or_else(|| panic!("Component ID: {} not found.", component_id)) - .downcast_ref() - .unwrap() - } - - pub fn get_generator_mut(&mut self, component_id: &str) -> &mut T { - self.components - .get_mut(component_id) - .unwrap_or_else(|| panic!("Component ID: {} not found.", component_id)) - .downcast_mut() - .unwrap() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::core::air::accumulation::PointEvaluationAccumulator; - use crate::core::air::Component; - use crate::core::backend::simd::m31::{PackedM31, N_LANES}; - use crate::core::backend::simd::SimdBackend; - use crate::core::backend::CpuBackend; - use crate::core::circle::CirclePoint; - use crate::core::fields::m31::{BaseField, M31}; - use crate::core::fields::qm31::SecureField; - use crate::core::pcs::TreeVec; - use crate::core::poly::circle::CircleEvaluation; - use crate::core::poly::BitReversedOrder; - use crate::core::{ColumnVec, InteractionElements, LookupValues}; - use crate::m31; - use crate::trace_generation::ComponentTraceGenerator; - pub struct ComponentA { - pub n_instances: usize, - } - - impl Component for ComponentA { - fn n_constraints(&self) -> usize { - todo!() - } - - fn max_constraint_log_degree_bound(&self) -> u32 { - todo!() - } - - fn trace_log_degree_bounds(&self) -> TreeVec> { - todo!() - } - - fn mask_points( - &self, - _point: CirclePoint, - ) -> TreeVec>>> { - todo!() - } - - fn evaluate_constraint_quotients_at_point( - &self, - _point: CirclePoint, - _mask: &TreeVec>>, - _evaluation_accumulator: &mut PointEvaluationAccumulator, - _interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, - ) { - todo!() - } - } - - type ComponentACpuInputs = Vec<(M31, M31)>; - struct ComponentACpuTraceGenerator { - inputs: ComponentACpuInputs, - } - impl ComponentGen for ComponentACpuTraceGenerator {} - - impl ComponentTraceGenerator for ComponentACpuTraceGenerator { - type Component = ComponentA; - type Inputs = ComponentACpuInputs; - - fn write_trace( - _component_id: &str, - _registry: &mut ComponentGenerationRegistry, - ) -> ColumnVec> { - unimplemented!("TestTraceGenerator::write_trace") - } - - fn add_inputs(&mut self, inputs: &ComponentACpuInputs) { - self.inputs.extend(inputs) - } - - fn component(&self) -> ComponentA { - ComponentA { - n_instances: self.inputs.len(), - } - } - - fn write_interaction_trace( - &self, - _trace: &ColumnVec<&CircleEvaluation>, - _elements: &InteractionElements, - ) -> ColumnVec> { - unimplemented!("TestTraceGenerator::write_interaction_trace") - } - } - - type ComponentASimdInputs = Vec<(PackedM31, PackedM31)>; - struct ComponentASimdTraceGenerator { - inputs: ComponentASimdInputs, - } - impl ComponentGen for ComponentASimdTraceGenerator {} - - impl ComponentTraceGenerator for ComponentASimdTraceGenerator { - type Component = ComponentA; - type Inputs = ComponentASimdInputs; - - fn write_trace( - _component_id: &str, - _registry: &mut ComponentGenerationRegistry, - ) -> ColumnVec> { - unimplemented!("TestTraceGenerator::write_trace") - } - - fn add_inputs(&mut self, inputs: &ComponentASimdInputs) { - self.inputs.extend(inputs) - } - - fn component(&self) -> ComponentA { - ComponentA { - n_instances: self.inputs.len() * N_LANES, - } - } - - fn write_interaction_trace( - &self, - _trace: &ColumnVec<&CircleEvaluation>, - _elements: &InteractionElements, - ) -> ColumnVec> { - unimplemented!("TestTraceGenerator::write_interaction_trace") - } - } - - #[test] - fn test_component_registry() { - let mut registry = ComponentGenerationRegistry::default(); - let component_id = "componentA::0"; - - let component_a_cpu_trace_generator = ComponentACpuTraceGenerator { inputs: vec![] }; - registry.register(component_id, component_a_cpu_trace_generator); - let cpu_inputs = vec![(m31!(1), m31!(1)), (m31!(2), m31!(2))]; - - registry - .get_generator_mut::(component_id) - .add_inputs(&cpu_inputs); - - assert_eq!( - registry - .get_generator_mut::(component_id) - .inputs, - cpu_inputs - ); - } -}