Skip to content

Commit

Permalink
Support shared preproccessed columns.
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyalesokhin-starkware committed Nov 10, 2024
1 parent 8385e54 commit b36d75f
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 59 deletions.
69 changes: 63 additions & 6 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt::{self, Display, Formatter};
use std::iter::zip;
use std::ops::Deref;
Expand All @@ -9,7 +10,10 @@ use rayon::prelude::*;
use tracing::{span, Level};

use super::cpu_domain::CpuDomainEvaluator;
use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use super::preprocessed_columns::PreprocessedColumn;
use super::{
EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX,
};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, Trace};
use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords;
Expand All @@ -35,6 +39,7 @@ const CHUNK_SIZE: usize = 1;
pub struct TraceLocationAllocator {
/// Mapping of tree index to next available column offset.
next_tree_offsets: TreeVec<usize>,
preprocessed_columns: HashMap<PreprocessedColumn, usize>,
}

impl TraceLocationAllocator {
Expand Down Expand Up @@ -62,6 +67,17 @@ impl TraceLocationAllocator {
.collect(),
)
}

pub fn new_with_preproccessed_columnds(preprocessed_columns: &[PreprocessedColumn]) -> Self {
Self {
next_tree_offsets: Default::default(),
preprocessed_columns: preprocessed_columns
.iter()
.enumerate()
.map(|(i, &col)| (col, i))
.collect(),
}
}
}

/// A component defined solely in means of the constraints framework.
Expand All @@ -80,16 +96,24 @@ pub struct FrameworkComponent<C: FrameworkEval> {
eval: C,
trace_locations: TreeVec<TreeSubspan>,
info: InfoEvaluator,
preprocessed_column_indices: Vec<usize>,
}

impl<E: FrameworkEval> FrameworkComponent<E> {
pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self {
let info = eval.evaluate(InfoEvaluator::default());
let trace_locations = location_allocator.next_for_structure(&info.mask_offsets);

let preprocessed_column_indices = info
.preprocessed_columns
.iter()
.map(|col| location_allocator.preprocessed_columns[col])
.collect();
Self {
eval,
trace_locations,
info,
preprocessed_column_indices,
}
}

Expand All @@ -108,10 +132,19 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
}

fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
self.info
let mut log_degree_bounds = self
.info
.mask_offsets
.as_ref()
.map(|tree_offsets| vec![self.eval.log_size(); tree_offsets.len()])
.map(|tree_offsets| vec![self.eval.log_size(); tree_offsets.len()]);

log_degree_bounds[0] = self
.preprocessed_column_indices
.iter()
.map(|_| self.eval.log_size())
.collect();

log_degree_bounds
}

fn mask_points(
Expand All @@ -127,14 +160,27 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
})
}

fn preproccessed_column_indices(&self) -> ColumnVec<usize> {
self.preprocessed_column_indices.clone()
}

fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
) {
let preprocessed_mask = self
.preprocessed_column_indices
.iter()
.map(|idx| &mask[PREPROCESSED_TRACE_IDX][*idx])
.collect_vec();

let mut mask_points = mask.sub_tree(&self.trace_locations);
mask_points[PREPROCESSED_TRACE_IDX] = preprocessed_mask;

self.eval.evaluate(PointEvaluator::new(
mask.sub_tree(&self.trace_locations),
mask_points,
evaluation_accumulator,
coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(),
));
Expand All @@ -154,8 +200,19 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain();
let trace_domain = CanonicCoset::new(self.eval.log_size());

let component_polys = trace.polys.sub_tree(&self.trace_locations);
let component_evals = trace.evals.sub_tree(&self.trace_locations);
let mut component_polys = trace.polys.sub_tree(&self.trace_locations);
component_polys[PREPROCESSED_TRACE_IDX] = self
.preprocessed_column_indices
.iter()
.map(|idx| &trace.polys[PREPROCESSED_TRACE_IDX][*idx])
.collect();

let mut component_evals = trace.evals.sub_tree(&self.trace_locations);
component_evals[PREPROCESSED_TRACE_IDX] = self
.preprocessed_column_indices
.iter()
.map(|idx| &trace.evals[PREPROCESSED_TRACE_IDX][*idx])
.collect();

// Extend trace if necessary.
// TODO: Don't extend when eval_size < committed_size. Instead, pick a good
Expand Down
15 changes: 15 additions & 0 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::ops::Mul;

use num_traits::One;

use super::preprocessed_columns::PreprocessedColumn;
use super::EvalAtRow;
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::pcs::TreeVec;
Expand All @@ -13,6 +15,7 @@ use crate::core::pcs::TreeVec;
pub struct InfoEvaluator {
pub mask_offsets: TreeVec<Vec<Vec<isize>>>,
pub n_constraints: usize,
pub preprocessed_columns: Vec<PreprocessedColumn>,
}
impl InfoEvaluator {
pub fn new() -> Self {
Expand All @@ -22,11 +25,17 @@ impl InfoEvaluator {
impl EvalAtRow for InfoEvaluator {
type F = BaseField;
type EF = SecureField;

fn next_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
offsets: [isize; N],
) -> [Self::F; N] {
assert!(
interaction != PREPROCESSED_TRACE_IDX,
"Preprocessed should be accesses with `get_preprocessed_column`",
);

// Check if requested a mask from a new interaction
if self.mask_offsets.len() <= interaction {
// Extend `mask_offsets` so that `interaction` is the last index.
Expand All @@ -35,6 +44,12 @@ impl EvalAtRow for InfoEvaluator {
self.mask_offsets[interaction].push(offsets.into_iter().collect());
[BaseField::one(); N]
}

fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F {
self.preprocessed_columns.push(column);
BaseField::one()
}

fn add_constraint<G>(&mut self, _constraint: G)
where
Self::EF: Mul<G, Output = Self::EF>,
Expand Down
56 changes: 49 additions & 7 deletions crates/prover/src/core/air/components.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use itertools::Itertools;

use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use super::{Component, ComponentProver, Trace};
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
use crate::core::backend::Backend;
use crate::core::circle::CirclePoint;
use crate::core::fields::qm31::SecureField;
Expand All @@ -27,11 +28,22 @@ impl<'a> Components<'a> {
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
TreeVec::concat_cols(
let mut mask_points = TreeVec::concat_cols(
self.components
.iter()
.map(|component| component.mask_points(point)),
)
);

let preprocessed_mask_points = &mut mask_points[PREPROCESSED_TRACE_IDX];
*preprocessed_mask_points = vec![vec![]; self.n_preprocessed_columns];

for component in &self.components {
for idx in component.preproccessed_column_indices() {
preprocessed_mask_points[idx].resize(1, point);
}
}

mask_points
}

pub fn eval_composition_polynomial_at_point(
Expand All @@ -52,11 +64,41 @@ impl<'a> Components<'a> {
}

pub fn column_log_sizes(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::concat_cols(
self.components
.iter()
.map(|component| component.trace_log_degree_bounds()),
)
let mut preprocessed_columns = vec![0; self.n_preprocessed_columns];
let mut updated_columns = vec![false; self.n_preprocessed_columns];

let mut column_log_sizes = TreeVec::concat_cols(self.components.iter().map(|component| {
let mut component_trace_log_sizes = component.trace_log_degree_bounds();

for (offset, size) in component
.preproccessed_column_indices()
.into_iter()
.zip(std::mem::take(&mut component_trace_log_sizes[0]))
{
let column_size = &mut preprocessed_columns[offset];
if updated_columns[offset] {
assert!(
*column_size == size,
"Preprocessed column size mismatch for column {}",
offset
);
} else {
*column_size = size;
updated_columns[offset] = true;
}
}

component_trace_log_sizes
}));

assert!(
updated_columns.iter().all(|&updated| updated),
"Column size not set for all reprocessed columns"
);

column_log_sizes[PREPROCESSED_TRACE_IDX] = preprocessed_columns;

column_log_sizes
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ pub trait Component {
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>;

fn preproccessed_column_indices(&self) -> ColumnVec<usize>;

/// Evaluates the constraint quotients combination of the component at a point.
fn evaluate_constraint_quotients_at_point(
&self,
Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/core/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ pub fn verify<MC: MerkleChannel>(

// Get mask sample points relative to oods point.
let mut sample_points = components.mask_points(oods_point);

// Add the composition polynomial mask points.
sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]);

Expand Down
65 changes: 56 additions & 9 deletions crates/prover/src/examples/blake/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use tracing::{span, Level};

use super::round::{blake_round_info, BlakeRoundComponent, BlakeRoundEval};
use super::scheduler::{BlakeSchedulerComponent, BlakeSchedulerEval};
use super::xor_table::{XorTableComponent, XorTableEval};
use crate::constraint_framework::preprocessed_columns::gen_is_first;
use crate::constraint_framework::TraceLocationAllocator;
use super::xor_table::{column_bits, XorTableComponent, XorTableEval};
use crate::constraint_framework::preprocessed_columns::{gen_is_first, PreprocessedColumn};
use crate::constraint_framework::{TraceLocationAllocator, PREPROCESSED_TRACE_IDX};
use crate::core::air::{Component, ComponentProver};
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::SimdBackend;
Expand All @@ -26,6 +26,29 @@ use crate::examples::blake::{
round, xor_table, BlakeXorElements, XorAccums, N_ROUNDS, ROUND_LOG_SPLIT,
};

const PREPROCESSED_XOR_COLUMNS: [PreprocessedColumn; 20] = [
PreprocessedColumn::IsFirst(column_bits::<12, 4>()),
PreprocessedColumn::XorTable(12, 4, 0),
PreprocessedColumn::XorTable(12, 4, 1),
PreprocessedColumn::XorTable(12, 4, 2),
PreprocessedColumn::IsFirst(column_bits::<9, 2>()),
PreprocessedColumn::XorTable(9, 2, 0),
PreprocessedColumn::XorTable(9, 2, 1),
PreprocessedColumn::XorTable(9, 2, 2),
PreprocessedColumn::IsFirst(column_bits::<8, 2>()),
PreprocessedColumn::XorTable(8, 2, 0),
PreprocessedColumn::XorTable(8, 2, 1),
PreprocessedColumn::XorTable(8, 2, 2),
PreprocessedColumn::IsFirst(column_bits::<7, 2>()),
PreprocessedColumn::XorTable(7, 2, 0),
PreprocessedColumn::XorTable(7, 2, 1),
PreprocessedColumn::XorTable(7, 2, 2),
PreprocessedColumn::IsFirst(column_bits::<4, 0>()),
PreprocessedColumn::XorTable(4, 0, 0),
PreprocessedColumn::XorTable(4, 0, 1),
PreprocessedColumn::XorTable(4, 0, 2),
];

#[derive(Serialize)]
pub struct BlakeStatement0 {
log_size: u32,
Expand Down Expand Up @@ -53,7 +76,23 @@ impl BlakeStatement0 {
sizes.push(xor_table::trace_sizes::<7, 2>());
sizes.push(xor_table::trace_sizes::<4, 0>());

TreeVec::concat_cols(sizes.into_iter())
let mut log_sizes = TreeVec::concat_cols(sizes.into_iter());

let log_size = self.log_size;

log_sizes[PREPROCESSED_TRACE_IDX] = chain!(
[log_size],
ROUND_LOG_SPLIT.iter().map(|l| log_size + l),
PREPROCESSED_XOR_COLUMNS.map(|column| match column {
PreprocessedColumn::XorTable(elem_bits, expand_bits, _) =>
2 * (elem_bits - expand_bits),
PreprocessedColumn::IsFirst(log_size) => log_size,
_ => panic!("Unexpected column"),
}),
)
.collect_vec();

log_sizes
}
fn mix_into(&self, channel: &mut impl Channel) {
channel.mix_u64(self.log_size as u64);
Expand Down Expand Up @@ -120,7 +159,18 @@ pub struct BlakeComponents {
}
impl BlakeComponents {
fn new(stmt0: &BlakeStatement0, all_elements: &AllElements, stmt1: &BlakeStatement1) -> Self {
let tree_span_provider = &mut TraceLocationAllocator::default();
let log_size = stmt0.log_size;
let tree_span_provider = &mut TraceLocationAllocator::new_with_preproccessed_columnds(
&chain!(
[PreprocessedColumn::IsFirst(log_size)],
ROUND_LOG_SPLIT
.iter()
.map(|l| PreprocessedColumn::IsFirst(log_size + l)),
PREPROCESSED_XOR_COLUMNS,
)
.collect_vec()[..],
);

Self {
scheduler_component: BlakeSchedulerComponent::new(
tree_span_provider,
Expand Down Expand Up @@ -259,10 +309,7 @@ where
tree_builder.extend_evals(
chain![
vec![gen_is_first(log_size)],
ROUND_LOG_SPLIT
.iter()
.map(|l| gen_is_first(log_size + l))
.collect_vec(),
ROUND_LOG_SPLIT.iter().map(|l| gen_is_first(log_size + l)),
xor_table::generate_constant_trace::<12, 4>(),
xor_table::generate_constant_trace::<9, 2>(),
xor_table::generate_constant_trace::<8, 2>(),
Expand Down
Loading

0 comments on commit b36d75f

Please sign in to comment.