Skip to content

Commit

Permalink
Create MLE eval component
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Aug 24, 2024
1 parent af52728 commit c1ac882
Show file tree
Hide file tree
Showing 4 changed files with 512 additions and 90 deletions.
31 changes: 18 additions & 13 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ pub struct TreeColumnSpanProvider {
}

impl TreeColumnSpanProvider {
fn next_for_structure<T>(&mut self, structure: &TreeVec<ColumnVec<T>>) -> Vec<TreeColumnSpan> {
pub fn next_for_structure<T>(
&mut self,
structure: &TreeVec<ColumnVec<T>>,
) -> Vec<TreeColumnSpan> {
structure
.iter()
.enumerate()
Expand Down Expand Up @@ -82,6 +85,10 @@ impl<E: FrameworkEval> FrameworkComponent<E> {
trace_locations,
}
}

pub fn trace_locations(&self) -> &[TreeColumnSpan] {
&self.trace_locations
}
}

impl<E: FrameworkEval> Component for FrameworkComponent<E> {
Expand All @@ -94,26 +101,20 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
}

fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::new(
self.eval
.evaluate(InfoEvaluator::default())
.mask_offsets
.iter()
.map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()])
.collect(),
)
let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default());
mask_offsets.map(|tree_offsets| vec![self.eval.log_size(); tree_offsets.len()])
}

fn mask_points(
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let info = self.eval.evaluate(InfoEvaluator::default());
let trace_step = CanonicCoset::new(self.eval.log_size()).step();
info.mask_offsets.map_cols(|col_mask| {
col_mask
let InfoEvaluator { mask_offsets, .. } = self.eval.evaluate(InfoEvaluator::default());
mask_offsets.map_cols(|col_offsets| {
col_offsets
.iter()
.map(|off| point + trace_step.mul_signed(*off).into_ef())
.map(|offset| point + trace_step.mul_signed(*offset).into_ef())
.collect()
})
}
Expand All @@ -138,6 +139,10 @@ impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
trace: &Trace<'_, SimdBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<SimdBackend>,
) {
if self.n_constraints() == 0 {
return;
}

let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain();
let trace_domain = CanonicCoset::new(self.eval.log_size());

Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/core/air/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::core::utils::generate_secure_powers;
/// Accumulates N evaluations of u_i(P0) at a single point.
/// Computes f(P0), the combined polynomial at that point.
/// For n accumulated evaluations, the i'th evaluation is multiplied by alpha^(N-1-i).
#[derive(Debug, Clone, Copy)]
pub struct PointEvaluationAccumulator {
random_coeff: SecureField,
accumulation: SecureField,
Expand Down
13 changes: 8 additions & 5 deletions crates/prover/src/examples/xor/gkr_lookups/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1;
/// IOP for multilinear eval at point.
pub const MAX_MLE_N_VARIABLES: u32 = M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR;

/// Accumulates [`Mle`]s grouped by their number of variables.
/// Collection of [`Mle`]s grouped by their number of variables.
pub struct MleCollection<B: Backend> {
mles_by_n_variables: Vec<Option<Vec<DynMle<B>>>>,
}

impl<B: Backend> MleCollection<B> {
/// Appends an [`Mle`] to the collection.
/// Appends an [`Mle`] to the back of the collection.
pub fn push(&mut self, mle: impl Into<DynMle<B>>) {
let mle = mle.into();
let mles = self.mles_by_n_variables[mle.n_variables()].get_or_insert(Vec::new());
Expand All @@ -35,6 +35,7 @@ impl<B: Backend> MleCollection<B> {
impl MleCollection<SimdBackend> {
/// Performs a random linear combination of all MLEs, grouped by their number of variables.
///
/// For `n` accumulated MLEs in a group, the `i`th MLE is multiplied by `alpha^(n-1-i)`.
/// MLEs are returned in ascending order by number of variables.
pub fn random_linear_combine_by_n_variables(
self,
Expand All @@ -53,13 +54,15 @@ impl MleCollection<SimdBackend> {
/// Panics if `mles` is empty or all MLEs don't have the same number of variables.
fn mle_random_linear_combination(
mles: Vec<DynMle<SimdBackend>>,
alpha: SecureField,
random_coeff: SecureField,
) -> Mle<SimdBackend, SecureField> {
assert!(!mles.is_empty());
let n_variables = mles[0].n_variables();
assert!(mles.iter().all(|mle| mle.n_variables() == n_variables));
let alpha_powers = generate_secure_powers(alpha, mles.len()).into_iter().rev();
let mut mle_and_coeff = zip(mles, alpha_powers);
let coeff_powers = generate_secure_powers(random_coeff, mles.len())
.into_iter()
.rev();
let mut mle_and_coeff = zip(mles, coeff_powers);

// The last value can initialize the accumulator.
let (mle, coeff) = mle_and_coeff.next_back().unwrap();
Expand Down
Loading

0 comments on commit c1ac882

Please sign in to comment.