From af8430357a5d754620ffef0b2169dd17186ee55c Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Sun, 30 Jun 2024 09:26:53 +0300 Subject: [PATCH] Revert "integrated point vanishing in fri quotients (#619)" This reverts commit a6eabddaf88abee1e0c051327baada989cc81c49. --- .../prover/src/core/backend/cpu/quotients.rs | 37 +++--- .../prover/src/core/backend/simd/quotients.rs | 116 +++++++++--------- crates/prover/src/core/constraints.rs | 12 -- 3 files changed, 72 insertions(+), 93 deletions(-) diff --git a/crates/prover/src/core/backend/cpu/quotients.rs b/crates/prover/src/core/backend/cpu/quotients.rs index d6beb4716..b9c62da85 100644 --- a/crates/prover/src/core/backend/cpu/quotients.rs +++ b/crates/prover/src/core/backend/cpu/quotients.rs @@ -1,14 +1,14 @@ -use itertools::izip; +use itertools::{izip, zip_eq}; use num_traits::{One, Zero}; use super::CpuBackend; use crate::core::backend::{Backend, Col}; use crate::core::circle::CirclePoint; -use crate::core::constraints::{complex_conjugate_line_coeffs, point_vanishing_fraction}; +use crate::core::constraints::{complex_conjugate_line_coeffs, pair_vanishing}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumn; -use crate::core::fields::FieldExpOps; +use crate::core::fields::{ComplexConjugate, FieldExpOps}; use crate::core::pcs::quotients::{ColumnSampleBatch, PointSample, QuotientOps}; use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation}; use crate::core::poly::BitReversedOrder; @@ -41,27 +41,27 @@ impl QuotientOps for CpuBackend { } } -// TODO(Ohad): no longer using pair_vanishing, remove domain_point_vec and line_coeffs, or write a -// function that deals with quotients over pair_vanishing polynomials. pub fn accumulate_row_quotients( sample_batches: &[ColumnSampleBatch], columns: &[&CircleEvaluation], quotient_constants: &QuotientConstants, row: usize, - _domain_point: CirclePoint, + domain_point: CirclePoint, ) -> SecureField { let mut row_accumulator = SecureField::zero(); - for (sample_batch, _line_coeffs, batch_coeff, denominator_inverses) in izip!( + for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!( sample_batches, "ient_constants.line_coeffs, "ient_constants.batch_random_coeffs, "ient_constants.denominator_inverses ) { let mut numerator = SecureField::zero(); - for (column_index, sampled_value) in sample_batch.columns_and_values.iter() { + for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs) + { let column = &columns[*column_index]; - let value = column[row]; - numerator += value - *sampled_value; + let value = column[row] * *c; + let linear_term = *a * domain_point.y + *b; + numerator += value - linear_term; } row_accumulator = row_accumulator * *batch_coeff + numerator * denominator_inverses[row]; @@ -114,24 +114,21 @@ fn denominator_inverses( sample_batches: &[ColumnSampleBatch], domain: CircleDomain, ) -> Vec> { - let n_fracions = sample_batches.len() * domain.size(); - let mut flat_denominators = Vec::with_capacity(n_fracions); - let mut numerator_terms = Vec::with_capacity(n_fracions); + let mut flat_denominators = Vec::with_capacity(sample_batches.len() * domain.size()); for sample_batch in sample_batches { for row in 0..domain.size() { let domain_point = domain.at(row); - let (num, denom) = point_vanishing_fraction(sample_batch.point, domain_point); - flat_denominators.push(num); - numerator_terms.push(denom); + let denominator = pair_vanishing( + sample_batch.point, + sample_batch.point.complex_conjugate(), + domain_point.into_ef(), + ); + flat_denominators.push(denominator); } } let mut flat_denominator_inverses = vec![SecureField::zero(); flat_denominators.len()]; SecureField::batch_inverse(&flat_denominators, &mut flat_denominator_inverses); - flat_denominator_inverses - .iter_mut() - .zip(&numerator_terms) - .for_each(|(inv, num_term)| *inv *= *num_term); flat_denominator_inverses .chunks_mut(domain.size()) diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index 28f4d3716..ddbd387d9 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -1,7 +1,5 @@ -use std::iter::zip; - -use itertools::izip; -use num_traits::{One, Zero}; +use itertools::{izip, zip_eq, Itertools}; +use num_traits::Zero; use super::column::SecureFieldVec; use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; @@ -15,7 +13,7 @@ use crate::core::circle::CirclePoint; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumn; -use crate::core::fields::FieldOps; +use crate::core::fields::{ComplexConjugate, FieldOps}; use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation}; use crate::core::poly::BitReversedOrder; @@ -56,27 +54,35 @@ impl QuotientOps for SimdBackend { } } -// TODO(Ohad): no longer using pair_vanishing, remove domain_point_vec and line_coeffs, or write a -// function that deals with quotients over pair_vanishing polynomials. pub fn accumulate_row_quotients( sample_batches: &[ColumnSampleBatch], columns: &[&CircleEvaluation], quotient_constants: &QuotientConstants, vec_row: usize, - _domain_point_vec: (PackedBaseField, PackedBaseField), + domain_point_vec: (PackedBaseField, PackedBaseField), ) -> PackedSecureField { let mut row_accumulator = PackedSecureField::zero(); - for (sample_batch, _, batch_coeff, denominator_inverses) in izip!( + for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!( sample_batches, "ient_constants.line_coeffs, "ient_constants.batch_random_coeffs, "ient_constants.denominator_inverses ) { let mut numerator = PackedSecureField::zero(); - for (column_index, sampled_value) in sample_batch.columns_and_values.iter() { + for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs) + { let column = &columns[*column_index]; - let value = column.data[vec_row]; - numerator += PackedSecureField::broadcast(-*sampled_value) + value; + let value = PackedSecureField::broadcast(*c) * column.data[vec_row]; + // The numerator is a line equation passing through + // (sample_point.y, sample_value), (conj(sample_point), conj(sample_value)) + // evaluated at (domain_point.y, value). + // When substituting a polynomial in this line equation, we get a polynomial with a root + // at sample_point and conj(sample_point) if the original polynomial had the values + // sample_value and conj(sample_value) at these points. + // TODO(AlonH): Use single point vanishing to save a multiplication. + let linear_term = PackedSecureField::broadcast(*a) * domain_point_vec.1 + + PackedSecureField::broadcast(*b); + numerator += value - linear_term; } row_accumulator = row_accumulator * PackedSecureField::broadcast(*batch_coeff) @@ -85,65 +91,53 @@ pub fn accumulate_row_quotients( row_accumulator } -/// Point vanishing for the packed representation of the points. skips the division. -/// See [crate::core::constraints::point_vanishing_fraction] for more details. -fn packed_point_vanishing_fraction( - excluded: CirclePoint, - p: (PackedBaseField, PackedBaseField), -) -> (PackedSecureField, PackedSecureField) { - let e_conjugate = excluded.conjugate(); - let h_x = PackedSecureField::broadcast(e_conjugate.x) * p.0 - - PackedSecureField::broadcast(e_conjugate.y) * p.1; - let h_y = PackedSecureField::broadcast(e_conjugate.y) * p.0 - + PackedSecureField::broadcast(e_conjugate.x) * p.1; - (h_y, PackedSecureField::one() + h_x) +/// Pair vanishing for the packed representation of the points. See +/// [crate::core::constraints::pair_vanishing] for more details. +fn packed_pair_vanishing( + excluded0: CirclePoint, + excluded1: CirclePoint, + packed_p: (PackedBaseField, PackedBaseField), +) -> PackedSecureField { + PackedSecureField::broadcast(excluded0.y - excluded1.y) * packed_p.0 + + PackedSecureField::broadcast(excluded1.x - excluded0.x) * packed_p.1 + + PackedSecureField::broadcast(excluded0.x * excluded1.y - excluded0.y * excluded1.x) } fn denominator_inverses( sample_batches: &[ColumnSampleBatch], domain: CircleDomain, ) -> Vec> { - let mut numerators = Vec::new(); - let mut denominators = Vec::new(); - - for sample_batch in sample_batches { - for vec_row in 0..1 << (domain.log_size() - LOG_N_LANES) { - // TODO(spapini): Optimize this, for the small number of columns case. - let points = std::array::from_fn(|i| { - domain.at(bit_reverse_index( - (vec_row << LOG_N_LANES) + i, - domain.log_size(), - )) - }); - let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x)); - let domain_points_y = PackedBaseField::from_array(points.map(|p| p.y)); - let domain_point_vec = (domain_points_x, domain_points_y); - let (denominator, numerator) = - packed_point_vanishing_fraction(sample_batch.point, domain_point_vec); - denominators.push(denominator); - numerators.push(numerator); - } - } - - let denominators = SecureFieldVec { - length: denominators.len() * N_LANES, - data: denominators, - }; - - let numerators = SecureFieldVec { - length: numerators.len() * N_LANES, - data: numerators, - }; - - let mut flat_denominator_inverses = SecureFieldVec::zeros(denominators.len()); + let flat_denominators: SecureFieldVec = sample_batches + .iter() + .flat_map(|sample_batch| { + (0..(1 << (domain.log_size() - LOG_N_LANES))) + .map(|vec_row| { + // TODO(spapini): Optimize this, for the small number of columns case. + let points = std::array::from_fn(|i| { + domain.at(bit_reverse_index( + (vec_row << LOG_N_LANES) + i, + domain.log_size(), + )) + }); + let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x)); + let domain_points_y = PackedBaseField::from_array(points.map(|p| p.y)); + let domain_point_vec = (domain_points_x, domain_points_y); + packed_pair_vanishing( + sample_batch.point, + sample_batch.point.complex_conjugate(), + domain_point_vec, + ) + }) + .collect_vec() + }) + .collect(); + + let mut flat_denominator_inverses = SecureFieldVec::zeros(flat_denominators.len()); >::batch_inverse( - &denominators, + &flat_denominators, &mut flat_denominator_inverses, ); - zip(&mut flat_denominator_inverses.data, &numerators.data) - .for_each(|(inv, denom_denom)| *inv *= *denom_denom); - flat_denominator_inverses .data .chunks(domain.size() / N_LANES) diff --git a/crates/prover/src/core/constraints.rs b/crates/prover/src/core/constraints.rs index 0d412b5d1..31711d98e 100644 --- a/crates/prover/src/core/constraints.rs +++ b/crates/prover/src/core/constraints.rs @@ -71,18 +71,6 @@ pub fn point_vanishing, EF: ExtensionOf>( h.y / (EF::one() + h.x) } -/// Evaluates a vanishing polynomial of the vanish_point at a point. -/// Note that this function has a pole on the antipode of the vanish_point. -/// Returns the result in a fraction form: (numerator, denominator). -// TODO(Ohad): reorganize these functions. -pub fn point_vanishing_fraction, EF: ExtensionOf>( - vanish_point: CirclePoint, - p: CirclePoint, -) -> (EF, EF) { - let h = p.into_ef() - vanish_point; - (h.y, (EF::one() + h.x)) -} - /// Evaluates a point on a line between a point and its complex conjugate. /// Relies on the fact that every polynomial F over the base field holds: /// F(p*) == F(p)* (* being the complex conjugate).