Skip to content

Commit

Permalink
Separate cpu and avx wide fib.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Apr 17, 2024
1 parent 288122c commit 6bee5d8
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
13 changes: 8 additions & 5 deletions crates/prover/src/examples/wide_fibonacci/avx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ impl Component<AVX512Backend> for WideFibComponent {

let _span = span!(Level::INFO, "Constraint pointwise eval").entered();

let constraint_log_degree_bound = self.log_column_size() + 1;
let constraint_log_degree_bound =
Component::<AVX512Backend>::max_constraint_log_degree_bound(self);
let n_constraints = Component::<AVX512Backend>::n_constraints(self);
let [accum] =
evaluation_accumulator.columns([(constraint_log_degree_bound, N_COLUMNS - 1)]);
evaluation_accumulator.columns([(constraint_log_degree_bound, n_constraints)]);

for vec_row in 0..(1 << (eval_domain.log_size() - VECS_LOG_SIZE as u32)) {
// Numerator.
Expand Down Expand Up @@ -152,6 +154,7 @@ mod tests {

use crate::commitment_scheme::blake2_hash::Blake2sHasher;
use crate::commitment_scheme::hasher::Hasher;
use crate::core::backend::avx512::AVX512Backend;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::fields::m31::BaseField;
use crate::core::fields::IntoSlice;
Expand All @@ -172,12 +175,12 @@ mod tests {
log_fibonacci_size: LOG_N_COLUMNS as u32,
log_n_instances: LOG_N_ROWS,
};
let air = WideFibAir { component };
let span = span!(Level::INFO, "Trace generation").entered();
let trace = gen_trace(LOG_N_ROWS as usize);
let trace = gen_trace(component.log_column_size() as usize);
span.exit();
let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));
let proof = prove(&air, channel, trace).unwrap();
let air = WideFibAir { component };
let proof = prove::<AVX512Backend>(&air, channel, trace).unwrap();

let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));
verify(proof, &air, channel).unwrap();
Expand Down
10 changes: 9 additions & 1 deletion crates/prover/src/examples/wide_fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@ pub struct WideFibComponent {
impl WideFibComponent {
pub fn fill_initial_trace(&self, private_input: Vec<Input>) -> ColumnVec<Vec<BaseField>> {
let n_instances = 1 << self.log_n_instances;
assert_eq!(private_input.len(), n_instances);
assert_eq!(
private_input.len(),
n_instances,
"The number of inputs must match the number of instances"
);
assert!(
self.log_fibonacci_size >= LOG_N_COLUMNS as u32,
"The fibonacci size must be at least equal to the length of a row"
);
let n_rows_per_instance = (1 << self.log_fibonacci_size) / N_COLUMNS;
let n_rows = n_instances * n_rows_per_instance;
let zero_vec = vec![BaseField::zero(); n_rows];
Expand Down
8 changes: 4 additions & 4 deletions crates/prover/src/examples/wide_fibonacci/constraint_eval.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use num_traits::Zero;
use num_traits::{One, Zero};

use super::component::WideFibComponent;
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
Expand Down Expand Up @@ -60,8 +60,8 @@ impl Component<CPUBackend> for WideFibComponent {
#[allow(clippy::needless_range_loop)]
for i in 0..trace_eval_domain.size() {
// Boundary constraint.
numerators[i] += accum.random_coeff_powers[254]
* (trace_evals[0].values.at(i) - BaseField::from_u32_unchecked(1));
numerators[i] += accum.random_coeff_powers[N_COLUMNS - 2]
* (trace_evals[0].values.at(i) - BaseField::one());

// Step constraints.
for j in 0..N_COLUMNS - 2 {
Expand All @@ -85,7 +85,7 @@ impl Component<CPUBackend> for WideFibComponent {
let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset;
let denom = coset_vanishing(constraint_zero_domain, point);
let denom_inverse = denom.inverse();
let numerator = mask[0][0] - BaseField::from_u32_unchecked(1);
let numerator = mask[0][0] - BaseField::one();
evaluation_accumulator.accumulate(numerator * denom_inverse);

for i in 0..N_COLUMNS - 2 {
Expand Down

0 comments on commit 6bee5d8

Please sign in to comment.