Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FRI quotients y optimization #686

Merged
merged 1 commit into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/prover/src/core/backend/cpu/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ pub fn accumulate_row_quotients(
{
let column = &columns[*column_index];
let value = column[row] * *c;
// The numerator is a line equation passing through
// (sample_point.y, sample_value), (conj(sample_point), conj(sample_value))
// evaluated at (domain_point.y, value).
// When substituting a polynomial in this line equation, we get a polynomial with a root
// at sample_point and conj(sample_point) if the original polynomial had the values
// sample_value and conj(sample_value) at these points.
let linear_term = *a * domain_point.y + *b;
numerator += value - linear_term;
}
Expand Down
162 changes: 106 additions & 56 deletions crates/prover/src/core/backend/simd/quotients.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use itertools::{izip, zip_eq, Itertools};
use num_traits::Zero;
use tracing::{span, Level};

use super::column::SecureFieldVec;
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
Expand Down Expand Up @@ -31,55 +32,28 @@ impl QuotientOps for SimdBackend {
// TODO(spapini): Move to the caller when Columns support slices.
let (subdomain, mut subdomain_shifts) = domain.split(LOG_BLOWUP_FACTOR);

assert!(subdomain.log_size() >= LOG_N_LANES);
let mut values = SecureColumn::<Self>::zeros(subdomain.size());
let quotient_constants = quotient_constants(sample_batches, random_coeff, subdomain);

// TODO(spapini): bit reverse iterator.
for vec_row in 0..1 << (subdomain.log_size() - LOG_N_LANES) {
// TODO(spapini): Optimize this, for the small number of columns case.
let points = std::array::from_fn(|i| {
subdomain.at(bit_reverse_index(
(vec_row << LOG_N_LANES) + i,
subdomain.log_size(),
))
});
let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x));
let domain_points_y = PackedBaseField::from_array(points.map(|p| p.y));
let row_accumulator = accumulate_row_quotients(
sample_batches,
columns,
&quotient_constants,
vec_row,
(domain_points_x, domain_points_y),
);
unsafe { values.set_packed(vec_row, row_accumulator) };
}

// Extend the evaluation to the full domain.
let mut extended_eval = SecureColumn::<Self>::zeros(domain.size());

let mut i = 0;
let values = values.columns;
let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset);
let subeval_polys = values.map(|c| {
i += 1;
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(subdomain, c)
.interpolate_with_twiddles(&twiddles)
});

// Bit reverse the shifts.
// Since we traverse the domain in bit-reversed order, we need bit-reverse the shifts.
// To see why, consider the index of a point in the natural order of the domain
// (least to most):
// b0 b1 b2 b3 b4 b5
// b0 adds P, b1 adds 2P, etc.. (b5 is special and flips the sign of the point).
// b0 adds G, b1 adds 2G, etc.. (b5 is special and flips the sign of the point).
// Splitting the domain to 4 parts yields:
// subdomain: b2 b3 b4 b5, shifts: b0 b1.
// b2 b3 b4 b5 is indeed a circle domain, with a bigger jump.
// Traversing the domain in bit-reversed order, after we finish with b5, b4, b3, b2,
// we need to change b1 and then b0. This is the bit reverse of the shift b0 b1.
bit_reverse(&mut subdomain_shifts);

let (span, mut extended_eval, subeval_polys) = accumulate_quotients_on_subdomain(
subdomain,
sample_batches,
random_coeff,
columns,
domain,
);

// Extend the evaluation to the full domain.
// TODO(spapini): Try to optimize out all these copies.
for (ci, &c) in subdomain_shifts.iter().enumerate() {
let subdomain = subdomain.shift(c);
Expand All @@ -93,6 +67,7 @@ impl QuotientOps for SimdBackend {
.copy_from_slice(&eval.data);
}
}
span.exit();

SecureEvaluation {
domain,
Expand All @@ -101,39 +76,114 @@ impl QuotientOps for SimdBackend {
}
}

