Skip to content

Commit

Permalink
Fall back to CPU in small constraint eval.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Sep 18, 2024
1 parent 77b7cdd commit 1a9755a
Show file tree
Hide file tree
Showing 10 changed files with 445 additions and 149 deletions.
416 changes: 281 additions & 135 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ serde = { version = "1.0", features = ["derive"] }

[dev-dependencies]
aligned = "0.4.2"
test-case = "3.3.1"
test-log = { version = "0.2.15", features = ["trace"] }
tracing-subscriber = "0.3.18"
[target.'cfg(all(target_family = "wasm", not(target_os = "wasi")))'.dev-dependencies]
Expand Down
31 changes: 30 additions & 1 deletion crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::ops::Deref;
use itertools::Itertools;
use tracing::{span, Level};

use super::cpu_domain::CpuDomainEvaluator;
use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, Trace};
Expand Down Expand Up @@ -173,7 +174,35 @@ impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
evaluation_accumulator.columns([(eval_domain.log_size(), self.n_constraints())]);
accum.random_coeff_powers.reverse();

let _span = span!(Level::INFO, "Constraint pointwise eval").entered();
let _span = span!(Level::INFO, "Constraint point-wise eval").entered();

if trace_domain.log_size() <= LOG_N_LANES {
// Fall back to CPU if the trace is too small.
let mut col = accum.col.to_cpu();

for row in 0..(1 << eval_domain.log_size()) {
let trace_cols = trace.as_cols_ref().map_cols(|c| c.to_cpu());
let trace_cols = trace_cols.as_cols_ref();

// Evaluate constrains at row.
let eval = CpuDomainEvaluator::new(
&trace_cols,
row,
&accum.random_coeff_powers,
trace_domain.log_size(),
eval_domain.log_size(),
);
let row_res = self.eval.evaluate(eval).row_res;

// Finalize row.
let denom_inv = denom_inv[row >> trace_domain.log_size()];
col.set(row, col.at(row) + row_res * denom_inv)
}
let col = col.to_simd();
*accum.col = col;
return;
}

let col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(accum.col) };

for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)) {
Expand Down
91 changes: 91 additions & 0 deletions crates/prover/src/constraint_framework/cpu_domain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use std::ops::Mul;

use num_traits::Zero;

use super::EvalAtRow;
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
use crate::core::utils::offset_bit_reversed_circle_domain_index;

/// Evaluates constraints at an evaluation domain points.
pub struct CpuDomainEvaluator<'a> {
pub trace_eval: &'a TreeVec<Vec<&'a CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>>,
pub column_index_per_interaction: Vec<usize>,
pub row: usize,
pub random_coeff_powers: &'a [SecureField],
pub row_res: SecureField,
pub constraint_index: usize,
pub domain_log_size: u32,
pub eval_domain_log_size: u32,
}

impl<'a> CpuDomainEvaluator<'a> {
#[allow(dead_code)]
pub fn new(
trace_eval: &'a TreeVec<Vec<&CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>>>,
row: usize,
random_coeff_powers: &'a [SecureField],
domain_log_size: u32,
eval_log_size: u32,
) -> Self {
Self {
trace_eval,
column_index_per_interaction: vec![0; trace_eval.len()],
row,
random_coeff_powers,
row_res: SecureField::zero(),
constraint_index: 0,
domain_log_size,
eval_domain_log_size: eval_log_size,
}
}
}

impl<'a> EvalAtRow for CpuDomainEvaluator<'a> {
type F = BaseField;
type EF = SecureField;

// TODO(spapini): Remove all boundary checks.
fn next_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
offsets: [isize; N],
) -> [Self::F; N] {
let col_index = self.column_index_per_interaction[interaction];
self.column_index_per_interaction[interaction] += 1;
offsets.map(|off| {
// If the offset is 0, we can just return the value directly from this row.
if off == 0 {
let col = &self.trace_eval[interaction][col_index];
return col[self.row];
}
// Otherwise, we need to look up the value at the offset.
// Since the domain is bit-reversed circle domain ordered, we need to look up the value
// at the bit-reversed natural order index at an offset.
let row = offset_bit_reversed_circle_domain_index(
self.row,
self.domain_log_size,
self.eval_domain_log_size,
off,
);
self.trace_eval[interaction][col_index][row]
})
}

fn add_constraint<G>(&mut self, constraint: G)
where
Self::EF: Mul<G, Output = Self::EF>,
{
self.row_res += self.random_coeff_powers[self.constraint_index] * constraint;
self.constraint_index += 1;
}

fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
SecureField::from_m31_array(values)
}
}
1 change: 1 addition & 0 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod assert;
mod component;
pub mod constant_columns;
mod cpu_domain;
mod info;
pub mod logup;
mod point;
Expand Down
5 changes: 4 additions & 1 deletion crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,10 @@ fn slow_eval_at_point(
// Swap content of a,c.
a.swap_with_slice(&mut c[0..n0]);
}
fold(cast_slice::<_, BaseField>(&poly.coeffs.data), &mappings)
fold(
&cast_slice::<_, BaseField>(&poly.coeffs.data)[..poly.coeffs.length],
&mappings,
)
}

