Skip to content

Commit

Permalink
Merge trace writer and trace generator.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Jul 8, 2024
1 parent 91a2833 commit 2ce2998
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 65 deletions.
23 changes: 0 additions & 23 deletions crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use self::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use super::backend::Backend;
use super::channel::Blake2sChannel;
use super::circle::CirclePoint;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
Expand All @@ -25,20 +24,6 @@ pub trait Air {
fn components(&self) -> Vec<&dyn Component>;
}

pub trait AirTraceVerifier {
fn interaction_elements(&self, channel: &mut Blake2sChannel) -> InteractionElements;
}

pub trait AirTraceWriter<B: Backend>: AirTraceVerifier {
fn interact(
&self,
trace: &ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> Vec<CircleEvaluation<B, BaseField, BitReversedOrder>>;

fn to_air_prover(&self) -> &impl AirProver<B>;
}

pub trait AirProver<B: Backend>: Air {
fn prover_components(&self) -> Vec<&dyn ComponentProver<B>>;
}
Expand Down Expand Up @@ -78,14 +63,6 @@ pub trait Component {
);
}

pub trait ComponentTraceWriter<B: Backend> {
fn write_interaction_trace(
&self,
trace: &ColumnVec<&CircleEvaluation<B, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>;
}

pub trait ComponentProver<B: Backend>: Component {
/// Evaluates the constraint quotients of the component on the evaluation domain.
/// Accumulates quotients in `evaluation_accumulator`.
Expand Down
35 changes: 26 additions & 9 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use itertools::Itertools;
use thiserror::Error;
use tracing::{span, Level};

use super::air::{AirProver, AirTraceVerifier, AirTraceWriter};
use super::air::AirProver;
use super::backend::Backend;
use super::fields::secure_column::SECURE_EXTENSION_DEGREE;
use super::fri::FriVerificationError;
Expand All @@ -25,6 +25,7 @@ use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::hasher::Hasher;
use crate::core::vcs::ops::MerkleOps;
use crate::core::vcs::verifier::MerkleVerificationError;
use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier};

type Channel = Blake2sChannel;
type ChannelHasher = Blake2sHasher;
Expand Down Expand Up @@ -54,7 +55,7 @@ pub struct AdditionalProofData {
}

pub fn evaluate_and_commit_on_trace<B: Backend + MerkleOps<MerkleHasher>>(
air: &impl AirTraceWriter<B>,
air: &impl AirTraceGenerator<B>,
channel: &mut Channel,
twiddles: &TwiddleTree<B>,
trace: ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
Expand Down Expand Up @@ -155,7 +156,7 @@ pub fn generate_proof<B: Backend + MerkleOps<MerkleHasher>>(
}

pub fn prove<B: Backend + MerkleOps<MerkleHasher>>(
air: &impl AirTraceWriter<B>,
air: &impl AirTraceGenerator<B>,
channel: &mut Channel,
trace: ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
) -> Result<StarkProof, ProvingError> {
Expand Down Expand Up @@ -348,10 +349,7 @@ mod tests {
use num_traits::Zero;

use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{
Air, AirProver, AirTraceVerifier, AirTraceWriter, Component, ComponentProver,
ComponentTrace, ComponentTraceWriter,
};
use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace};
use crate::core::backend::cpu::CpuCircleEvaluation;
use crate::core::backend::CpuBackend;
use crate::core::channel::Blake2sChannel;
Expand All @@ -367,6 +365,8 @@ mod tests {
use crate::core::test_utils::test_channel;
use crate::core::{ColumnVec, InteractionElements, LookupValues};
use crate::qm31;
use crate::trace_generation::registry::ComponentGenerationRegistry;
use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator};

struct TestAir<C: ComponentProver<CpuBackend>> {
component: C,
Expand All @@ -384,7 +384,7 @@ mod tests {
}
}

impl AirTraceWriter<CpuBackend> for TestAir<TestComponent> {
impl AirTraceGenerator<CpuBackend> for TestAir<TestComponent> {
fn interact(
&self,
_trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
Expand All @@ -404,6 +404,7 @@ mod tests {
}
}

#[derive(Clone)]
struct TestComponent {
log_size: u32,
max_constraint_log_degree_bound: u32,
Expand Down Expand Up @@ -449,14 +450,30 @@ mod tests {
}
}

impl ComponentTraceWriter<CpuBackend> for TestComponent {
impl ComponentTraceGenerator<CpuBackend> for TestComponent {
type Component = Self;
type Inputs = ();

fn add_inputs(&mut self, _inputs: &Self::Inputs) {}

fn write_trace(
_component_id: &str,
_registry: &mut ComponentGenerationRegistry,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn write_interaction_trace(
&self,
_trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn component(&self) -> Self::Component {
self.clone()
}
}

impl ComponentProver<CpuBackend> for TestComponent {
Expand Down
9 changes: 4 additions & 5 deletions crates/prover/src/examples/fibonacci/air.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use itertools::{zip_eq, Itertools};

use super::component::FibonacciComponent;
use crate::core::air::{
Air, AirProver, AirTraceVerifier, AirTraceWriter, Component, ComponentProver,
};
use crate::core::air::{Air, AirProver, Component, ComponentProver};
use crate::core::backend::CpuBackend;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
use crate::core::{ColumnVec, InteractionElements};
use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier};

pub struct FibonacciAir {
pub component: FibonacciComponent,
Expand All @@ -33,7 +32,7 @@ impl AirTraceVerifier for FibonacciAir {
}
}

impl AirTraceWriter<CpuBackend> for FibonacciAir {
impl AirTraceGenerator<CpuBackend> for FibonacciAir {
fn interact(
&self,
_trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
Expand Down Expand Up @@ -82,7 +81,7 @@ impl AirTraceVerifier for MultiFibonacciAir {
}
}

impl AirTraceWriter<CpuBackend> for MultiFibonacciAir {
impl AirTraceGenerator<CpuBackend> for MultiFibonacciAir {
fn interact(
&self,
_trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
Expand Down
23 changes: 21 additions & 2 deletions crates/prover/src/examples/fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use num_traits::One;

use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::mask::shifted_mask_points;
use crate::core::air::{Component, ComponentProver, ComponentTrace, ComponentTraceWriter};
use crate::core::air::{Component, ComponentProver, ComponentTrace};
use crate::core::backend::CpuBackend;
use crate::core::circle::{CirclePoint, Coset};
use crate::core::constraints::{coset_vanishing, pair_vanishing};
Expand All @@ -17,7 +17,10 @@ use crate::core::poly::BitReversedOrder;
use crate::core::prover::BASE_TRACE;
use crate::core::utils::bit_reverse_index;
use crate::core::{ColumnVec, InteractionElements, LookupValues};
use crate::trace_generation::registry::ComponentGenerationRegistry;
use crate::trace_generation::ComponentTraceGenerator;

#[derive(Clone)]
pub struct FibonacciComponent {
pub log_size: u32,
pub claim: BaseField,
Expand Down Expand Up @@ -123,14 +126,30 @@ impl Component for FibonacciComponent {
}
}

impl ComponentTraceWriter<CpuBackend> for FibonacciComponent {
impl ComponentTraceGenerator<CpuBackend> for FibonacciComponent {
type Component = Self;
type Inputs = ();

fn add_inputs(&mut self, _inputs: &Self::Inputs) {}

fn write_trace(
_component_id: &str,
_registry: &mut ComponentGenerationRegistry,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn write_interaction_trace(
&self,
_trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn component(&self) -> Self::Component {
self.clone()
}
}

impl ComponentProver<CpuBackend> for FibonacciComponent {
Expand Down
28 changes: 22 additions & 6 deletions crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ use tracing::{span, Level};

use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::mask::fixed_mask_points;
use crate::core::air::{
Air, AirProver, AirTraceVerifier, AirTraceWriter, Component, ComponentProver, ComponentTrace,
ComponentTraceWriter,
};
use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace};
use crate::core::backend::simd::column::BaseFieldVec;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
Expand All @@ -27,6 +24,7 @@ use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::BitReversedOrder;
use crate::core::{ColumnVec, InteractionElements, LookupValues};
use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator};

const N_LOG_INSTANCES_PER_ROW: usize = 3;
const N_INSTANCES_PER_ROW: usize = 1 << N_LOG_INSTANCES_PER_ROW;
Expand Down Expand Up @@ -73,7 +71,7 @@ impl AirTraceVerifier for PoseidonAir {
}
}

impl AirTraceWriter<SimdBackend> for PoseidonAir {
impl AirTraceGenerator<SimdBackend> for PoseidonAir {
fn interact(
&self,
_trace: &ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
Expand Down Expand Up @@ -358,14 +356,32 @@ pub fn gen_trace(
.collect_vec()
}

impl ComponentTraceWriter<SimdBackend> for PoseidonComponent {
impl ComponentTraceGenerator<SimdBackend> for PoseidonComponent {
type Component = Self;
type Inputs = ();

fn add_inputs(&mut self, _inputs: &Self::Inputs) {
todo!()
}

fn write_trace(
_component_id: &str,
_registry: &mut crate::trace_generation::registry::ComponentGenerationRegistry,
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
todo!()
}

fn write_interaction_trace(
&self,
_trace: &ColumnVec<&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn component(&self) -> Self::Component {
todo!()
}
}

struct PoseidonEvalAtDomain<'a> {
Expand Down
23 changes: 21 additions & 2 deletions crates/prover/src/examples/wide_fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use itertools::Itertools;

use crate::core::air::accumulation::PointEvaluationAccumulator;
use crate::core::air::mask::fixed_mask_points;
use crate::core::air::{Air, Component, ComponentTraceWriter};
use crate::core::air::{Air, Component};
use crate::core::backend::cpu::CpuCircleEvaluation;
use crate::core::backend::CpuBackend;
use crate::core::circle::{CirclePoint, Coset};
Expand All @@ -17,6 +17,8 @@ use crate::core::poly::BitReversedOrder;
use crate::core::utils::shifted_secure_combination;
use crate::core::{ColumnVec, InteractionElements, LookupValues};
use crate::examples::wide_fibonacci::trace_gen::write_lookup_column;
use crate::trace_generation::registry::ComponentGenerationRegistry;
use crate::trace_generation::ComponentTraceGenerator;

pub const LOG_N_COLUMNS: usize = 8;
pub const N_COLUMNS: usize = 1 << LOG_N_COLUMNS;
Expand All @@ -32,6 +34,7 @@ pub const LOOKUP_VALUE_N_MINUS_1_ID: &str = "wide_fibonacci_n-1";
/// 2^`self.log_fibonacci_size`. The numbers are computes over [N_COLUMNS] trace columns. The
/// number of rows (i.e the size of the columns) is determined by the parameters above (see
/// [WideFibComponent::log_column_size()]).
#[derive(Clone)]
pub struct WideFibComponent {
pub log_fibonacci_size: u32,
pub log_n_instances: u32,
Expand Down Expand Up @@ -251,7 +254,19 @@ impl Component for WideFibComponent {
}
}

impl ComponentTraceWriter<CpuBackend> for WideFibComponent {
impl ComponentTraceGenerator<CpuBackend> for WideFibComponent {
type Component = Self;
type Inputs = ();

fn add_inputs(&mut self, _inputs: &Self::Inputs) {}

fn write_trace(
_component_id: &str,
_registry: &mut ComponentGenerationRegistry,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}

fn write_interaction_trace(
&self,
trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
Expand All @@ -271,6 +286,10 @@ impl ComponentTraceWriter<CpuBackend> for WideFibComponent {
})
.collect_vec()
}

fn component(&self) -> Self::Component {
self.clone()
}
}

// Input for the fibonacci claim.
Expand Down
8 changes: 3 additions & 5 deletions crates/prover/src/examples/wide_fibonacci/constraint_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ use super::component::{
};
use super::trace_gen::write_trace_row;
use crate::core::air::accumulation::{ColumnAccumulator, DomainEvaluationAccumulator};
use crate::core::air::{
AirProver, AirTraceVerifier, AirTraceWriter, Component, ComponentProver, ComponentTrace,
ComponentTraceWriter,
};
use crate::core::air::{AirProver, Component, ComponentProver, ComponentTrace};
use crate::core::backend::CpuBackend;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::circle::Coset;
Expand All @@ -29,6 +26,7 @@ use crate::core::utils::{
};
use crate::core::{ColumnVec, InteractionElements, LookupValues};
use crate::examples::wide_fibonacci::component::LOG_N_COLUMNS;
use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator};

// TODO(AlonH): Rename file to `cpu.rs`.

Expand All @@ -40,7 +38,7 @@ impl AirTraceVerifier for WideFibAir {
}
}

impl AirTraceWriter<CpuBackend> for WideFibAir {
impl AirTraceGenerator<CpuBackend> for WideFibAir {
fn interact(
&self,
trace: &ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
Expand Down
Loading

0 comments on commit 2ce2998

Please sign in to comment.