Skip to content

Commit

Permalink
Extension FRI optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Jul 2, 2024
1 parent e33378a commit 21492c9
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 20 deletions.
92 changes: 73 additions & 19 deletions crates/prover/src/core/backend/simd/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,36 @@ use crate::core::backend::{Col, Column};
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE};
use crate::core::fields::{ComplexConjugate, FieldOps};
use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse_index;
use crate::core::prover::LOG_BLOWUP_FACTOR;
use crate::core::utils::{bit_reverse, bit_reverse_index};

impl QuotientOps for SimdBackend {
fn accumulate_quotients(
domain: CircleDomain,
outer_domain: CircleDomain,
columns: &[&CircleEvaluation<Self, BaseField, BitReversedOrder>],
random_coeff: SecureField,
sample_batches: &[ColumnSampleBatch],
) -> SecureEvaluation<Self> {
assert!(domain.log_size() >= LOG_N_LANES);
let mut values = SecureColumn::<Self>::zeros(domain.size());
let quotient_constants = quotient_constants(sample_batches, random_coeff, domain);
// Split the domain into a subdomain and a shift coset.
// TODO(spapini): Move to the caller when Columns support slices.
let (subdomain, mut domain_shifts) = outer_domain.split(LOG_BLOWUP_FACTOR);

assert!(subdomain.log_size() >= LOG_N_LANES);
let mut values = SecureColumn::<Self>::zeros(subdomain.size());
let quotient_constants = quotient_constants(sample_batches, random_coeff, subdomain);

// TODO(spapini): bit reverse iterator.
for vec_row in 0..1 << (domain.log_size() - LOG_N_LANES) {
for vec_row in 0..1 << (subdomain.log_size() - LOG_N_LANES) {
// TODO(spapini): Optimize this, for the small number of columns case.
let points = std::array::from_fn(|i| {
domain.at(bit_reverse_index(
subdomain.at(bit_reverse_index(
(vec_row << LOG_N_LANES) + i,
domain.log_size(),
subdomain.log_size(),
))
});
let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x));
Expand All @@ -50,7 +55,49 @@ impl QuotientOps for SimdBackend {
);
unsafe { values.set_packed(vec_row, row_accumulator) };
}
SecureEvaluation { domain, values }

// Extend the evaluation to the full domain.
let mut extended_eval = SecureColumn::<Self>::zeros(outer_domain.size());

let mut i = 0;
let values = values.columns;
let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset);
let subeval_polys = values.map(|c| {
i += 1;
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(subdomain, c)
.interpolate_with_twiddles(&twiddles)
});

// Since we traverse the domain in bit-reversed order, we need bit-reverse the shifts.
// To see why, consider the index of a point in the natural order of the domain
// (least to most):
// b0 b1 b2 b3 b4 b5
// b0 adds P, b1 adds 2P, etc.. (b5 is special and flips the sign of the point).
// Splitting the coset to 4 parts yields:
// subdomain: b2 b3 b4 b5, shifts: b0 b1.
// b2 b3 b4 b5 is indeed a circle domain, with a bigger jump.
// Traversing the domain in bit-reversed order, after we finish with b5, b4, b3, b2,
// we need to change b1 and then b0. This is the bit reverse of the shift b0 b1.
bit_reverse(&mut domain_shifts);

// TODO(spapini): Try to optimize out all these copies.
for (ci, &c) in domain_shifts.iter().enumerate() {
let subdomain = subdomain.shift(c);

let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset);
#[allow(clippy::needless_range_loop)]
for i in 0..SECURE_EXTENSION_DEGREE {
// Sanity check.
let eval = subeval_polys[i].evaluate_with_twiddles(subdomain, &twiddles);
extended_eval.columns[i].data[(ci * eval.data.len())..((ci + 1) * eval.data.len())]
.copy_from_slice(&eval.data);
}
}

SecureEvaluation {
domain: outer_domain,
values: extended_eval,
}
}
}

