From 2ce2998b7136cf3eee457865e1425859b73504ca Mon Sep 17 00:00:00 2001 From: Alon Haramati Date: Sun, 23 Jun 2024 16:57:15 +0300 Subject: [PATCH] Merge trace writer and trace generator. --- crates/prover/src/core/air/mod.rs | 23 ------------ crates/prover/src/core/prover/mod.rs | 35 ++++++++++++++----- crates/prover/src/examples/fibonacci/air.rs | 9 +++-- .../src/examples/fibonacci/component.rs | 23 ++++++++++-- crates/prover/src/examples/poseidon/mod.rs | 28 +++++++++++---- .../src/examples/wide_fibonacci/component.rs | 23 ++++++++++-- .../wide_fibonacci/constraint_eval.rs | 8 ++--- .../prover/src/examples/wide_fibonacci/mod.rs | 3 +- .../src/examples/wide_fibonacci/simd.rs | 28 +++++++++++---- crates/prover/src/trace_generation/mod.rs | 28 +++++++++++++-- .../prover/src/trace_generation/registry.rs | 22 ++++++++++-- 11 files changed, 165 insertions(+), 65 deletions(-) diff --git a/crates/prover/src/core/air/mod.rs b/crates/prover/src/core/air/mod.rs index 91ddbc714..d6d980511 100644 --- a/crates/prover/src/core/air/mod.rs +++ b/crates/prover/src/core/air/mod.rs @@ -1,6 +1,5 @@ use self::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use super::backend::Backend; -use super::channel::Blake2sChannel; use super::circle::CirclePoint; use super::fields::m31::BaseField; use super::fields::qm31::SecureField; @@ -25,20 +24,6 @@ pub trait Air { fn components(&self) -> Vec<&dyn Component>; } -pub trait AirTraceVerifier { - fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements; -} - -pub trait AirTraceWriter: AirTraceVerifier { - fn interact( - &self, - trace: &ColumnVec>, - elements: &InteractionElements, - ) -> Vec>; - - fn to_air_prover(&self) -> &impl AirProver; -} - pub trait AirProver: Air { fn prover_components(&self) -> Vec<&dyn ComponentProver>; } @@ -78,14 +63,6 @@ pub trait Component { ); } -pub trait ComponentTraceWriter { - fn write_interaction_trace( - &self, - trace: &ColumnVec<&CircleEvaluation>, - elements: &InteractionElements, - ) -> ColumnVec>; -} - pub trait ComponentProver: Component { /// Evaluates the constraint quotients of the component on the evaluation domain. /// Accumulates quotients in `evaluation_accumulator`. diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index f8d19e7f5..fd68091de 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -2,7 +2,7 @@ use itertools::Itertools; use thiserror::Error; use tracing::{span, Level}; -use super::air::{AirProver, AirTraceVerifier, AirTraceWriter}; +use super::air::AirProver; use super::backend::Backend; use super::fields::secure_column::SECURE_EXTENSION_DEGREE; use super::fri::FriVerificationError; @@ -25,6 +25,7 @@ use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; use crate::core::vcs::hasher::Hasher; use crate::core::vcs::ops::MerkleOps; use crate::core::vcs::verifier::MerkleVerificationError; +use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier}; type Channel = Blake2sChannel; type ChannelHasher = Blake2sHasher; @@ -54,7 +55,7 @@ pub struct AdditionalProofData { } pub fn evaluate_and_commit_on_trace>( - air: &impl AirTraceWriter, + air: &impl AirTraceGenerator, channel: &mut Channel, twiddles: &TwiddleTree, trace: ColumnVec>, @@ -155,7 +156,7 @@ pub fn generate_proof>( } pub fn prove>( - air: &impl AirTraceWriter, + air: &impl AirTraceGenerator, channel: &mut Channel, trace: ColumnVec>, ) -> Result { @@ -348,10 +349,7 @@ mod tests { use num_traits::Zero; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; - use crate::core::air::{ - Air, AirProver, AirTraceVerifier, AirTraceWriter, Component, ComponentProver, - ComponentTrace, ComponentTraceWriter, - }; + use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace}; use crate::core::backend::cpu::CpuCircleEvaluation; use crate::core::backend::CpuBackend; use crate::core::channel::Blake2sChannel; @@ -367,6 +365,8 @@ mod tests { use crate::core::test_utils::test_channel; use crate::core::{ColumnVec, InteractionElements, LookupValues}; use crate::qm31; + use crate::trace_generation::registry::ComponentGenerationRegistry; + use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; struct TestAir> { component: C, @@ -384,7 +384,7 @@ mod tests { } } - impl AirTraceWriter for TestAir { + impl AirTraceGenerator for TestAir { fn interact( &self, _trace: &ColumnVec>, @@ -404,6 +404,7 @@ mod tests { } } + #[derive(Clone)] struct TestComponent { log_size: u32, max_constraint_log_degree_bound: u32, @@ -449,7 +450,19 @@ mod tests { } } - impl ComponentTraceWriter for TestComponent { + impl ComponentTraceGenerator for TestComponent { + type Component = Self; + type Inputs = (); + + fn add_inputs(&mut self, _inputs: &Self::Inputs) {} + + fn write_trace( + _component_id: &str, + _registry: &mut ComponentGenerationRegistry, + ) -> ColumnVec> { + vec![] + } + fn write_interaction_trace( &self, _trace: &ColumnVec<&CircleEvaluation>, @@ -457,6 +470,10 @@ mod tests { ) -> ColumnVec> { vec![] } + + fn component(&self) -> Self::Component { + self.clone() + } } impl ComponentProver for TestComponent { diff --git a/crates/prover/src/examples/fibonacci/air.rs b/crates/prover/src/examples/fibonacci/air.rs index c8c1b8c39..40c8eea6f 100644 --- a/crates/prover/src/examples/fibonacci/air.rs +++ b/crates/prover/src/examples/fibonacci/air.rs @@ -1,15 +1,14 @@ use itertools::{zip_eq, Itertools}; use super::component::FibonacciComponent; -use crate::core::air::{ - Air, AirProver, AirTraceVerifier, AirTraceWriter, Component, ComponentProver, -}; +use crate::core::air::{Air, AirProver, Component, ComponentProver}; use crate::core::backend::CpuBackend; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; use crate::core::poly::circle::CircleEvaluation; use crate::core::poly::BitReversedOrder; use crate::core::{ColumnVec, InteractionElements}; +use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier}; pub struct FibonacciAir { pub component: FibonacciComponent, @@ -33,7 +32,7 @@ impl AirTraceVerifier for FibonacciAir { } } -impl AirTraceWriter for FibonacciAir { +impl AirTraceGenerator for FibonacciAir { fn interact( &self, _trace: &ColumnVec>, @@ -82,7 +81,7 @@ impl AirTraceVerifier for MultiFibonacciAir { } } -impl AirTraceWriter for MultiFibonacciAir { +impl AirTraceGenerator for MultiFibonacciAir { fn interact( &self, _trace: &ColumnVec>, diff --git a/crates/prover/src/examples/fibonacci/component.rs b/crates/prover/src/examples/fibonacci/component.rs index cbe24012c..f683ae8b3 100644 --- a/crates/prover/src/examples/fibonacci/component.rs +++ b/crates/prover/src/examples/fibonacci/component.rs @@ -4,7 +4,7 @@ use num_traits::One; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::mask::shifted_mask_points; -use crate::core::air::{Component, ComponentProver, ComponentTrace, ComponentTraceWriter}; +use crate::core::air::{Component, ComponentProver, ComponentTrace}; use crate::core::backend::CpuBackend; use crate::core::circle::{CirclePoint, Coset}; use crate::core::constraints::{coset_vanishing, pair_vanishing}; @@ -17,7 +17,10 @@ use crate::core::poly::BitReversedOrder; use crate::core::prover::BASE_TRACE; use crate::core::utils::bit_reverse_index; use crate::core::{ColumnVec, InteractionElements, LookupValues}; +use crate::trace_generation::registry::ComponentGenerationRegistry; +use crate::trace_generation::ComponentTraceGenerator; +#[derive(Clone)] pub struct FibonacciComponent { pub log_size: u32, pub claim: BaseField, @@ -123,7 +126,19 @@ impl Component for FibonacciComponent { } } -impl ComponentTraceWriter for FibonacciComponent { +impl ComponentTraceGenerator for FibonacciComponent { + type Component = Self; + type Inputs = (); + + fn add_inputs(&mut self, _inputs: &Self::Inputs) {} + + fn write_trace( + _component_id: &str, + _registry: &mut ComponentGenerationRegistry, + ) -> ColumnVec> { + vec![] + } + fn write_interaction_trace( &self, _trace: &ColumnVec<&CircleEvaluation>, @@ -131,6 +146,10 @@ impl ComponentTraceWriter for FibonacciComponent { ) -> ColumnVec> { vec![] } + + fn component(&self) -> Self::Component { + self.clone() + } } impl ComponentProver for FibonacciComponent { diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 368fbbb50..5d5d65255 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -8,10 +8,7 @@ use tracing::{span, Level}; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::mask::fixed_mask_points; -use crate::core::air::{ - Air, AirProver, AirTraceVerifier, AirTraceWriter, Component, ComponentProver, ComponentTrace, - ComponentTraceWriter, -}; +use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace}; use crate::core::backend::simd::column::BaseFieldVec; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; @@ -27,6 +24,7 @@ use crate::core::pcs::TreeVec; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; use crate::core::{ColumnVec, InteractionElements, LookupValues}; +use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; const N_LOG_INSTANCES_PER_ROW: usize = 3; const N_INSTANCES_PER_ROW: usize = 1 << N_LOG_INSTANCES_PER_ROW; @@ -73,7 +71,7 @@ impl AirTraceVerifier for PoseidonAir { } } -impl AirTraceWriter for PoseidonAir { +impl AirTraceGenerator for PoseidonAir { fn interact( &self, _trace: &ColumnVec>, @@ -358,7 +356,21 @@ pub fn gen_trace( .collect_vec() } -impl ComponentTraceWriter for PoseidonComponent { +impl ComponentTraceGenerator for PoseidonComponent { + type Component = Self; + type Inputs = (); + + fn add_inputs(&mut self, _inputs: &Self::Inputs) { + todo!() + } + + fn write_trace( + _component_id: &str, + _registry: &mut crate::trace_generation::registry::ComponentGenerationRegistry, + ) -> ColumnVec> { + todo!() + } + fn write_interaction_trace( &self, _trace: &ColumnVec<&CircleEvaluation>, @@ -366,6 +378,10 @@ impl ComponentTraceWriter for PoseidonComponent { ) -> ColumnVec> { vec![] } + + fn component(&self) -> Self::Component { + todo!() + } } struct PoseidonEvalAtDomain<'a> { diff --git a/crates/prover/src/examples/wide_fibonacci/component.rs b/crates/prover/src/examples/wide_fibonacci/component.rs index d8e4dcada..4e0483826 100644 --- a/crates/prover/src/examples/wide_fibonacci/component.rs +++ b/crates/prover/src/examples/wide_fibonacci/component.rs @@ -2,7 +2,7 @@ use itertools::Itertools; use crate::core::air::accumulation::PointEvaluationAccumulator; use crate::core::air::mask::fixed_mask_points; -use crate::core::air::{Air, Component, ComponentTraceWriter}; +use crate::core::air::{Air, Component}; use crate::core::backend::cpu::CpuCircleEvaluation; use crate::core::backend::CpuBackend; use crate::core::circle::{CirclePoint, Coset}; @@ -17,6 +17,8 @@ use crate::core::poly::BitReversedOrder; use crate::core::utils::shifted_secure_combination; use crate::core::{ColumnVec, InteractionElements, LookupValues}; use crate::examples::wide_fibonacci::trace_gen::write_lookup_column; +use crate::trace_generation::registry::ComponentGenerationRegistry; +use crate::trace_generation::ComponentTraceGenerator; pub const LOG_N_COLUMNS: usize = 8; pub const N_COLUMNS: usize = 1 << LOG_N_COLUMNS; @@ -32,6 +34,7 @@ pub const LOOKUP_VALUE_N_MINUS_1_ID: &str = "wide_fibonacci_n-1"; /// 2^`self.log_fibonacci_size`. The numbers are computes over [N_COLUMNS] trace columns. The /// number of rows (i.e the size of the columns) is determined by the parameters above (see /// [WideFibComponent::log_column_size()]). +#[derive(Clone)] pub struct WideFibComponent { pub log_fibonacci_size: u32, pub log_n_instances: u32, @@ -251,7 +254,19 @@ impl Component for WideFibComponent { } } -impl ComponentTraceWriter for WideFibComponent { +impl ComponentTraceGenerator for WideFibComponent { + type Component = Self; + type Inputs = (); + + fn add_inputs(&mut self, _inputs: &Self::Inputs) {} + + fn write_trace( + _component_id: &str, + _registry: &mut ComponentGenerationRegistry, + ) -> ColumnVec> { + vec![] + } + fn write_interaction_trace( &self, trace: &ColumnVec<&CircleEvaluation>, @@ -271,6 +286,10 @@ impl ComponentTraceWriter for WideFibComponent { }) .collect_vec() } + + fn component(&self) -> Self::Component { + self.clone() + } } // Input for the fibonacci claim. diff --git a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs b/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs index 6e2d09e91..a75d28f67 100644 --- a/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs +++ b/crates/prover/src/examples/wide_fibonacci/constraint_eval.rs @@ -9,10 +9,7 @@ use super::component::{ }; use super::trace_gen::write_trace_row; use crate::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator}; -use crate::core::air::{ - AirProver, AirTraceVerifier, AirTraceWriter, Component, ComponentProver, ComponentTrace, - ComponentTraceWriter, -}; +use crate::core::air::{AirProver, Component, ComponentProver, ComponentTrace}; use crate::core::backend::CpuBackend; use crate::core::channel::{Blake2sChannel, Channel}; use crate::core::circle::Coset; @@ -29,6 +26,7 @@ use crate::core::utils::{ }; use crate::core::{ColumnVec, InteractionElements, LookupValues}; use crate::examples::wide_fibonacci::component::LOG_N_COLUMNS; +use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; // TODO(AlonH): Rename file to `cpu.rs`. @@ -40,7 +38,7 @@ impl AirTraceVerifier for WideFibAir { } } -impl AirTraceWriter for WideFibAir { +impl AirTraceGenerator for WideFibAir { fn interact( &self, trace: &ColumnVec>, diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index befc6ec0e..cc23d7c1e 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -13,7 +13,7 @@ mod tests { use super::component::{Input, WideFibAir, WideFibComponent, LOG_N_COLUMNS}; use super::constraint_eval::gen_trace; use crate::core::air::accumulation::DomainEvaluationAccumulator; - use crate::core::air::{Component, ComponentProver, ComponentTrace, ComponentTraceWriter}; + use crate::core::air::{Component, ComponentProver, ComponentTrace}; use crate::core::backend::cpu::CpuCircleEvaluation; use crate::core::backend::CpuBackend; use crate::core::channel::{Blake2sChannel, Channel}; @@ -30,6 +30,7 @@ mod tests { use crate::core::vcs::hasher::Hasher; use crate::core::InteractionElements; use crate::examples::wide_fibonacci::trace_gen::write_lookup_column; + use crate::trace_generation::ComponentTraceGenerator; use crate::{m31, qm31}; pub fn assert_constraints_on_row(row: &[BaseField]) { diff --git a/crates/prover/src/examples/wide_fibonacci/simd.rs b/crates/prover/src/examples/wide_fibonacci/simd.rs index d5a6ae55c..b9d325b61 100644 --- a/crates/prover/src/examples/wide_fibonacci/simd.rs +++ b/crates/prover/src/examples/wide_fibonacci/simd.rs @@ -5,10 +5,7 @@ use tracing::{span, Level}; use super::component::LOG_N_COLUMNS; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::mask::fixed_mask_points; -use crate::core::air::{ - Air, AirProver, AirTraceVerifier, AirTraceWriter, Component, ComponentProver, ComponentTrace, - ComponentTraceWriter, -}; +use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace}; use crate::core::backend::simd::column::BaseFieldVec; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; @@ -26,8 +23,11 @@ use crate::core::poly::BitReversedOrder; use crate::core::prover::BASE_TRACE; use crate::core::{ColumnVec, InteractionElements, LookupValues}; use crate::examples::wide_fibonacci::component::{ALPHA_ID, N_COLUMNS, Z_ID}; +use crate::trace_generation::registry::ComponentGenerationRegistry; +use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; // TODO(AlonH): Remove this once the Cpu and Simd implementations are aligned. +#[derive(Clone)] pub struct SimdWideFibComponent { pub log_fibonacci_size: u32, pub log_n_instances: u32, @@ -66,7 +66,7 @@ impl AirTraceVerifier for SimdWideFibAir { } } -impl AirTraceWriter for SimdWideFibAir { +impl AirTraceGenerator for SimdWideFibAir { fn interact( &self, _trace: &ColumnVec>, @@ -162,7 +162,19 @@ pub fn gen_trace( } // TODO(AlonH): Implement. -impl ComponentTraceWriter for SimdWideFibComponent { +impl ComponentTraceGenerator for SimdWideFibComponent { + type Component = Self; + type Inputs = (); + + fn add_inputs(&mut self, _inputs: &Self::Inputs) {} + + fn write_trace( + _component_id: &str, + _registry: &mut ComponentGenerationRegistry, + ) -> ColumnVec> { + vec![] + } + fn write_interaction_trace( &self, _trace: &ColumnVec<&CircleEvaluation>, @@ -170,6 +182,10 @@ impl ComponentTraceWriter for SimdWideFibComponent { ) -> ColumnVec> { vec![] } + + fn component(&self) -> Self::Component { + self.clone() + } } impl ComponentProver for SimdWideFibComponent { diff --git a/crates/prover/src/trace_generation/mod.rs b/crates/prover/src/trace_generation/mod.rs index ebcb1b1a9..a11b5da5b 100644 --- a/crates/prover/src/trace_generation/mod.rs +++ b/crates/prover/src/trace_generation/mod.rs @@ -3,19 +3,20 @@ pub mod registry; use downcast_rs::{impl_downcast, Downcast}; use registry::ComponentGenerationRegistry; -use crate::core::air::Component; +use crate::core::air::{AirProver, Component}; use crate::core::backend::Backend; +use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; use crate::core::poly::circle::CircleEvaluation; use crate::core::poly::BitReversedOrder; -use crate::core::ColumnVec; +use crate::core::{ColumnVec, InteractionElements}; pub trait ComponentGen: Downcast {} impl_downcast!(ComponentGen); // A trait to generate a a trace. // Generates the trace given a list of inputs collects inputs for subcomponents. -pub trait TraceGenerator { +pub trait ComponentTraceGenerator { type Component: Component; type Inputs; @@ -33,5 +34,26 @@ pub trait TraceGenerator { registry: &mut ComponentGenerationRegistry, ) -> ColumnVec>; + /// Allocates and returns the interaction trace of the component. + fn write_interaction_trace( + &self, + trace: &ColumnVec<&CircleEvaluation>, + elements: &InteractionElements, + ) -> ColumnVec>; + fn component(&self) -> Self::Component; } + +pub trait AirTraceVerifier { + fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements; +} + +pub trait AirTraceGenerator: AirTraceVerifier { + fn interact( + &self, + trace: &ColumnVec>, + elements: &InteractionElements, + ) -> Vec>; + + fn to_air_prover(&self) -> &impl AirProver; +} diff --git a/crates/prover/src/trace_generation/registry.rs b/crates/prover/src/trace_generation/registry.rs index e81577655..f5ab16654 100644 --- a/crates/prover/src/trace_generation/registry.rs +++ b/crates/prover/src/trace_generation/registry.rs @@ -46,7 +46,7 @@ mod tests { use crate::core::poly::BitReversedOrder; use crate::core::{ColumnVec, InteractionElements, LookupValues}; use crate::m31; - use crate::trace_generation::TraceGenerator; + use crate::trace_generation::ComponentTraceGenerator; pub struct ComponentA { pub n_instances: usize, } @@ -97,7 +97,7 @@ mod tests { } impl ComponentGen for ComponentACpuTraceGenerator {} - impl TraceGenerator for ComponentACpuTraceGenerator { + impl ComponentTraceGenerator for ComponentACpuTraceGenerator { type Component = ComponentA; type Inputs = ComponentACpuInputs; @@ -117,6 +117,14 @@ mod tests { n_instances: self.inputs.len(), } } + + fn write_interaction_trace( + &self, + _trace: &ColumnVec<&CircleEvaluation>, + _elements: &InteractionElements, + ) -> ColumnVec> { + unimplemented!("TestTraceGenerator::write_interaction_trace") + } } type ComponentASimdInputs = Vec<(PackedM31, PackedM31)>; @@ -125,7 +133,7 @@ mod tests { } impl ComponentGen for ComponentASimdTraceGenerator {} - impl TraceGenerator for ComponentASimdTraceGenerator { + impl ComponentTraceGenerator for ComponentASimdTraceGenerator { type Component = ComponentA; type Inputs = ComponentASimdInputs; @@ -145,6 +153,14 @@ mod tests { n_instances: self.inputs.len() * N_LANES, } } + + fn write_interaction_trace( + &self, + _trace: &ColumnVec<&CircleEvaluation>, + _elements: &InteractionElements, + ) -> ColumnVec> { + unimplemented!("TestTraceGenerator::write_interaction_trace") + } } #[test]