Skip to content

Commit

Permalink
Allow preprocessed column definitions out of crate
Browse files Browse the repository at this point in the history
  • Loading branch information
yoichi-nexus committed Dec 19, 2024
1 parent f24cde6 commit cec714b
Show file tree
Hide file tree
Showing 14 changed files with 333 additions and 131 deletions.
2 changes: 2 additions & 0 deletions crates/prover/src/constraint_framework/assert.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use num_traits::Zero;

use super::logup::{LogupAtRow, LogupSums};
Expand Down
27 changes: 19 additions & 8 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashMap;
use std::fmt::{self, Display, Formatter};
use std::iter::zip;
use std::ops::Deref;
use std::sync::Arc;

use itertools::Itertools;
#[cfg(feature = "parallel")]
Expand All @@ -11,7 +12,7 @@ use tracing::{span, Level};

use super::cpu_domain::CpuDomainEvaluator;
use super::logup::LogupSums;
use super::preprocessed_columns::PreprocessedColumn;
use super::preprocessed_columns::PreprocessedColumnOps;
use super::{
EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX,
};
Expand Down Expand Up @@ -49,7 +50,8 @@ pub struct TraceLocationAllocator {
/// Mapping of tree index to next available column offset.
next_tree_offsets: TreeVec<usize>,
/// Mapping of preprocessed columns to their index.
preprocessed_columns: HashMap<PreprocessedColumn, usize>,
/// A preprocessed column implementation is indicated by its TypeId
preprocessed_columns: HashMap<Arc<dyn PreprocessedColumnOps>, usize>,
/// Controls whether the preprocessed columns are dynamic or static (default=Dynamic).
preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode,
}
Expand Down Expand Up @@ -81,30 +83,39 @@ impl TraceLocationAllocator {
}

