Skip to content

Commit

Permalink
FRI quotients y optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Jul 1, 2024
1 parent d8f7e32 commit 0d2b522
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 35 deletions.
6 changes: 6 additions & 0 deletions crates/prover/src/core/backend/cpu/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ pub fn accumulate_row_quotients(
{
let column = &columns[*column_index];
let value = column[row] * *c;
// The numerator is a line equation passing through
// (sample_point.y, sample_value), (conj(sample_point), conj(sample_value))
// evaluated at (domain_point.y, value).
// When substituting a polynomial in this line equation, we get a polynomial with a root
// at sample_point and conj(sample_point) if the original polynomial had the values
// sample_value and conj(sample_value) at these points.
let linear_term = *a * domain_point.y + *b;
numerator += value - linear_term;
}
Expand Down
89 changes: 55 additions & 34 deletions crates/prover/src/core/backend/simd/quotients.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use itertools::{izip, zip_eq, Itertools};
use num_traits::Zero;
use tracing::{span, Level};

use super::column::SecureFieldVec;
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
Expand All @@ -15,11 +16,10 @@ use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE};
use crate::core::fields::{ComplexConjugate, FieldOps};
use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::LOG_BLOWUP_FACTOR;
use crate::core::utils::{bit_reverse, bit_reverse_index};
use crate::core::poly::circle::PolyOps;

impl QuotientOps for SimdBackend {
fn accumulate_quotients(
Expand All @@ -32,30 +32,36 @@ impl QuotientOps for SimdBackend {
// TODO(spapini): Move to the caller when Columns support slices.
let (domain, mut shifts) = outer_domain.split(LOG_BLOWUP_FACTOR);

assert!(domain.log_size() >= LOG_N_LANES);
assert!(domain.log_size() >= LOG_N_LANES + 2);
let mut values = SecureColumn::<Self>::zeros(domain.size());
let quotient_constants = quotient_constants(sample_batches, random_coeff, domain);

let span = span!(Level::INFO, "Quotient accumulation").entered();
// TODO(spapini): bit reverse iterator.
for vec_row in 0..1 << (domain.log_size() - LOG_N_LANES) {
// TODO(spapini): Optimize this, for the small number of columns case.
let points = std::array::from_fn(|i| {
domain.at(bit_reverse_index(
(vec_row << LOG_N_LANES) + i,
domain.log_size(),
))
});
let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x));
let domain_points_y = PackedBaseField::from_array(points.map(|p| p.y));
for quad_row in 0..1 << (domain.log_size() - LOG_N_LANES - 2) {
// TODO(spapini): Use optimized domain iteration.
let spaced_ys = PackedBaseField::from_array(std::array::from_fn(|i| {
domain
.at(bit_reverse_index(
(quad_row << (LOG_N_LANES + 2)) + (i << 2),
domain.log_size(),
))
.y
}));
let row_accumulator = accumulate_row_quotients(
sample_batches,
columns,
&quotient_constants,
vec_row,
(domain_points_x, domain_points_y),
quad_row,
spaced_ys,
);
unsafe { values.set_packed(vec_row, row_accumulator) };
#[allow(clippy::needless_range_loop)]
for i in 0..4 {
unsafe { values.set_packed((quad_row << 2) + i, row_accumulator[i]) };
}
}
span.exit();
let span = span!(Level::INFO, "Quotient extension").entered();

// Extend the evaluation to the full domain.
let mut extended_eval = SecureColumn::<Self>::zeros(outer_domain.size());
Expand Down Expand Up @@ -84,6 +90,7 @@ impl QuotientOps for SimdBackend {
.copy_from_slice(&eval.data);
}
}
span.exit();

SecureEvaluation {
domain: outer_domain,
Expand All @@ -92,39 +99,53 @@ impl QuotientOps for SimdBackend {
}
}

/// Accumulates the quotients for 4 * N_LANES rows at a time.
/// spaced_ys - y values for N_LANES points in the domain, in jumps of 4.
pub fn accumulate_row_quotients(
sample_batches: &[ColumnSampleBatch],
columns: &[&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>],
quotient_constants: &QuotientConstants<SimdBackend>,
vec_row: usize,
domain_point_vec: (PackedBaseField, PackedBaseField),
) -> PackedSecureField {
let mut row_accumulator = PackedSecureField::zero();
quad_row: usize,
spaced_ys: PackedBaseField,
) -> [PackedSecureField; 4] {
let mut row_accumulator = [PackedSecureField::zero(); 4];
for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!(
sample_batches,
&quotient_constants.line_coeffs,
&quotient_constants.batch_random_coeffs,
&quotient_constants.denominator_inverses
) {
let mut numerator = PackedSecureField::zero();
let mut numerator = [PackedSecureField::zero(); 4];
for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs)
{
let column = &columns[*column_index];
let value = PackedSecureField::broadcast(*c) * column.data[vec_row];
// The numerator is a line equation passing through
// (sample_point.y, sample_value), (conj(sample_point), conj(sample_value))
// evaluated at (domain_point.y, value).
// When substituting a polynomial in this line equation, we get a polynomial with a root
// at sample_point and conj(sample_point) if the original polynomial had the values
// sample_value and conj(sample_value) at these points.
// TODO(AlonH): Use single point vanishing to save a multiplication.
let linear_term = PackedSecureField::broadcast(*a) * domain_point_vec.1
+ PackedSecureField::broadcast(*b);
numerator += value - linear_term;
let values: [_; 4] = std::array::from_fn(|i| {
PackedSecureField::broadcast(*c) * column.data[(quad_row << 2) + i]
});

// The numerator is the line equation:
// value - a * point.y - b;
// The y values for 4 consecutive points in the domain (bit reversed order) are
// y, -y, -y, y.
// We use this fact to save multiplications.
// spaced_ys are the y value in jumps of 4:
// y0, y1, y2, ...
let spaced_ay = PackedSecureField::broadcast(*a) * spaced_ys;
// t0:t1 = ay0, -ay0, ay1, -ay1, ...
let (t0, t1) = spaced_ay.interleave(-spaced_ay);
// t2:t3:t4:t5 = ay0, -ay0, -ay0, ay0, ay1, -ay1, ...
let (t2, t3) = t0.interleave(-t0);
let (t4, t5) = t1.interleave(-t1);
let ay = [t2, t3, t4, t5];
for i in 0..4 {
numerator[i] += values[i] - ay[i] - PackedSecureField::broadcast(*b);
}
}

row_accumulator = row_accumulator * PackedSecureField::broadcast(*batch_coeff)
+ numerator * denominator_inverses.data[vec_row];
for i in 0..4 {
row_accumulator[i] = row_accumulator[i] * PackedSecureField::broadcast(*batch_coeff)
+ numerator[i] * denominator_inverses.data[(quad_row << 2) + i];
}
}
row_accumulator
}
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ mod tests {

// Get from environment variable:
let log_n_instances = env::var("LOG_N_INSTANCES")
.unwrap_or_else(|_| "8".to_string())
.unwrap_or_else(|_| "10".to_string())
.parse::<u32>()
.unwrap();
let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32;
Expand Down

0 comments on commit 0d2b522

Please sign in to comment.