From 2d40280e1b16b3ceb5da125358933e7807968bcb Mon Sep 17 00:00:00 2001 From: Alon Haramati Date: Mon, 24 Jun 2024 16:47:40 +0300 Subject: [PATCH] Implement trace generator for fibonacci. --- crates/prover/benches/poseidon.rs | 2 +- crates/prover/src/core/prover/mod.rs | 23 +++-- crates/prover/src/examples/fibonacci/air.rs | 86 ++++++++++++++++++- .../src/examples/fibonacci/component.rs | 64 ++++++++++++-- crates/prover/src/examples/fibonacci/mod.rs | 28 +++++- crates/prover/src/examples/poseidon/mod.rs | 10 ++- .../src/examples/wide_fibonacci/component.rs | 2 + .../wide_fibonacci/constraint_eval.rs | 6 +- .../prover/src/examples/wide_fibonacci/mod.rs | 2 +- .../src/examples/wide_fibonacci/simd.rs | 10 ++- crates/prover/src/trace_generation/mod.rs | 10 ++- .../prover/src/trace_generation/registry.rs | 12 +-- 12 files changed, 215 insertions(+), 40 deletions(-) diff --git a/crates/prover/benches/poseidon.rs b/crates/prover/benches/poseidon.rs index 7f219663b..0e7d6b2b3 100644 --- a/crates/prover/benches/poseidon.rs +++ b/crates/prover/benches/poseidon.rs @@ -20,7 +20,7 @@ pub fn simd_poseidon(c: &mut Criterion) { let trace = gen_trace(component.log_column_size()); let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); let air = PoseidonAir { component }; - prove::(&air, channel, trace).unwrap() + prove::(air, channel, trace).unwrap() }); }); } diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index 6dfd4e94b..86ed32e82 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -157,7 +157,7 @@ pub fn generate_proof>( } pub fn prove>( - air: &impl AirTraceGenerator, + air: impl AirTraceGenerator, channel: &mut Channel, trace: ColumnVec>, ) -> Result { @@ -173,8 +173,7 @@ pub fn prove>( // Check that the composition polynomial is not too big. // TODO(AlonH): Get traces log degree bounds from trace writer. - let composition_polynomial_log_degree_bound = - air.to_air_prover().composition_log_degree_bound(); + let composition_polynomial_log_degree_bound = air.composition_log_degree_bound(); if composition_polynomial_log_degree_bound + LOG_BLOWUP_FACTOR > MAX_CIRCLE_DOMAIN_LOG_SIZE { return Err(ProvingError::MaxCompositionDegreeExceeded { degree: composition_polynomial_log_degree_bound, @@ -190,10 +189,10 @@ pub fn prove>( span.exit(); let (mut commitment_scheme, interaction_elements) = - evaluate_and_commit_on_trace(air, channel, &twiddles, trace)?; + evaluate_and_commit_on_trace(&air, channel, &twiddles, trace)?; generate_proof( - air.to_air_prover(), + &air.to_air_prover(), channel, &interaction_elements, &twiddles, @@ -390,6 +389,7 @@ mod tests { use crate::trace_generation::registry::ComponentGenerationRegistry; use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; + #[derive(Clone)] struct TestAir> { component: C, } @@ -415,9 +415,13 @@ mod tests { vec![] } - fn to_air_prover(&self) -> &impl AirProver { + fn to_air_prover(self) -> impl AirProver { self } + + fn composition_log_degree_bound(&self) -> u32 { + self.component.max_constraint_log_degree_bound() + } } impl AirProver for TestAir { @@ -426,6 +430,7 @@ mod tests { } } + #[derive(Clone)] struct TestComponent { log_size: u32, max_constraint_log_degree_bound: u32, @@ -527,7 +532,7 @@ mod tests { let values = vec![BaseField::zero(); 1 << LOG_DOMAIN_SIZE]; let trace = vec![CpuCircleEvaluation::new(domain, values)]; - let proof_error = prove(&air, &mut test_channel(), trace).unwrap_err(); + let proof_error = prove(air, &mut test_channel(), trace).unwrap_err(); assert!(matches!( proof_error, ProvingError::MaxTraceDegreeExceeded { @@ -554,7 +559,7 @@ mod tests { let values = vec![BaseField::zero(); 1 << LOG_DOMAIN_SIZE]; let trace = vec![CpuCircleEvaluation::new(domain, values)]; - let proof_error = prove(&air, &mut test_channel(), trace).unwrap_err(); + let proof_error = prove(air, &mut test_channel(), trace).unwrap_err(); assert!(matches!( proof_error, ProvingError::MaxCompositionDegreeExceeded { @@ -576,7 +581,7 @@ mod tests { let values = vec![BaseField::zero(); 1 << LOG_DOMAIN_SIZE]; let trace = vec![CpuCircleEvaluation::new(domain, values)]; - let proof = prove(&air, &mut test_channel(), trace).unwrap_err(); + let proof = prove(air, &mut test_channel(), trace).unwrap_err(); assert!(matches!(proof, ProvingError::ConstraintsNotSatisfied)); } } diff --git a/crates/prover/src/examples/fibonacci/air.rs b/crates/prover/src/examples/fibonacci/air.rs index 40c8eea6f..ab5ffd55a 100644 --- a/crates/prover/src/examples/fibonacci/air.rs +++ b/crates/prover/src/examples/fibonacci/air.rs @@ -1,6 +1,6 @@ use itertools::{zip_eq, Itertools}; -use super::component::FibonacciComponent; +use super::component::{FibonacciComponent, FibonacciInput, FibonacciTraceGenerator}; use crate::core::air::{Air, AirProver, Component, ComponentProver}; use crate::core::backend::CpuBackend; use crate::core::channel::Blake2sChannel; @@ -8,8 +8,73 @@ use crate::core::fields::m31::BaseField; use crate::core::poly::circle::CircleEvaluation; use crate::core::poly::BitReversedOrder; use crate::core::{ColumnVec, InteractionElements}; -use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier}; +use crate::trace_generation::registry::ComponentGenerationRegistry; +use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; +pub struct FibonacciAirGenerator { + pub registry: ComponentGenerationRegistry, +} + +impl Clone for FibonacciAirGenerator { + fn clone(&self) -> Self { + Self { + registry: ComponentGenerationRegistry::default(), + } + } +} + +impl FibonacciAirGenerator { + pub fn new(inputs: &FibonacciInput) -> Self { + let mut component_generator = FibonacciTraceGenerator::new(); + component_generator.add_inputs(inputs); + let mut registry = ComponentGenerationRegistry::default(); + registry.register("fibonacci", component_generator); + Self { registry } + } +} + +impl AirTraceVerifier for FibonacciAirGenerator { + fn interaction_elements(&self, _channel: &mut Blake2sChannel) -> InteractionElements { + InteractionElements::default() + } +} + +impl AirTraceGenerator for FibonacciAirGenerator { + fn write_trace(&mut self) -> Vec> { + FibonacciTraceGenerator::write_trace("fibonacci", &mut self.registry) + } + + fn interact( + &self, + _trace: &ColumnVec>, + _elements: &InteractionElements, + ) -> Vec> { + vec![] + } + + fn to_air_prover(self) -> impl AirProver { + let component_generator = self + .registry + .get_generator::("fibonacci"); + // TODO(AlonH): Take instead of clone. + FibonacciAir { + component: component_generator.clone().component(), + } + } + + fn composition_log_degree_bound(&self) -> u32 { + let component_generator = self + .registry + .get_generator::("fibonacci"); + assert!(component_generator.inputs_set(), "Fibonacci input not set."); + component_generator + .clone() + .component() + .max_constraint_log_degree_bound() + } +} + +#[derive(Clone)] pub struct FibonacciAir { pub component: FibonacciComponent, } @@ -41,9 +106,13 @@ impl AirTraceGenerator for FibonacciAir { vec![] } - fn to_air_prover(&self) -> &impl AirProver { + fn to_air_prover(self) -> impl AirProver { self } + + fn composition_log_degree_bound(&self) -> u32 { + self.component.max_constraint_log_degree_bound() + } } impl AirProver for FibonacciAir { @@ -52,6 +121,7 @@ impl AirProver for FibonacciAir { } } +#[derive(Clone)] pub struct MultiFibonacciAir { pub components: Vec, } @@ -90,9 +160,17 @@ impl AirTraceGenerator for MultiFibonacciAir { vec![] } - fn to_air_prover(&self) -> &impl AirProver { + fn to_air_prover(self) -> impl AirProver { self } + + fn composition_log_degree_bound(&self) -> u32 { + self.components + .iter() + .map(|component| component.max_constraint_log_degree_bound()) + .max() + .unwrap() + } } impl AirProver for MultiFibonacciAir { diff --git a/crates/prover/src/examples/fibonacci/component.rs b/crates/prover/src/examples/fibonacci/component.rs index 73a8161b5..25d782d30 100644 --- a/crates/prover/src/examples/fibonacci/component.rs +++ b/crates/prover/src/examples/fibonacci/component.rs @@ -18,8 +18,9 @@ use crate::core::prover::BASE_TRACE; use crate::core::utils::bit_reverse_index; use crate::core::{ColumnVec, InteractionElements, LookupValues}; use crate::trace_generation::registry::ComponentGenerationRegistry; -use crate::trace_generation::ComponentTraceGenerator; +use crate::trace_generation::{ComponentGen, ComponentTraceGenerator}; +#[derive(Clone)] pub struct FibonacciComponent { pub log_size: u32, pub claim: BaseField, @@ -130,17 +131,61 @@ impl Component for FibonacciComponent { } } -impl ComponentTraceGenerator for FibonacciComponent { - type Component = Self; - type Inputs = (); +pub struct FibonacciInput(pub u32, pub BaseField); - fn add_inputs(&mut self, _inputs: &Self::Inputs) {} +#[derive(Clone)] +pub struct FibonacciTraceGenerator { + log_size: Option, + claim: Option, +} + +impl ComponentGen for FibonacciTraceGenerator {} + +impl FibonacciTraceGenerator { + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + Self { + log_size: None, + claim: None, + } + } + + pub fn inputs_set(&self) -> bool { + self.log_size.is_some() && self.claim.is_some() + } +} + +impl ComponentTraceGenerator for FibonacciTraceGenerator { + type Component = FibonacciComponent; + type Inputs = FibonacciInput; + + fn add_inputs(&mut self, inputs: &Self::Inputs) { + assert!(!self.inputs_set(), "Fibonacci input already set."); + self.log_size = Some(inputs.0); + self.claim = Some(inputs.1); + } fn write_trace( - _component_id: &str, - _registry: &mut ComponentGenerationRegistry, + component_id: &str, + registry: &mut ComponentGenerationRegistry, ) -> ColumnVec> { - vec![] + let component = registry.get_generator_mut::(component_id); + assert!(component.inputs_set(), "Fibonacci input not set."); + let trace_domain = CanonicCoset::new(component.log_size.unwrap()); + let mut trace = Vec::with_capacity(trace_domain.size()); + + // Fill trace with fibonacci squared. + let mut a = BaseField::one(); + let mut b = BaseField::one(); + for _ in 0..trace_domain.size() { + trace.push(a); + let tmp = a.square() + b.square(); + a = b; + b = tmp; + } + + // Returns as a CircleEvaluation. + vec![CircleEvaluation::new_canonical_ordered(trace_domain, trace)] } fn write_interaction_trace( @@ -152,7 +197,8 @@ impl ComponentTraceGenerator for FibonacciComponent { } fn component(self) -> Self::Component { - self + assert!(self.inputs_set(), "Fibonacci input not set."); + FibonacciComponent::new(self.log_size.unwrap(), self.claim.unwrap()) } } diff --git a/crates/prover/src/examples/fibonacci/mod.rs b/crates/prover/src/examples/fibonacci/mod.rs index cf2862162..910cf364a 100644 --- a/crates/prover/src/examples/fibonacci/mod.rs +++ b/crates/prover/src/examples/fibonacci/mod.rs @@ -17,6 +17,7 @@ use crate::core::vcs::hasher::Hasher; pub mod air; mod component; +#[derive(Clone)] pub struct Fibonacci { pub air: FibonacciAir, } @@ -55,7 +56,7 @@ impl Fibonacci { .air .component .claim]))); - prove(&self.air, channel, vec![trace]) + prove(self.air.clone(), channel, vec![trace]) } pub fn verify(&self, proof: StarkProof) -> Result<(), VerificationError> { @@ -97,7 +98,8 @@ impl MultiFibonacci { pub fn prove(&self) -> Result { let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&self.claims))); - prove(&self.air, channel, self.get_trace()) + let trace = self.get_trace(); + prove(self.air.clone(), channel, trace) } pub fn verify(&self, proof: StarkProof) -> Result<(), VerificationError> { @@ -120,15 +122,22 @@ mod tests { use super::{Fibonacci, MultiFibonacci}; use crate::core::air::accumulation::PointEvaluationAccumulator; use crate::core::air::{AirExt, AirProverExt, Component, ComponentTrace}; + use crate::core::channel::{Blake2sChannel, Channel}; use crate::core::circle::CirclePoint; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; + use crate::core::fields::IntoSlice; use crate::core::pcs::TreeVec; use crate::core::poly::circle::CanonicCoset; - use crate::core::prover::{VerificationError, BASE_TRACE}; + use crate::core::prover::{prove, VerificationError, BASE_TRACE}; use crate::core::queries::Queries; use crate::core::utils::bit_reverse; + use crate::core::vcs::blake2_hash::Blake2sHasher; + use crate::core::vcs::hasher::Hasher; use crate::core::{InteractionElements, LookupValues}; + use crate::examples::fibonacci::air::FibonacciAirGenerator; + use crate::examples::fibonacci::component::FibonacciInput; + use crate::trace_generation::AirTraceGenerator; use crate::{m31, qm31}; pub fn generate_test_queries(n_queries: usize, trace_length: usize) -> Vec { @@ -232,6 +241,19 @@ mod tests { fib.verify(proof).unwrap(); } + #[test] + fn test_fib_prove_2() { + const FIB_LOG_SIZE: u32 = 5; + const CLAIM: BaseField = m31!(443693538); + let mut fib_trace_generator = + FibonacciAirGenerator::new(&FibonacciInput(FIB_LOG_SIZE, CLAIM)); + + let trace = fib_trace_generator.write_trace(); + let channel = + &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[CLAIM]))); + prove(fib_trace_generator, channel, trace).unwrap(); + } + #[test] fn test_prove_invalid_trace_value() { const FIB_LOG_SIZE: u32 = 5; diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 6057f211a..0e529228e 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -41,6 +41,7 @@ const EXTERNAL_ROUND_CONSTS: [[BaseField; N_STATE]; 2 * N_HALF_FULL_ROUNDS] = const INTERNAL_ROUND_CONSTS: [BaseField; N_PARTIAL_ROUNDS] = [BaseField::from_u32_unchecked(1234); N_PARTIAL_ROUNDS]; +#[derive(Clone)] pub struct PoseidonComponent { pub log_n_instances: u32, } @@ -55,6 +56,7 @@ impl PoseidonComponent { } } +#[derive(Clone)] pub struct PoseidonAir { pub component: PoseidonComponent, } @@ -80,9 +82,13 @@ impl AirTraceGenerator for PoseidonAir { vec![] } - fn to_air_prover(&self) -> &impl AirProver { + fn to_air_prover(self) -> impl AirProver { self } + + fn composition_log_degree_bound(&self) -> u32 { + self.component.max_constraint_log_degree_bound() + } } impl Component for PoseidonComponent { @@ -561,7 +567,7 @@ mod tests { let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); let air = PoseidonAir { component }; - let proof = prove::(&air, channel, trace).unwrap(); + let proof = prove::(air.clone(), channel, trace).unwrap(); let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); verify(proof, &air, channel).unwrap(); diff --git a/crates/prover/src/examples/wide_fibonacci/component.rs b/crates/prover/src/examples/wide_fibonacci/component.rs index 714ada265..067d162d4 100644 --- a/crates/prover/src/examples/wide_fibonacci/component.rs +++ b/crates/prover/src/examples/wide_fibonacci/component.rs @@ -34,6 +34,7 @@ pub const LOOKUP_VALUE_N_MINUS_1_ID: &str = "wide_fibonacci_n-1"; /// 2^`self.log_fibonacci_size`. The numbers are computes over [N_COLUMNS] trace columns. The /// number of rows (i.e the size of the columns) is determined by the parameters above (see /// [WideFibComponent::log_column_size()]). +#[derive(Clone)] pub struct WideFibComponent { pub log_fibonacci_size: u32, pub log_n_instances: u32, @@ -169,6 +170,7 @@ impl WideFibComponent { } } +#[derive(Clone)] pub struct WideFibAir { pub component: WideFibComponent, } diff --git a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs b/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs index a7ff01af8..d4d193932 100644 --- a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs +++ b/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs @@ -48,9 +48,13 @@ impl AirTraceGenerator for WideFibAir { .write_interaction_trace(&trace.iter().collect(), elements) } - fn to_air_prover(&self) -> &impl AirProver { + fn to_air_prover(self) -> impl AirProver { self } + + fn composition_log_degree_bound(&self) -> u32 { + self.component.max_constraint_log_degree_bound() + } } impl AirProver for WideFibAir { diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index cc23d7c1e..bb6cd06b1 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -235,7 +235,7 @@ mod tests { let air = WideFibAir { component }; let prover_channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); - let proof = prove::(&air, prover_channel, trace).unwrap(); + let proof = prove::(air.clone(), prover_channel, trace).unwrap(); let verifier_channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); diff --git a/crates/prover/src/examples/wide_fibonacci/simd.rs b/crates/prover/src/examples/wide_fibonacci/simd.rs index d3ba4adbd..5962f2b64 100644 --- a/crates/prover/src/examples/wide_fibonacci/simd.rs +++ b/crates/prover/src/examples/wide_fibonacci/simd.rs @@ -27,6 +27,7 @@ use crate::trace_generation::registry::ComponentGenerationRegistry; use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; // TODO(AlonH): Remove this once the Cpu and Simd implementations are aligned. +#[derive(Clone)] pub struct SimdWideFibComponent { pub log_fibonacci_size: u32, pub log_n_instances: u32, @@ -49,6 +50,7 @@ impl SimdWideFibComponent { } // TODO(AlonH): Remove this once the Cpu and Simd implementations are aligned. +#[derive(Clone)] pub struct SimdWideFibAir { pub component: SimdWideFibComponent, } @@ -74,9 +76,13 @@ impl AirTraceGenerator for SimdWideFibAir { vec![] } - fn to_air_prover(&self) -> &impl AirProver { + fn to_air_prover(self) -> impl AirProver { self } + + fn composition_log_degree_bound(&self) -> u32 { + self.component.max_constraint_log_degree_bound() + } } impl Component for SimdWideFibComponent { @@ -280,7 +286,7 @@ mod tests { span.exit(); let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); let air = SimdWideFibAir { component }; - let proof = prove::(&air, channel, trace).unwrap(); + let proof = prove::(air.clone(), channel, trace).unwrap(); let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[]))); verify(proof, &air, channel).unwrap(); diff --git a/crates/prover/src/trace_generation/mod.rs b/crates/prover/src/trace_generation/mod.rs index fe594ed84..c4f2186ff 100644 --- a/crates/prover/src/trace_generation/mod.rs +++ b/crates/prover/src/trace_generation/mod.rs @@ -47,12 +47,18 @@ pub trait AirTraceVerifier { fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements; } -pub trait AirTraceGenerator: AirTraceVerifier { +pub trait AirTraceGenerator: AirTraceVerifier + Clone { + fn composition_log_degree_bound(&self) -> u32; + + fn write_trace(&mut self) -> Vec> { + vec![] + } + fn interact( &self, trace: &ColumnVec>, elements: &InteractionElements, ) -> Vec>; - fn to_air_prover(&self) -> &impl AirProver; + fn to_air_prover(self) -> impl AirProver; } diff --git a/crates/prover/src/trace_generation/registry.rs b/crates/prover/src/trace_generation/registry.rs index dcb2d7e66..b8533e929 100644 --- a/crates/prover/src/trace_generation/registry.rs +++ b/crates/prover/src/trace_generation/registry.rs @@ -112,12 +112,6 @@ mod tests { self.inputs.extend(inputs) } - fn component(self) -> ComponentA { - ComponentA { - n_instances: self.inputs.len(), - } - } - fn write_interaction_trace( &self, _trace: &ColumnVec<&CircleEvaluation>, @@ -125,6 +119,12 @@ mod tests { ) -> ColumnVec> { unimplemented!("TestTraceGenerator::write_interaction_trace") } + + fn component(self) -> ComponentA { + ComponentA { + n_instances: self.inputs.len(), + } + } } type ComponentASimdInputs = Vec<(PackedM31, PackedM31)>;