/// Create a new `TraceLocationAllocator` with fixed preprocessed columns setup.
pub fn new_with_preproccessed_columns(preprocessed_columns: &[PreprocessedColumn]) -> Self {
pub fn new_with_preproccessed_columns(
preprocessed_columns: &[Arc<dyn PreprocessedColumnOps>],
) -> Self {
Self {
next_tree_offsets: Default::default(),
preprocessed_columns: preprocessed_columns
.iter()
.enumerate()
.map(|(i, &col)| (col, i))
.map(|(i, col)| (col.clone(), i))
.collect(),
preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode::Static,
}
}

pub const fn preprocessed_columns(&self) -> &HashMap<PreprocessedColumn, usize> {
pub const fn preprocessed_columns(&self) -> &HashMap<Arc<dyn PreprocessedColumnOps>, usize> {
&self.preprocessed_columns
}

// validates that `self.preprocessed_columns` is consistent with
// `preprocessed_columns`.
// I.e. preprocessed_columns[i] == self.preprocessed_columns[i].
pub fn validate_preprocessed_columns(&self, preprocessed_columns: &[PreprocessedColumn]) {
// The equality comparison uses the pointer comparison of the boxes.
pub fn validate_preprocessed_columns(
&self,
preprocessed_columns: &[Arc<dyn PreprocessedColumnOps>],
) {
assert_eq!(preprocessed_columns.len(), self.preprocessed_columns.len());

for (column, idx) in self.preprocessed_columns.iter() {
assert_eq!(Some(column), preprocessed_columns.get(*idx));
assert!(match preprocessed_columns.get(*idx) {
Some(preprocessed_column) => preprocessed_column == column,
None => false,
},)
}
}
}
Expand Down Expand Up @@ -146,7 +157,7 @@ impl<E: FrameworkEval> FrameworkComponent<E> {
let next_column = location_allocator.preprocessed_columns.len();
*location_allocator
.preprocessed_columns
.entry(*col)
.entry(col.clone())
.or_insert_with(|| {
if matches!(
location_allocator.preprocessed_columns_allocation_mode,
Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/constraint_framework/cpu_domain.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::ops::Mul;
use std::sync::Arc;

use num_traits::Zero;

Expand Down
6 changes: 4 additions & 2 deletions crates/prover/src/constraint_framework/expr/evaluator.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::sync::Arc;

use num_traits::Zero;

use super::{BaseExpr, ExtExpr};
use crate::constraint_framework::expr::ColumnExpr;
use crate::constraint_framework::preprocessed_columns::PreprocessedColumn;
use crate::constraint_framework::preprocessed_columns::PreprocessedColumnOps;
use crate::constraint_framework::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use crate::core::fields::m31;
use crate::core::lookups::utils::Fraction;
Expand Down Expand Up @@ -174,7 +176,7 @@ impl EvalAtRow for ExprEvaluator {
intermediate
}

fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F {
fn get_preprocessed_column(&mut self, column: Arc<dyn PreprocessedColumnOps>) -> Self::F {
BaseExpr::Param(column.name().to_string())
}

Expand Down
9 changes: 5 additions & 4 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use std::array;
use std::cell::{RefCell, RefMut};
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};
use std::rc::Rc;
use std::sync::Arc;

use num_traits::{One, Zero};

use super::logup::{LogupAtRow, LogupSums};
use super::preprocessed_columns::PreprocessedColumn;
use super::preprocessed_columns::PreprocessedColumnOps;
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
use crate::core::fields::m31::BaseField;
Expand All @@ -22,14 +23,14 @@ use crate::core::pcs::TreeVec;
pub struct InfoEvaluator {
pub mask_offsets: TreeVec<Vec<Vec<isize>>>,
pub n_constraints: usize,
pub preprocessed_columns: Vec<PreprocessedColumn>,
pub preprocessed_columns: Vec<Arc<dyn PreprocessedColumnOps>>,
pub logup: LogupAtRow<Self>,
pub arithmetic_counts: ArithmeticCounts,
}
impl InfoEvaluator {
pub fn new(
log_size: u32,
preprocessed_columns: Vec<PreprocessedColumn>,
preprocessed_columns: Vec<Arc<dyn PreprocessedColumnOps>>,
logup_sums: LogupSums,
) -> Self {
Self {
Expand Down Expand Up @@ -70,7 +71,7 @@ impl EvalAtRow for InfoEvaluator {
array::from_fn(|_| FieldCounter::one())
}

fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F {
fn get_preprocessed_column(&mut self, column: Arc<dyn PreprocessedColumnOps>) -> Self::F {
self.preprocessed_columns.push(column);
FieldCounter::one()
}
Expand Down
18 changes: 8 additions & 10 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ mod simd_domain;
use std::array;
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Neg, Sub};
use std::sync::Arc;

pub use assert::{assert_constraints, AssertEvaluator};
pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator};
pub use info::InfoEvaluator;
use num_traits::{One, Zero};
pub use point::PointEvaluator;
use preprocessed_columns::PreprocessedColumn;
use preprocessed_columns::PreprocessedColumnOps;
pub use simd_domain::SimdDomainEvaluator;

use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -87,7 +88,7 @@ pub trait EvalAtRow {
mask_item
}

fn get_preprocessed_column(&mut self, _column: PreprocessedColumn) -> Self::F {
fn get_preprocessed_column(&mut self, _column: Arc<dyn PreprocessedColumnOps>) -> Self::F {
let [mask_item] = self.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
mask_item
}
Expand Down Expand Up @@ -165,18 +166,15 @@ pub trait EvalAtRow {
}
}

/// Default implementation for evaluators that have an element called "logup" that works like a
/// LogupAtRow, where the logup functionality can be proxied.
/// TODO(alont): Remove once LogupAtRow is no longer used.
macro_rules! logup_proxy {
() => {
fn write_logup_frac(&mut self, fraction: Fraction<Self::EF, Self::EF>) {
if self.logup.fracs.is_empty() {
self.logup.is_first = self.get_preprocessed_column(
crate::constraint_framework::preprocessed_columns::PreprocessedColumn::IsFirst(
self.logup.log_size,
),
);
self.logup.is_first = self.get_preprocessed_column(Arc::new(
crate::constraint_framework::preprocessed_columns::IsFirst {
log_size: self.logup.log_size,
},
));
self.logup.is_finalized = false;
}
self.logup.fracs.push(fraction.clone());
Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/constraint_framework/point.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::ops::Mul;
use std::sync::Arc;

use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
Expand Down
154 changes: 126 additions & 28 deletions crates/prover/src/constraint_framework/preprocessed_columns.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,133 @@
use std::any::Any;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;

use num_traits::One;

use crate::core::backend::{Backend, Col, Column};
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Backend, Col, Column, CpuBackend};
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};

// TODO(ilya): Where should this enum be placed?
/// XorTable, etc will be implementation of this trait.
pub trait PreprocessedColumnOps: Debug + Any {
fn get_type_id(&self) -> std::any::TypeId {
self.type_id()
}
fn name(&self) -> &'static str;
fn log_size(&self) -> u32;
fn gen_preprocessed_column_cpu(
&self,
) -> CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>;
fn gen_preprocessed_column_simd(
&self,
) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>;
fn as_bytes(&self) -> Vec<u8>;
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct IsFirst {
pub log_size: u32,
}

impl PreprocessedColumnOps for IsFirst {
fn name(&self) -> &'static str {
"preprocessed.is_first"
}
fn log_size(&self) -> u32 {
self.log_size
}
fn gen_preprocessed_column_cpu(
&self,
) -> CircleEvaluation<CpuBackend, BaseField, BitReversedOrder> {
gen_is_first(self.log_size)
}
fn gen_preprocessed_column_simd(
&self,
) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
gen_is_first(self.log_size)
}
fn as_bytes(&self) -> Vec<u8> {
self.log_size.to_le_bytes().to_vec()
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PreprocessedColumn {
XorTable(u32, u32, usize),
IsFirst(u32),
Plonk(usize),
pub struct XorTable {
pub elem_bits: u32,
pub expand_bits: u32,
pub kind: usize,
}

impl PreprocessedColumn {
pub const fn name(&self) -> &'static str {
match self {
PreprocessedColumn::XorTable(..) => "preprocessed.xor_table",
PreprocessedColumn::IsFirst(_) => "preprocessed.is_first",
PreprocessedColumn::Plonk(_) => "preprocessed.plonk",
}
impl PreprocessedColumnOps for XorTable {
fn name(&self) -> &'static str {
"preprocessed.xor_table"
}
fn log_size(&self) -> u32 {
assert!(self.elem_bits >= self.expand_bits);
2 * (self.elem_bits - self.expand_bits)
}
fn gen_preprocessed_column_cpu(
&self,
) -> CircleEvaluation<CpuBackend, BaseField, BitReversedOrder> {
unimplemented!("XorTable is not supported.")
}
fn gen_preprocessed_column_simd(
&self,
) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
unimplemented!("XorTable is not supported.")
}
fn as_bytes(&self) -> Vec<u8> {
let mut bytes = vec![];
bytes.extend_from_slice(&self.elem_bits.to_le_bytes());
bytes.extend_from_slice(&self.expand_bits.to_le_bytes());
bytes.extend_from_slice(&self.kind.to_le_bytes());
bytes
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Plonk {
pub kind: u32,
}

impl PreprocessedColumnOps for Plonk {
fn name(&self) -> &'static str {
"preprocessed.plonk"
}
fn log_size(&self) -> u32 {
unimplemented!("Plonk is not supported.")
}
fn gen_preprocessed_column_cpu(
&self,
) -> CircleEvaluation<CpuBackend, BaseField, BitReversedOrder> {
unimplemented!("Plonk is not supported.")
}
fn gen_preprocessed_column_simd(
&self,
) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
unimplemented!("Plonk is not supported.")
}
fn as_bytes(&self) -> Vec<u8> {
self.kind.to_le_bytes().to_vec()
}
}

impl PartialEq for dyn PreprocessedColumnOps {
fn eq(&self, other: &Self) -> bool {
self.get_type_id() == other.get_type_id() && self.as_bytes() == other.as_bytes()
}
}

impl Eq for dyn PreprocessedColumnOps {}

impl Hash for dyn PreprocessedColumnOps {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.get_type_id().hash(state);
self.as_bytes().hash(state);
}
}

Expand Down Expand Up @@ -54,19 +161,10 @@ pub fn gen_is_step_with_offset<B: Backend>(
CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col)
}

pub fn gen_preprocessed_column<B: Backend>(
preprocessed_column: &PreprocessedColumn,
) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
match preprocessed_column {
PreprocessedColumn::IsFirst(log_size) => gen_is_first(*log_size),
PreprocessedColumn::Plonk(_) | PreprocessedColumn::XorTable(..) => {
unimplemented!("eval_preprocessed_column: Plonk and XorTable are not supported.")
}
}
}

pub fn gen_preprocessed_columns<'a, B: Backend>(
columns: impl Iterator<Item = &'a PreprocessedColumn>,
) -> Vec<CircleEvaluation<B, BaseField, BitReversedOrder>> {
columns.map(gen_preprocessed_column).collect()
pub fn gen_preprocessed_columns_simd<'a>(
columns: impl Iterator<Item = Arc<dyn PreprocessedColumnOps>>,
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
columns
.map(|col| col.gen_preprocessed_column_simd())
.collect()
}
1 change: 1 addition & 0 deletions crates/prover/src/constraint_framework/simd_domain.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::ops::Mul;
use std::sync::Arc;

use num_traits::Zero;

Expand Down
Loading

0 comments on commit cec714b

Please sign in to comment.