Expand Down Expand Up @@ -172,21 +219,28 @@ mod tests {
use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::LOG_BLOWUP_FACTOR;
use crate::qm31;

#[test]
fn test_accumulate_quotients() {
const LOG_SIZE: u32 = 8;
let domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let e0: BaseFieldVec = (0..domain.size()).map(BaseField::from).collect();
let e1: BaseFieldVec = (0..domain.size()).map(|i| BaseField::from(2 * i)).collect();
let columns = vec![
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(domain, e0),
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(domain, e1),
let small_domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let domain = CanonicCoset::new(LOG_SIZE + LOG_BLOWUP_FACTOR).circle_domain();
let e0: BaseFieldVec = (0..small_domain.size()).map(BaseField::from).collect();
let e1: BaseFieldVec = (0..small_domain.size())
.map(|i| BaseField::from(2 * i))
.collect();
let polys = vec![
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(small_domain, e0)
.interpolate(),
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(small_domain, e1)
.interpolate(),
];
let columns = vec![polys[0].evaluate(domain), polys[1].evaluate(domain)];
let random_coeff = qm31!(1, 2, 3, 4);
let a = qm31!(3, 6, 9, 12);
let b = qm31!(4, 8, 12, 16);
let a = polys[0].eval_at_point(SECURE_FIELD_CIRCLE_GEN);
let b = polys[1].eval_at_point(SECURE_FIELD_CIRCLE_GEN);
let samples = vec![ColumnSampleBatch {
point: SECURE_FIELD_CIRCLE_GEN,
columns_and_values: vec![(0, a), (1, b)],
Expand Down
44 changes: 43 additions & 1 deletion crates/prover/src/core/poly/circle/domain.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::iter::Chain;

use itertools::Itertools;

use crate::core::circle::{
CirclePoint, CirclePointIndex, Coset, CosetIterator, M31_CIRCLE_LOG_ORDER,
};
Expand Down Expand Up @@ -77,6 +79,23 @@ impl CircleDomain {
pub fn is_canonic(&self) -> bool {
self.half_coset.initial_index * 4 == self.half_coset.step_size
}

/// Splits a circle domain into a smaller [CircleDomain]s, shifted by offsets.
pub fn split(&self, log_parts: u32) -> (CircleDomain, Vec<CirclePointIndex>) {
assert!(log_parts <= self.half_coset.log_size);
let subdomain = CircleDomain::new(Coset::new(
self.half_coset.initial_index,
self.half_coset.log_size - log_parts,
));
let shifts = (0..1 << log_parts)
.map(|i| self.half_coset.step_size * i)
.collect_vec();
(subdomain, shifts)
}

pub fn shift(&self, shift: CirclePointIndex) -> CircleDomain {
CircleDomain::new(self.half_coset.shift(shift))
}
}

impl IntoIterator for CircleDomain {
Expand All @@ -101,6 +120,8 @@ type CircleDomainIndexIterator =

#[cfg(test)]
mod tests {
use itertools::Itertools;

use super::CircleDomain;
use crate::core::circle::{CirclePointIndex, Coset};
use crate::core::poly::circle::CanonicCoset;
Expand Down Expand Up @@ -134,7 +155,7 @@ mod tests {
}

#[test]
pub fn test_at_circle_domain() {
fn test_at_circle_domain() {
let domain = CanonicCoset::new(7).circle_domain();
let half_domain_size = domain.size() / 2;

Expand All @@ -143,4 +164,25 @@ mod tests {
assert_eq!(domain.at(i), domain.at(i + half_domain_size).conjugate());
}
}

#[test]
fn test_domain_split() {
let domain = CanonicCoset::new(5).circle_domain();
let (subdomain, shifts) = domain.split(2);

let domain_points = domain.iter().collect::<Vec<_>>();
let points_for_each_domain = shifts
.iter()
.map(|&shift| (subdomain.shift(shift)).iter().collect_vec())
.collect::<Vec<_>>();
// Interleave the points from each subdomain.
let extended_points = (0..(1 << 3))
.flat_map(|point_ind| {
(0..(1 << 2))
.map(|shift_ind| points_for_each_domain[shift_ind][point_ind])
.collect_vec()
})
.collect_vec();
assert_eq!(domain_points, extended_points);
}
}

0 comments on commit 21492c9

Please sign in to comment.