fn accumulate_quotients_on_subdomain(
subdomain: CircleDomain,
sample_batches: &[ColumnSampleBatch],
random_coeff: SecureField,
columns: &[&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>],
domain: CircleDomain,
) -> (
span::EnteredSpan,
SecureColumn<SimdBackend>,
[crate::core::poly::circle::CirclePoly<SimdBackend>; 4],
) {
assert!(subdomain.log_size() >= LOG_N_LANES + 2);
let mut values = SecureColumn::<SimdBackend>::zeros(subdomain.size());
let quotient_constants = quotient_constants(sample_batches, random_coeff, subdomain);

let span = span!(Level::INFO, "Quotient accumulation").entered();
// TODO(spapini): bit reverse iterator.
for quad_row in 0..1 << (subdomain.log_size() - LOG_N_LANES - 2) {
// TODO(spapini): Use optimized domain iteration.
let spaced_ys = PackedBaseField::from_array(std::array::from_fn(|i| {
subdomain
.at(bit_reverse_index(
(quad_row << (LOG_N_LANES + 2)) + (i << 2),
subdomain.log_size(),
))
.y
}));
let row_accumulator = accumulate_row_quotients(
sample_batches,
columns,
&quotient_constants,
quad_row,
spaced_ys,
);
#[allow(clippy::needless_range_loop)]
for i in 0..4 {
unsafe { values.set_packed((quad_row << 2) + i, row_accumulator[i]) };
}
}
span.exit();
let span = span!(Level::INFO, "Quotient extension").entered();

// Extend the evaluation to the full domain.
let extended_eval = SecureColumn::<SimdBackend>::zeros(domain.size());

let mut i = 0;
let values = values.columns;
let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset);
let subeval_polys = values.map(|c| {
i += 1;
CircleEvaluation::<SimdBackend, BaseField, BitReversedOrder>::new(subdomain, c)
.interpolate_with_twiddles(&twiddles)
});
(span, extended_eval, subeval_polys)
}

/// Accumulates the quotients for 4 * N_LANES rows at a time.
/// spaced_ys - y values for N_LANES points in the domain, in jumps of 4.
pub fn accumulate_row_quotients(
sample_batches: &[ColumnSampleBatch],
columns: &[&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>],
quotient_constants: &QuotientConstants<SimdBackend>,
vec_row: usize,
domain_point_vec: (PackedBaseField, PackedBaseField),
) -> PackedSecureField {
let mut row_accumulator = PackedSecureField::zero();
quad_row: usize,
spaced_ys: PackedBaseField,
) -> [PackedSecureField; 4] {
let mut row_accumulator = [PackedSecureField::zero(); 4];
for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!(
sample_batches,
&quotient_constants.line_coeffs,
&quotient_constants.batch_random_coeffs,
&quotient_constants.denominator_inverses
) {
let mut numerator = PackedSecureField::zero();
let mut numerator = [PackedSecureField::zero(); 4];
for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs)
{
let column = &columns[*column_index];
let value = PackedSecureField::broadcast(*c) * column.data[vec_row];
// The numerator is a line equation passing through
// (sample_point.y, sample_value), (conj(sample_point), conj(sample_value))
// evaluated at (domain_point.y, value).
// When substituting a polynomial in this line equation, we get a polynomial with a root
// at sample_point and conj(sample_point) if the original polynomial had the values
// sample_value and conj(sample_value) at these points.
// TODO(AlonH): Use single point vanishing to save a multiplication.
let linear_term = PackedSecureField::broadcast(*a) * domain_point_vec.1
+ PackedSecureField::broadcast(*b);
numerator += value - linear_term;
let cvalues: [_; 4] = std::array::from_fn(|i| {
PackedSecureField::broadcast(*c) * column.data[(quad_row << 2) + i]
});

// The numerator is the line equation:
// c * value - a * point.y - b;
// Note that a, b, c were already multilpied by random_coeff^i.
// See [column_line_coeffs()] for more details.
// This is why we only add here.
// 4 consecutive point in the domain in bit reversed order are:
// P, -P, P + H, -P + H.
// H being the half point (-1,0). The y values for these are
// P.y, -P.y, -P.y, P.y.
// We use this fact to save multiplications.
// spaced_ys are the y value in jumps of 4:
// P0.y, P1.y, P2.y, ...
let spaced_ay = PackedSecureField::broadcast(*a) * spaced_ys;
// t0:t1 = a*P0.y, -a*P0.y, a*P1.y, -a*P1.y, ...
let (t0, t1) = spaced_ay.interleave(-spaced_ay);
// t2:t3:t4:t5 = a*P0.y, -a*P0.y, -a*P0.y, a*P0.y, a*P1.y, -a*P1.y, ...
let (t2, t3) = t0.interleave(-t0);
let (t4, t5) = t1.interleave(-t1);
let ay = [t2, t3, t4, t5];
for i in 0..4 {
numerator[i] += cvalues[i] - ay[i] - PackedSecureField::broadcast(*b);
}
}

row_accumulator = row_accumulator * PackedSecureField::broadcast(*batch_coeff)
+ numerator * denominator_inverses.data[vec_row];
for i in 0..4 {
row_accumulator[i] = row_accumulator[i] * PackedSecureField::broadcast(*batch_coeff)
+ numerator[i] * denominator_inverses.data[(quad_row << 2) + i];
}
}
row_accumulator
}
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ mod tests {

// Get from environment variable:
let log_n_instances = env::var("LOG_N_INSTANCES")
.unwrap_or_else(|_| "8".to_string())
.unwrap_or_else(|_| "10".to_string())
.parse::<u32>()
.unwrap();
let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32;
Expand Down
Loading