Skip to content

Commit

Permalink
Create ComponentTraceWriter trait.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed May 26, 2024
1 parent ab4f4c1 commit e7f0a82
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 23 deletions.
24 changes: 22 additions & 2 deletions crates/prover/src/core/air/air_ext.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::iter::zip;

use itertools::{zip_eq, Itertools};
Expand All @@ -12,7 +13,7 @@ use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, SecureCirclePoly};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::LOG_BLOWUP_FACTOR;
use crate::core::{ComponentVec, InteractionElements};
use crate::core::{ColumnVec, ComponentVec, InteractionElements};

pub trait AirExt: Air {
fn composition_log_degree_bound(&self) -> u32 {
Expand Down Expand Up @@ -51,7 +52,7 @@ pub trait AirExt: Air {
.dedup()
.collect_vec();
let elements = channel.draw_felts(ids.len()).into_iter().map(|e| e.0 .0);
InteractionElements(zip_eq(ids, elements).collect_vec())
InteractionElements(BTreeMap::from_iter(zip_eq(ids, elements)))
}

fn eval_composition_polynomial_at_point(
Expand Down Expand Up @@ -99,6 +100,24 @@ pub trait AirExt: Air {
impl<A: Air + ?Sized> AirExt for A {}

pub trait AirProverExt<B: Backend>: AirProver<B> {
fn interact(
&self,
trace: &ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ComponentVec<CircleEvaluation<B, BaseField, BitReversedOrder>> {
let trace_iter = &mut trace.iter();
ComponentVec(
self.prover_components()
.iter()
.map(|component| {
let n_columns = component.trace_log_degree_bounds().len();
let trace_columns = trace_iter.take(n_columns).collect_vec();
component.write_interaction_trace(&trace_columns, elements)
})
.collect(),
)
}

fn compute_composition_polynomial(
&self,
random_coeff: SecureField,
Expand All @@ -120,4 +139,5 @@ pub trait AirProverExt<B: Backend>: AirProver<B> {
accumulator.finalize()
}
}

impl<B: Backend, A: AirProver<B>> AirProverExt<B> for A {}
12 changes: 10 additions & 2 deletions crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::poly::circle::{CircleEvaluation, CirclePoly};
use super::poly::BitReversedOrder;
use super::ColumnVec;
use super::{ColumnVec, InteractionElements};

pub mod accumulation;
mod air_ext;
Expand Down Expand Up @@ -53,7 +53,15 @@ pub trait Component {
);
}

pub trait ComponentProver<B: Backend>: Component {
pub trait ComponentTraceWriter<B: Backend> {
fn write_interaction_trace(
&self,
trace: &ColumnVec<&CircleEvaluation<B, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<B, BaseField, BitReversedOrder>>;
}

pub trait ComponentProver<B: Backend>: Component + ComponentTraceWriter<B> {
/// Evaluates the constraint quotients of the component on the evaluation domain.
/// Accumulates quotients in `evaluation_accumulator`.
fn evaluate_constraint_quotients_on_domain(
Expand Down
14 changes: 12 additions & 2 deletions crates/prover/src/core/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::ops::{Deref, DerefMut};
use std::collections::BTreeMap;
use std::ops::{Deref, DerefMut, Index};

use self::fields::m31::BaseField;

Expand Down Expand Up @@ -60,4 +61,13 @@ impl<T> DerefMut for ComponentVec<T> {
}
}

pub struct InteractionElements(Vec<(String, BaseField)>);
pub struct InteractionElements(BTreeMap<String, BaseField>);

impl Index<&str> for InteractionElements {
type Output = BaseField;

fn index(&self, index: &str) -> &Self::Output {
// TODO(AlonH): Return an error if the key is not found.
&self.0[index]
}
}
20 changes: 18 additions & 2 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,21 @@ mod tests {
use num_traits::Zero;

use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace};
use crate::core::air::{
Air, AirProver, Component, ComponentProver, ComponentTrace, ComponentTraceWriter,
};
use crate::core::backend::cpu::CpuCircleEvaluation;
use crate::core::backend::CpuBackend;
use crate::core::circle::{CirclePoint, CirclePointIndex, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::{CanonicCoset, CircleDomain, MAX_CIRCLE_DOMAIN_LOG_SIZE};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, MAX_CIRCLE_DOMAIN_LOG_SIZE,
};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{prove, ProvingError};
use crate::core::test_utils::test_channel;
use crate::core::{ColumnVec, InteractionElements};
use crate::qm31;

struct TestAir<C: ComponentProver<CpuBackend>> {
Expand Down Expand Up @@ -335,6 +341,16 @@ mod tests {
}
}

impl ComponentTraceWriter<CpuBackend> for TestComponent {
fn write_interaction_trace(
&self,
_trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}
}

impl ComponentProver<CpuBackend> for TestComponent {
fn evaluate_constraint_quotients_on_domain(
&self,
Expand Down
17 changes: 14 additions & 3 deletions crates/prover/src/examples/fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ use num_traits::One;

use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::mask::shifted_mask_points;
use crate::core::air::{Component, ComponentProver, ComponentTrace};
use crate::core::air::{Component, ComponentProver, ComponentTrace, ComponentTraceWriter};
use crate::core::backend::CpuBackend;
use crate::core::circle::{CirclePoint, Coset};
use crate::core::constraints::{coset_vanishing, pair_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, FieldExpOps};
use crate::core::poly::circle::CanonicCoset;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse_index;
use crate::core::ColumnVec;
use crate::core::{ColumnVec, InteractionElements};

pub struct FibonacciComponent {
pub log_size: u32,
Expand Down Expand Up @@ -116,6 +117,16 @@ impl Component for FibonacciComponent {
}
}

impl ComponentTraceWriter<CpuBackend> for FibonacciComponent {
fn write_interaction_trace(
&self,
_trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
vec![]
}
}

impl ComponentProver<CpuBackend> for FibonacciComponent {
fn evaluate_constraint_quotients_on_domain(
&self,
Expand Down
34 changes: 27 additions & 7 deletions crates/prover/src/examples/wide_fibonacci/component.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
use itertools::Itertools;

use crate::core::air::accumulation::PointEvaluationAccumulator;
use crate::core::air::mask::fixed_mask_points;
use crate::core::air::{Air, Component};
use crate::core::air::{Air, Component, ComponentTraceWriter};
use crate::core::backend::CpuBackend;
use crate::core::circle::CirclePoint;
use crate::core::constraints::coset_vanishing;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::CanonicCoset;
use crate::core::ColumnVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::{ColumnVec, InteractionElements};
use crate::examples::wide_fibonacci::trace_gen::write_lookup_column;

pub const LOG_N_COLUMNS: usize = 8;
pub const N_COLUMNS: usize = 1 << LOG_N_COLUMNS;

const ALPHA_ID: &str = "wide_fibonacci_alpha";
const Z_ID: &str = "wide_fibonacci_z";

/// Component that computes 2^`self.log_n_instances` instances of fibonacci sequences of size
/// 2^`self.log_fibonacci_size`. The numbers are computes over [N_COLUMNS] trace columns. The
/// number of rows (i.e the size of the columns) is determined by the parameters above (see
Expand Down Expand Up @@ -68,10 +76,7 @@ impl Component for WideFibComponent {
}

fn interaction_element_ids(&self) -> Vec<String> {
vec![
"wide_fibonacci_alpha".to_string(),
"wide_fibonacci_z".to_string(),
]
vec![ALPHA_ID.to_string(), Z_ID.to_string()]
}

fn evaluate_constraint_quotients_at_point(
Expand All @@ -90,6 +95,21 @@ impl Component for WideFibComponent {
}
}

impl ComponentTraceWriter<CpuBackend> for WideFibComponent {
fn write_interaction_trace(
&self,
trace: &ColumnVec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>,
elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>> {
let interaction_trace_domain = trace[0].domain;
let trace_values = trace.iter().map(|eval| &eval.values[..]).collect_vec();
let (alpha, z) = (elements[ALPHA_ID], elements[Z_ID]);
let values = write_lookup_column(&trace_values, alpha, z);
let eval = CircleEvaluation::new(interaction_trace_domain, values);
vec![eval]
}
}

// Input for the fibonacci claim.
#[derive(Debug, Clone, Copy)]
pub struct Input {
Expand Down
3 changes: 2 additions & 1 deletion crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ mod tests {
let alpha = m31!(7);
let z = m31!(11);
let trace = gen_trace(&wide_fib, vec![input]);
let lookup_column = write_lookup_column(&trace, alpha, z);
let input_trace = trace.iter().map(|values| &values[..]).collect_vec();
let lookup_column = write_lookup_column(&input_trace, alpha, z);

assert_constraints_on_lookup_column(&lookup_column, &trace, alpha, z)
}
Expand Down
16 changes: 14 additions & 2 deletions crates/prover/src/examples/wide_fibonacci/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use tracing::{span, Level};

use super::component::{WideFibAir, WideFibComponent};
use crate::core::air::accumulation::DomainEvaluationAccumulator;
use crate::core::air::{AirProver, Component, ComponentProver, ComponentTrace};
use crate::core::air::{
AirProver, Component, ComponentProver, ComponentTrace, ComponentTraceWriter,
};
use crate::core::backend::simd::column::BaseFieldVec;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
Expand All @@ -15,7 +17,7 @@ use crate::core::fields::m31::BaseField;
use crate::core::fields::{FieldExpOps, FieldOps};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;
use crate::core::{ColumnVec, InteractionElements};
use crate::examples::wide_fibonacci::component::N_COLUMNS;

impl AirProver<SimdBackend> for WideFibAir {
Expand Down Expand Up @@ -50,6 +52,16 @@ pub fn gen_trace(
.collect_vec()
}

impl ComponentTraceWriter<SimdBackend> for WideFibComponent {
fn write_interaction_trace(
&self,
_trace: &ColumnVec<&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
_elements: &InteractionElements,
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
vec![]
}
}

impl ComponentProver<SimdBackend> for WideFibComponent {
fn evaluate_constraint_quotients_on_domain(
&self,
Expand Down
7 changes: 5 additions & 2 deletions crates/prover/src/examples/wide_fibonacci/trace_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::core::fields::m31::BaseField;
use crate::core::fields::FieldExpOps;
use crate::core::utils::shifted_secure_combination;

/// Given a private input, write the trace row for the wide Fibonacci example to dst. Returns the
/// Writes the trace row for the wide Fibonacci example to dst, given a private input. Returns the
/// last two elements of the row in case the sequence is continued.
pub fn write_trace_row(
dst: &mut [Vec<BaseField>],
Expand All @@ -22,8 +22,11 @@ pub fn write_trace_row(
(dst[n_columns - 2][row_index], dst[n_columns - 1][row_index])
}

/// Writes and returns the lookup column for the wide Fibonacci example, which is the partial
/// product of the shifted secure combination of the first two elements in each row divided by the
/// the shifted secure combination of the last two elements in each row.
pub fn write_lookup_column(
input_trace: &[Vec<BaseField>],
input_trace: &[&[BaseField]],
// TODO(AlonH): Change alpha and z to SecureField.
alpha: BaseField,
z: BaseField,
Expand Down

0 comments on commit e7f0a82

Please sign in to comment.