#[cfg(test)]
Expand Down
9 changes: 9 additions & 0 deletions crates/prover/src/core/backend/simd/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ impl BaseColumn {
res
}

pub fn from_cpu_vec(values: Vec<BaseField>) -> Self {
let length = values.len();
let data = values
.chunks(N_LANES)
.map(PackedBaseField::from_slice)
.collect();
Self { data, length }
}

/// Returns a vector of `BaseColumnMutSlice`s, each mutably owning
/// `chunk_size` `PackedBaseField`s (i.e, `chuck_size` * `N_LANES` elements).
pub fn chunks_mut(&mut self, chunk_size: usize) -> Vec<BaseColumnMutSlice<'_>> {
Expand Down
7 changes: 7 additions & 0 deletions crates/prover/src/core/backend/simd/m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ impl PackedM31 {
Self(Simd::from_array(values.map(|M31(v)| v)))
}

pub fn from_slice(values: &[M31]) -> PackedM31 {
assert!(values.len() <= N_LANES);
let mut res = [M31::zero(); N_LANES];
res[..values.len()].copy_from_slice(values);
Self::from_array(res)
}

pub fn to_array(self) -> [M31; N_LANES] {
self.reduce().0.to_array().map(M31)
}
Expand Down
8 changes: 8 additions & 0 deletions crates/prover/src/core/fields/secure_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::iter::zip;
use super::m31::BaseField;
use super::qm31::SecureField;
use super::{ExtensionOf, FieldOps};
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column, CpuBackend};

pub const SECURE_EXTENSION_DEGREE: usize =
Expand All @@ -20,6 +22,12 @@ impl SecureColumnByCoords<CpuBackend> {
pub fn to_vec(&self) -> Vec<SecureField> {
(0..self.len()).map(|i| self.at(i)).collect()
}

pub fn to_simd(self) -> SecureColumnByCoords<SimdBackend> {
SecureColumnByCoords {
columns: self.columns.map(BaseColumn::from_cpu_vec),
}
}
}
impl<B: FieldOps<BaseField>> SecureColumnByCoords<B> {
pub fn at(&self, index: usize) -> SecureField {
Expand Down
25 changes: 13 additions & 12 deletions crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub fn generate_trace<const N: usize>(
mod tests {
use itertools::Itertools;
use num_traits::One;
use test_case::test_case;

use super::WideFibonacciEval;
use crate::constraint_framework::{
Expand Down Expand Up @@ -149,13 +150,14 @@ mod tests {
);
}

#[test_case(6; "SIMD")]
#[test_case(4; "CPU fall back")]
#[test_log::test]
fn test_wide_fib_prove() {
const LOG_N_INSTANCES: u32 = 6;
fn test_wide_fib_prove_with_blake(log_n_instances: u32) {
let config = PcsConfig::default();
// Precompute twiddles.
let twiddles = SimdBackend::precompute_twiddles(
CanonicCoset::new(LOG_N_INSTANCES + 1 + config.fri_config.log_blowup_factor)
CanonicCoset::new(log_n_instances + 1 + config.fri_config.log_blowup_factor)
.circle_domain()
.half_coset,
);
Expand All @@ -168,7 +170,7 @@ mod tests {
);

// Trace.
let trace = generate_test_trace(LOG_N_INSTANCES);
let trace = generate_test_trace(log_n_instances);
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(trace);
tree_builder.commit(prover_channel);
Expand All @@ -177,7 +179,7 @@ mod tests {
let component = WideFibonacciComponent::new(
&mut TraceLocationAllocator::default(),
WideFibonacciEval::<FIB_SEQUENCE_LENGTH> {
log_n_rows: LOG_N_INSTANCES,
log_n_rows: log_n_instances,
},
);

Expand All @@ -198,15 +200,14 @@ mod tests {
verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap();
}

#[test]
#[test_case(6; "SIMD")]
#[test_case(4; "CPU fall back")]
#[cfg(not(target_arch = "wasm32"))]
fn test_wide_fib_prove_with_poseidon() {
const LOG_N_INSTANCES: u32 = 6;

fn test_wide_fib_prove_with_poseidon(log_n_instances: u32) {
let config = PcsConfig::default();
// Precompute twiddles.
let twiddles = SimdBackend::precompute_twiddles(
CanonicCoset::new(LOG_N_INSTANCES + 1 + config.fri_config.log_blowup_factor)
CanonicCoset::new(log_n_instances + 1 + config.fri_config.log_blowup_factor)
.circle_domain()
.half_coset,
);
Expand All @@ -219,7 +220,7 @@ mod tests {
);

// Trace.
let trace = generate_test_trace(LOG_N_INSTANCES);
let trace = generate_test_trace(log_n_instances);
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(trace);
tree_builder.commit(prover_channel);
Expand All @@ -228,7 +229,7 @@ mod tests {
let component = WideFibonacciComponent::new(
&mut TraceLocationAllocator::default(),
WideFibonacciEval::<FIB_SEQUENCE_LENGTH> {
log_n_rows: LOG_N_INSTANCES,
log_n_rows: log_n_instances,
},
);
let proof = prove::<SimdBackend, Poseidon252MerkleChannel>(
Expand Down

0 comments on commit 1a9755a

Please sign in to comment.