From 64af9abb5026be68094ebcdd09852ea92c5e90c2 Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Sun, 30 Jun 2024 12:48:58 +0300 Subject: [PATCH] FRI quotients y optimization --- .../prover/src/core/backend/simd/quotients.rs | 50 +++++++++++-------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index 8dccc9d1c..2dd35f651 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -31,16 +31,16 @@ impl QuotientOps for SimdBackend { // TODO(spapini): Move to the caller when Columns support slices. let (domain, mut shifts) = outer_domain.split(LOG_BLOWUP_FACTOR); - assert!(domain.log_size() >= LOG_N_LANES); + assert!(domain.log_size() >= LOG_N_LANES + 2); let mut values = SecureColumn::::zeros(domain.size()); let quotient_constants = quotient_constants(sample_batches, random_coeff, domain); // TODO(spapini): bit reverse iterator. - for vec_row in 0..1 << (domain.log_size() - LOG_N_LANES) { + for quad_row in 0..1 << (domain.log_size() - LOG_N_LANES - 2) { // TODO(spapini): Optimize this, for the small number of columns case. let points = std::array::from_fn(|i| { domain.at(bit_reverse_index( - (vec_row << LOG_N_LANES) + i, + (quad_row << (LOG_N_LANES + 2)) + (i << 2), domain.log_size(), )) }); @@ -50,10 +50,13 @@ impl QuotientOps for SimdBackend { sample_batches, columns, "ient_constants, - vec_row, + quad_row, (domain_points_x, domain_points_y), ); - unsafe { values.set_packed(vec_row, row_accumulator) }; + #[allow(clippy::needless_range_loop)] + for i in 0..4 { + unsafe { values.set_packed((quad_row << 2) + i, row_accumulator[i]) }; + } } // Extend the evaluation to the full domain. @@ -93,35 +96,38 @@ pub fn accumulate_row_quotients( sample_batches: &[ColumnSampleBatch], columns: &[&CircleEvaluation], quotient_constants: &QuotientConstants, - vec_row: usize, + quad_row: usize, domain_point_vec: (PackedBaseField, PackedBaseField), -) -> PackedSecureField { - let mut row_accumulator = PackedSecureField::zero(); +) -> [PackedSecureField; 4] { + let mut row_accumulator = [PackedSecureField::zero(); 4]; for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!( sample_batches, "ient_constants.line_coeffs, "ient_constants.batch_random_coeffs, "ient_constants.denominator_inverses ) { - let mut numerator = PackedSecureField::zero(); + let mut numerator = [PackedSecureField::zero(); 4]; for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs) { let column = &columns[*column_index]; - let value = PackedSecureField::broadcast(*c) * column.data[vec_row]; - // The numerator is a line equation passing through - // (sample_point.y, sample_value), (conj(sample_point), conj(sample_value)) - // evaluated at (domain_point.y, value). - // When substituting a polynomial in this line equation, we get a polynomial with a root - // at sample_point and conj(sample_point) if the original polynomial had the values - // sample_value and conj(sample_value) at these points. - // TODO(AlonH): Use single point vanishing to save a multiplication. - let linear_term = PackedSecureField::broadcast(*a) * domain_point_vec.1 - + PackedSecureField::broadcast(*b); - numerator += value - linear_term; + let values: [_; 4] = std::array::from_fn(|i| { + PackedSecureField::broadcast(*c) * column.data[(quad_row << 2) + i] + }); + // y values are y,-y,-y,y. + let spaced_linear_term = PackedSecureField::broadcast(*a) * domain_point_vec.1; + let (t0, t1) = spaced_linear_term.interleave(-spaced_linear_term); + let (t2, t3) = t0.interleave(-t0); + let (t4, t5) = t1.interleave(-t1); + let linear_term = [t2, t3, t4, t5]; + for i in 0..4 { + numerator[i] += values[i] - linear_term[i] - PackedSecureField::broadcast(*b); + } } - row_accumulator = row_accumulator * PackedSecureField::broadcast(*batch_coeff) - + numerator * denominator_inverses.data[vec_row]; + for i in 0..4 { + row_accumulator[i] = row_accumulator[i] * PackedSecureField::broadcast(*batch_coeff) + + numerator[i] * denominator_inverses.data[(quad_row << 2) + i]; + } } row_accumulator }