diff --git a/crates/prover/src/trace_generation/mod.rs b/crates/prover/src/trace_generation/mod.rs index b05f4d201..45cd35cee 100644 --- a/crates/prover/src/trace_generation/mod.rs +++ b/crates/prover/src/trace_generation/mod.rs @@ -1,8 +1,9 @@ pub mod registry; use downcast_rs::{impl_downcast, Downcast}; -use registry::ComponentRegistry; +use registry::ComponentGenerationRegistry; +use crate::core::air::Component; use crate::core::backend::Backend; use crate::core::fields::m31::BaseField; use crate::core::poly::circle::CircleEvaluation; @@ -15,19 +16,22 @@ impl_downcast!(ComponentGen); // A trait to generate a a trace. // Generates the trace given a list of inputs collects inputs for subcomponents. pub trait TraceGenerator { - type ComponentInputs; + type Component: Component; + type Inputs; /// Add inputs for the trace generation of the component. /// This function should be called from the caller components before calling `write_trace` of /// this component. - fn add_inputs(&mut self, inputs: &Self::ComponentInputs); + fn add_inputs(&mut self, inputs: &Self::Inputs); /// Allocates and returns the trace of the component and updates the /// subcomponents with the corresponding inputs. /// Should be called only after all the inputs are available. // TODO(ShaharS): change `component_id` to a struct that contains the id and the component name. fn write_trace( - component_id: &str, - registry: &mut ComponentRegistry, + &self, + registry: &mut ComponentGenerationRegistry, ) -> ColumnVec>; + + fn component(&self) -> Self::Component; } diff --git a/crates/prover/src/trace_generation/registry.rs b/crates/prover/src/trace_generation/registry.rs index 7da8a79a8..f3f05e39e 100644 --- a/crates/prover/src/trace_generation/registry.rs +++ b/crates/prover/src/trace_generation/registry.rs @@ -3,28 +3,28 @@ use std::collections::HashMap; use super::ComponentGen; #[derive(Default)] -pub struct ComponentRegistry { +pub struct ComponentGenerationRegistry { components: HashMap>, } -impl ComponentRegistry { - pub fn register_component(&mut self, component_id: &str, component: impl ComponentGen) { +impl ComponentGenerationRegistry { + pub fn register(&mut self, component_id: &str, component: impl ComponentGen) { self.components .insert(component_id.to_string(), Box::new(component)); } - pub fn get_component(&self, component_id: &str) -> &T { + pub fn get_generator(&self, component_id: &str) -> &T { self.components .get(component_id) - .unwrap_or_else(|| panic!("Component name {} not found.", component_id)) + .unwrap_or_else(|| panic!("Component ID: {} not found.", component_id)) .downcast_ref() .unwrap() } - pub fn get_component_mut(&mut self, component_id: &str) -> &mut T { + pub fn get_generator_mut(&mut self, component_id: &str) -> &mut T { self.components .get_mut(component_id) - .unwrap_or_else(|| panic!("Component name {} not found.", component_id)) + .unwrap_or_else(|| panic!("Component ID: {} not found.", component_id)) .downcast_mut() .unwrap() } @@ -33,42 +33,137 @@ impl ComponentRegistry { #[cfg(test)] mod tests { use super::*; + use crate::core::air::accumulation::PointEvaluationAccumulator; + use crate::core::air::Component; + use crate::core::backend::simd::m31::{PackedM31, N_LANES}; + use crate::core::backend::simd::SimdBackend; use crate::core::backend::CpuBackend; - use crate::core::fields::m31::BaseField; + use crate::core::circle::CirclePoint; + use crate::core::fields::m31::{BaseField, M31}; + use crate::core::fields::qm31::SecureField; + use crate::core::pcs::TreeVec; use crate::core::poly::circle::CircleEvaluation; use crate::core::poly::BitReversedOrder; - use crate::core::ColumnVec; + use crate::core::{ColumnVec, InteractionElements}; + use crate::m31; use crate::trace_generation::TraceGenerator; - - #[derive(Default)] - struct ComponentA { - inputs: Vec, + pub struct ComponentA { + pub n_instances: usize, } - impl ComponentGen for ComponentA {} + impl Component for ComponentA { + fn n_constraints(&self) -> usize { + todo!() + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + todo!() + } + + fn n_interaction_phases(&self) -> u32 { + todo!() + } + + fn trace_log_degree_bounds(&self) -> TreeVec> { + todo!() + } + + fn mask_points( + &self, + _point: CirclePoint, + ) -> TreeVec>>> { + todo!() + } - impl TraceGenerator for ComponentA { - type ComponentInputs = u32; + fn interaction_element_ids(&self) -> Vec { + todo!() + } - fn add_inputs(&mut self, _inputs: &Self::ComponentInputs) { - unimplemented!("TestTraceGenerator::add_inputs") + fn evaluate_constraint_quotients_at_point( + &self, + _point: CirclePoint, + _mask: &ColumnVec>, + _evaluation_accumulator: &mut PointEvaluationAccumulator, + _interaction_elements: &InteractionElements, + ) { + todo!() } + } + + type ComponentACpuInputs = Vec<(M31, M31)>; + struct ComponentACpuTraceGenerator { + inputs: ComponentACpuInputs, + } + impl ComponentGen for ComponentACpuTraceGenerator {} + + impl TraceGenerator for ComponentACpuTraceGenerator { + type Component = ComponentA; + type Inputs = ComponentACpuInputs; fn write_trace( - _component_id: &str, - _registry: &mut ComponentRegistry, + &self, + _registry: &mut ComponentGenerationRegistry, ) -> ColumnVec> { unimplemented!("TestTraceGenerator::write_trace") } + + fn add_inputs(&mut self, inputs: &ComponentACpuInputs) { + self.inputs.extend(inputs) + } + + fn component(&self) -> ComponentA { + ComponentA { + n_instances: self.inputs.len(), + } + } + } + + type ComponentASimdInputs = Vec<(PackedM31, PackedM31)>; + struct ComponentASimdTraceGenerator { + inputs: ComponentASimdInputs, + } + impl ComponentGen for ComponentASimdTraceGenerator {} + + impl TraceGenerator for ComponentASimdTraceGenerator { + type Component = ComponentA; + type Inputs = ComponentASimdInputs; + + fn write_trace( + &self, + _registry: &mut ComponentGenerationRegistry, + ) -> ColumnVec> { + unimplemented!("TestTraceGenerator::write_trace") + } + + fn add_inputs(&mut self, inputs: &ComponentASimdInputs) { + self.inputs.extend(inputs) + } + + fn component(&self) -> ComponentA { + ComponentA { + n_instances: self.inputs.len() * N_LANES, + } + } } #[test] fn test_component_registry() { - let mut registry = ComponentRegistry::default(); - let component = ComponentA { inputs: vec![1] }; + let mut registry = ComponentGenerationRegistry::default(); + let component_id = "componentA::0"; + + let component_a_cpu_trace_generator = ComponentACpuTraceGenerator { inputs: vec![] }; + registry.register(component_id, component_a_cpu_trace_generator); + let cpu_inputs = vec![(m31!(1), m31!(1)), (m31!(2), m31!(2))]; - registry.register_component("test", component); + registry + .get_generator_mut::(component_id) + .add_inputs(&cpu_inputs); - assert_eq!(registry.get_component::("test").inputs, vec![1]); + assert_eq!( + registry + .get_generator_mut::(component_id) + .inputs, + cpu_inputs + ); } }