Skip to content

Commit

Permalink
Create ComponentTraceWriter trait.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed May 20, 2024
1 parent b3619fe commit 68d7b88
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 20 deletions.
21 changes: 20 additions & 1 deletion crates/prover/src/core/air/air_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, SecureCirclePoly};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::LOG_BLOWUP_FACTOR;
use crate::core::{ComponentVec, InteractionElements};
use crate::core::{ColumnVec, ComponentVec, InteractionElements};

pub trait AirExt: Air {
fn composition_log_degree_bound(&self) -> u32 {
Expand Down Expand Up @@ -102,6 +102,24 @@ pub trait AirExt: Air {
impl<A: Air + ?Sized> AirExt for A {}

pub trait AirProverExt<B: Backend>: AirProver<B> {
fn interact(
&self,
trace: &ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<B, BaseField, BitReversedOrder>> {
let trace_iter = &mut trace.iter();
ComponentVec(
self.prover_components()
.iter()
.map(|component| {
let n_columns = component.trace_log_degree_bounds().len();
let trace_columns = trace_iter.take(n_columns).collect_vec();
component.write_interaction_trace(&trace_columns, elements)
})
.collect(),
)
}

fn compute_composition_polynomial(
&self,
random_coeff: SecureField,
Expand All @@ -123,4 +141,5 @@ pub trait AirProverExt<B: Backend>: AirProver<B> {
accumulator.finalize()
}
}

impl<B: Backend, A: AirProver<B>> AirProverExt<B> for A {}
12 changes: 10 additions & 2 deletions crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::poly::circle::{CircleEvaluation, CirclePoly};
use super::poly::BitReversedOrder;
use super::ColumnVec;
use super::{ColumnVec, InteractionElements};

pub mod accumulation;
mod air_ext;
Expand Down Expand Up @@ -53,7 +53,15 @@ pub trait Component {
);
}

pub trait ComponentProver<B: Backend>: Component {
pub trait ComponentTraceWriter<B: Backend> {
fn write_interaction_trace(
&self,
trace: &ColumnVec<&CircleEvaluation<B, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>;
}

pub trait ComponentProver<B: Backend>: Component + ComponentTraceWriter<B> {
/// Evaluates the constraint quotients of the component on the evaluation domain.
/// Accumulates quotients in `evaluation_accumulator`.
fn evaluate_constraint_quotients_on_domain(
Expand Down
10 changes: 9 additions & 1 deletion crates/prover/src/core/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Deref, DerefMut};
use std::ops::{Deref, DerefMut, Index};

use self::fields::m31::BaseField;

Expand Down Expand Up @@ -61,3 +61,11 @@ impl<T> DerefMut for ComponentVec<T> {
}

pub struct InteractionElements(Vec<(String, BaseField)>);

impl Index<&str> for InteractionElements {
type Output = BaseField;

fn index(&self, index: &str) -> &Self::Output {
&self.0.iter().find(|(id, _)| id == index).unwrap().1
}
}
20 changes: 18 additions & 2 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,21 @@ mod tests {
use num_traits::Zero;

use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace};
use crate::core::air::{
Air, AirProver, Component, ComponentProver, ComponentTrace, ComponentTraceWriter,
};
use crate::core::backend::cpu::CPUCircleEvaluation;
use crate::core::backend::CPUBackend;
use crate::core::circle::{CirclePoint, CirclePointIndex, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::{CanonicCoset, CircleDomain, MAX_CIRCLE_DOMAIN_LOG_SIZE};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, MAX_CIRCLE_DOMAIN_LOG_SIZE,
};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{prove, ProvingError};
use crate::core::test_utils::test_channel;
use crate::core::{ColumnVec, InteractionElements};
use crate::qm31;

struct TestAir<C: ComponentProver<CPUBackend>> {
Expand Down Expand Up @@ -335,6 +341,16 @@ mod tests {
}
}

impl ComponentTraceWriter<CPUBackend> for TestComponent {
fn write_interaction_trace(
&self,
_trace: &ColumnVec<&CircleEvaluation<CPUBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CPUBackend, BaseField, BitReversedOrder>> {
vec![]
}
}

impl ComponentProver<CPUBackend> for TestComponent {
fn evaluate_constraint_quotients_on_domain(
&self,
Expand Down
17 changes: 14 additions & 3 deletions crates/prover/src/examples/fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ use num_traits::One;

use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::mask::shifted_mask_points;
use crate::core::air::{Component, ComponentProver, ComponentTrace};
use crate::core::air::{Component, ComponentProver, ComponentTrace, ComponentTraceWriter};
use crate::core::backend::CPUBackend;
use crate::core::circle::{CirclePoint, Coset};
use crate::core::constraints::{coset_vanishing, pair_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, FieldExpOps};
use crate::core::poly::circle::CanonicCoset;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse_index;
use crate::core::ColumnVec;
use crate::core::{ColumnVec, InteractionElements};

pub struct FibonacciComponent {
pub log_size: u32,
Expand Down Expand Up @@ -116,6 +117,16 @@ impl Component for FibonacciComponent {
}
}

impl ComponentTraceWriter<CPUBackend> for FibonacciComponent {
fn write_interaction_trace(
&self,
_trace: &ColumnVec<&CircleEvaluation<CPUBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CPUBackend, BaseField, BitReversedOrder>> {
vec![]
}
}

impl ComponentProver<CPUBackend> for FibonacciComponent {
fn evaluate_constraint_quotients_on_domain(
&self,
Expand Down
16 changes: 14 additions & 2 deletions crates/prover/src/examples/wide_fibonacci/avx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use tracing::{span, Level};

use super::component::{WideFibAir, WideFibComponent};
use crate::core::air::accumulation::DomainEvaluationAccumulator;
use crate::core::air::{AirProver, Component, ComponentProver, ComponentTrace};
use crate::core::air::{
AirProver, Component, ComponentProver, ComponentTrace, ComponentTraceWriter,
};
use crate::core::backend::avx512::qm31::PackedSecureField;
use crate::core::backend::avx512::{AVX512Backend, BaseFieldVec, PackedBaseField, VECS_LOG_SIZE};
use crate::core::backend::{Col, Column, ColumnOps};
Expand All @@ -13,7 +15,7 @@ use crate::core::fields::m31::BaseField;
use crate::core::fields::{FieldExpOps, FieldOps};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;
use crate::core::{ColumnVec, InteractionElements};
use crate::examples::wide_fibonacci::component::N_COLUMNS;

impl AirProver<AVX512Backend> for WideFibAir {
Expand Down Expand Up @@ -48,6 +50,16 @@ pub fn gen_trace(
.collect_vec()
}

impl ComponentTraceWriter<AVX512Backend> for WideFibComponent {
fn write_interaction_trace(
&self,
_trace: &ColumnVec<&CircleEvaluation<AVX512Backend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<AVX512Backend, BaseField, BitReversedOrder>> {
vec![]
}
}

impl ComponentProver<AVX512Backend> for WideFibComponent {
fn evaluate_constraint_quotients_on_domain(
&self,
Expand Down
34 changes: 27 additions & 7 deletions crates/prover/src/examples/wide_fibonacci/component.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
use itertools::Itertools;

use crate::core::air::accumulation::PointEvaluationAccumulator;
use crate::core::air::mask::fixed_mask_points;
use crate::core::air::{Air, Component};
use crate::core::air::{Air, Component, ComponentTraceWriter};
use crate::core::backend::CPUBackend;
use crate::core::circle::CirclePoint;
use crate::core::constraints::coset_vanishing;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::CanonicCoset;
use crate::core::ColumnVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::{ColumnVec, InteractionElements};
use crate::examples::wide_fibonacci::trace_gen::write_lookup_column;

pub const LOG_N_COLUMNS: usize = 8;
pub const N_COLUMNS: usize = 1 << LOG_N_COLUMNS;

const ALPHA_ID: &str = "wide_fibonacci_alpha";
const Z_ID: &str = "wide_fibonacci_z";

/// Component that computes 2^`self.log_n_instances` instances of fibonacci sequences of size
/// 2^`self.log_fibonacci_size`. The numbers are computes over [N_COLUMNS] trace columns. The
/// number of rows (i.e the size of the columns) is determined by the parameters above (see
Expand Down Expand Up @@ -68,10 +76,7 @@ impl Component for WideFibComponent {
}

fn interaction_element_ids(&self) -> Vec<String> {
vec![
"wide_fibonacci_alpha".to_string(),
"wide_fibonacci_z".to_string(),
]
vec![ALPHA_ID.to_string(), Z_ID.to_string()]
}

fn evaluate_constraint_quotients_at_point(
Expand All @@ -90,6 +95,21 @@ impl Component for WideFibComponent {
}
}

impl ComponentTraceWriter<CPUBackend> for WideFibComponent {
fn write_interaction_trace(
&self,
trace: &ColumnVec<&CircleEvaluation<CPUBackend, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CPUBackend, BaseField, BitReversedOrder>> {
let domain = trace[0].domain;
let input_trace = trace.iter().map(|eval| &eval.values).collect_vec();
let (alpha, z) = (elements[ALPHA_ID], elements[Z_ID]);
let values = write_lookup_column(&input_trace, alpha, z);
let eval = CircleEvaluation::new(domain, values);
vec![eval]
}
}

// Input for the fibonacci claim.
#[derive(Debug, Clone, Copy)]
pub struct Input {
Expand Down
3 changes: 2 additions & 1 deletion crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ mod tests {
let alpha = m31!(7);
let z = m31!(11);
let trace = gen_trace(&wide_fib, vec![input]);
let lookup_column = write_lookup_column(&trace, alpha, z);
let input_trace = trace.iter().collect_vec();
let lookup_column = write_lookup_column(&input_trace, alpha, z);

assert_constraints_on_lookup_column(&lookup_column, &trace, alpha, z)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/wide_fibonacci/trace_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub fn write_trace_row(
}

pub fn write_lookup_column(
input_trace: &[Vec<BaseField>],
input_trace: &[&Vec<BaseField>],
// TODO(AlonH): Change alpha and z to SecureField.
alpha: BaseField,
z: BaseField,
Expand Down

0 comments on commit 68d7b88

Please sign in to comment.