Skip to content

Commit

Permalink
seperated trace gen trait (#678)
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware authored Jun 26, 2024
1 parent 1b443ac commit 7d2397f
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 29 deletions.
14 changes: 9 additions & 5 deletions crates/prover/src/trace_generation/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
pub mod registry;

use downcast_rs::{impl_downcast, Downcast};
use registry::ComponentRegistry;
use registry::ComponentGenerationRegistry;

use crate::core::air::Component;
use crate::core::backend::Backend;
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::CircleEvaluation;
Expand All @@ -15,19 +16,22 @@ impl_downcast!(ComponentGen);
// A trait to generate a a trace.
// Generates the trace given a list of inputs collects inputs for subcomponents.
pub trait TraceGenerator<B: Backend> {
type ComponentInputs;
type Component: Component;
type Inputs;

/// Add inputs for the trace generation of the component.
/// This function should be called from the caller components before calling `write_trace` of
/// this component.
fn add_inputs(&mut self, inputs: &Self::ComponentInputs);
fn add_inputs(&mut self, inputs: &Self::Inputs);

/// Allocates and returns the trace of the component and updates the
/// subcomponents with the corresponding inputs.
/// Should be called only after all the inputs are available.
// TODO(ShaharS): change `component_id` to a struct that contains the id and the component name.
fn write_trace(
component_id: &str,
registry: &mut ComponentRegistry,
&self,
registry: &mut ComponentGenerationRegistry,
) -> ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>;

fn component(&self) -> Self::Component;
}
143 changes: 119 additions & 24 deletions crates/prover/src/trace_generation/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,28 @@ use std::collections::HashMap;
use super::ComponentGen;

#[derive(Default)]
pub struct ComponentRegistry {
pub struct ComponentGenerationRegistry {
components: HashMap<String, Box<dyn ComponentGen>>,
}

impl ComponentRegistry {
pub fn register_component(&mut self, component_id: &str, component: impl ComponentGen) {
impl ComponentGenerationRegistry {
pub fn register(&mut self, component_id: &str, component: impl ComponentGen) {
self.components
.insert(component_id.to_string(), Box::new(component));
}

pub fn get_component<T: ComponentGen>(&self, component_id: &str) -> &T {
pub fn get_generator<T: ComponentGen>(&self, component_id: &str) -> &T {
self.components
.get(component_id)
.unwrap_or_else(|| panic!("Component name {} not found.", component_id))
.unwrap_or_else(|| panic!("Component ID: {} not found.", component_id))
.downcast_ref()
.unwrap()
}

pub fn get_component_mut<T: ComponentGen>(&mut self, component_id: &str) -> &mut T {
pub fn get_generator_mut<T: ComponentGen>(&mut self, component_id: &str) -> &mut T {
self.components
.get_mut(component_id)
.unwrap_or_else(|| panic!("Component name {} not found.", component_id))
.unwrap_or_else(|| panic!("Component ID: {} not found.", component_id))
.downcast_mut()
.unwrap()
}
Expand All @@ -33,42 +33,137 @@ impl ComponentRegistry {
#[cfg(test)]
mod tests {
use super::*;
use crate::core::air::accumulation::PointEvaluationAccumulator;
use crate::core::air::Component;
use crate::core::backend::simd::m31::{PackedM31, N_LANES};
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::{BaseField, M31};
use crate::core::fields::qm31::SecureField;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;
use crate::core::{ColumnVec, InteractionElements};
use crate::m31;
use crate::trace_generation::TraceGenerator;

#[derive(Default)]
struct ComponentA {
inputs: Vec<u32>,
pub struct ComponentA {
pub n_instances: usize,
}

impl ComponentGen for ComponentA {}
impl Component for ComponentA {
fn n_constraints(&self) -> usize {
todo!()
}

fn max_constraint_log_degree_bound(&self) -> u32 {
todo!()
}

fn n_interaction_phases(&self) -> u32 {
todo!()
}

fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
todo!()
}

fn mask_points(
&self,
_point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
todo!()
}

impl TraceGenerator<CpuBackend> for ComponentA {
type ComponentInputs = u32;
fn interaction_element_ids(&self) -> Vec<String> {
todo!()
}

fn add_inputs(&mut self, _inputs: &Self::ComponentInputs) {
unimplemented!("TestTraceGenerator::add_inputs")
fn evaluate_constraint_quotients_at_point(
&self,
_point: CirclePoint<SecureField>,
_mask: &ColumnVec<Vec<SecureField>>,
_evaluation_accumulator: &mut PointEvaluationAccumulator,
_interaction_elements: &InteractionElements,
) {
todo!()
}
}

type ComponentACpuInputs = Vec<(M31, M31)>;
struct ComponentACpuTraceGenerator {
inputs: ComponentACpuInputs,
}
impl ComponentGen for ComponentACpuTraceGenerator {}

impl TraceGenerator<CpuBackend> for ComponentACpuTraceGenerator {
type Component = ComponentA;
type Inputs = ComponentACpuInputs;

fn write_trace(
_component_id: &str,
_registry: &mut ComponentRegistry,
&self,
_registry: &mut ComponentGenerationRegistry,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
unimplemented!("TestTraceGenerator::write_trace")
}

fn add_inputs(&mut self, inputs: &ComponentACpuInputs) {
self.inputs.extend(inputs)
}

fn component(&self) -> ComponentA {
ComponentA {
n_instances: self.inputs.len(),
}
}
}

type ComponentASimdInputs = Vec<(PackedM31, PackedM31)>;
struct ComponentASimdTraceGenerator {
inputs: ComponentASimdInputs,
}
impl ComponentGen for ComponentASimdTraceGenerator {}

impl TraceGenerator<SimdBackend> for ComponentASimdTraceGenerator {
type Component = ComponentA;
type Inputs = ComponentASimdInputs;

fn write_trace(
&self,
_registry: &mut ComponentGenerationRegistry,
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
unimplemented!("TestTraceGenerator::write_trace")
}

fn add_inputs(&mut self, inputs: &ComponentASimdInputs) {
self.inputs.extend(inputs)
}

fn component(&self) -> ComponentA {
ComponentA {
n_instances: self.inputs.len() * N_LANES,
}
}
}

#[test]
fn test_component_registry() {
let mut registry = ComponentRegistry::default();
let component = ComponentA { inputs: vec![1] };
let mut registry = ComponentGenerationRegistry::default();
let component_id = "componentA::0";

let component_a_cpu_trace_generator = ComponentACpuTraceGenerator { inputs: vec![] };
registry.register(component_id, component_a_cpu_trace_generator);
let cpu_inputs = vec![(m31!(1), m31!(1)), (m31!(2), m31!(2))];

registry.register_component("test", component);
registry
.get_generator_mut::<ComponentACpuTraceGenerator>(component_id)
.add_inputs(&cpu_inputs);

assert_eq!(registry.get_component::<ComponentA>("test").inputs, vec![1]);
assert_eq!(
registry
.get_generator_mut::<ComponentACpuTraceGenerator>(component_id)
.inputs,
cpu_inputs
);
}
}

0 comments on commit 7d2397f

Please sign in to comment.