Skip to content

Commit

Permalink
integrated point vanishing in fri quotients (#619)
Browse files Browse the repository at this point in the history
<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/starkware-libs/stwo/619)
<!-- Reviewable:end -->
  • Loading branch information
ohad-starkware authored May 16, 2024
1 parent c3b650b commit a6eabdd
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 55 deletions.
1 change: 0 additions & 1 deletion crates/prover/src/core/backend/avx512/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ mod tests {
let (avx_g, avx_lambda) = AVX512Backend::decompose(&avx_eval);

assert_eq!(avx_lambda, cpu_lambda);

for i in 0..(1 << DOMAIN_LOG_SIZE) {
assert_eq!(avx_g.values.at(i), cpu_g.values.at(i));
}
Expand Down
9 changes: 8 additions & 1 deletion crates/prover/src/core/backend/avx512/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl FromIterator<BaseField> for BaseFieldVec {
}
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, Default)]
pub struct SecureFieldVec {
pub data: Vec<PackedSecureField>,
length: usize,
Expand Down Expand Up @@ -183,6 +183,13 @@ impl Column<SecureField> for SecureFieldVec {
}
}

impl Extend<PackedSecureField> for SecureFieldVec {
fn extend<T: IntoIterator<Item = PackedSecureField>>(&mut self, iter: T) {
self.data.extend(iter);
self.length = self.data.len() * K_BLOCK_SIZE;
}
}

impl FromIterator<SecureField> for SecureFieldVec {
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
let mut chunks = iter.into_iter().array_chunks();
Expand Down
70 changes: 34 additions & 36 deletions crates/prover/src/core/backend/avx512/quotients.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use itertools::{izip, zip_eq, Itertools};
use itertools::{izip, Itertools};
use num_traits::One;

use super::qm31::PackedSecureField;
use super::{AVX512Backend, SecureFieldVec, K_BLOCK_SIZE, VECS_LOG_SIZE};
Expand All @@ -11,7 +12,7 @@ use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fields::{ComplexConjugate, FieldOps};
use crate::core::fields::FieldOps;
use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
Expand Down Expand Up @@ -52,35 +53,27 @@ impl QuotientOps for AVX512Backend {
}
}

// TODO(Ohad): no longer using pair_vanishing, remove domain_point_vec and line_coeffs, or write a
// function that deals with quotients over pair_vanishing polynomials.
pub fn accumulate_row_quotients(
sample_batches: &[ColumnSampleBatch],
columns: &[&CircleEvaluation<AVX512Backend, BaseField, BitReversedOrder>],
quotient_constants: &QuotientConstants<AVX512Backend>,
vec_row: usize,
domain_point_vec: (PackedBaseField, PackedBaseField),
_domain_point_vec: (PackedBaseField, PackedBaseField),
) -> PackedSecureField {
let mut row_accumulator = PackedSecureField::zero();
for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!(
for (sample_batch, _, batch_coeff, denominator_inverses) in izip!(
sample_batches,
&quotient_constants.line_coeffs,
&quotient_constants.batch_random_coeffs,
&quotient_constants.denominator_inverses
) {
let mut numerator = PackedSecureField::zero();
for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs)
{
for (column_index, sampled_value) in sample_batch.columns_and_values.iter() {
let column = &columns[*column_index];
let value = PackedSecureField::broadcast(*c) * column.data[vec_row];
// The numerator is a line equation passing through
// (sample_point.y, sample_value), (conj(sample_point), conj(sample_value))
// evaluated at (domain_point.y, value).
// When substituting a polynomial in this line equation, we get a polynomial with a root
// at sample_point and conj(sample_point) if the original polynomial had the values
// sample_value and conj(sample_value) at these points.
// TODO(AlonH): Use single point vanishing to save a multiplication.
let linear_term = PackedSecureField::broadcast(*a) * domain_point_vec.1
+ PackedSecureField::broadcast(*b);
numerator += value - linear_term;
let value = column.data[vec_row];
numerator += PackedSecureField::broadcast(-*sampled_value) + value;
}

row_accumulator = row_accumulator * PackedSecureField::broadcast(*batch_coeff)
Expand All @@ -89,23 +82,25 @@ pub fn accumulate_row_quotients(
row_accumulator
}

/// Pair vanishing for the packed representation of the points. See
/// [crate::core::constraints::pair_vanishing] for more details.
fn packed_pair_vanishing(
excluded0: CirclePoint<SecureField>,
excluded1: CirclePoint<SecureField>,
packed_p: (PackedBaseField, PackedBaseField),
) -> PackedSecureField {
PackedSecureField::broadcast(excluded0.y - excluded1.y) * packed_p.0
+ PackedSecureField::broadcast(excluded1.x - excluded0.x) * packed_p.1
+ PackedSecureField::broadcast(excluded0.x * excluded1.y - excluded0.y * excluded1.x)
/// Point vanishing for the packed representation of the points. skips the division.
/// See [crate::core::constraints::point_vanishing_fraction] for more details.
fn packed_point_vanishing_fraction(
excluded: CirclePoint<SecureField>,
p: (PackedBaseField, PackedBaseField),
) -> (PackedSecureField, PackedSecureField) {
let e_conjugate = excluded.conjugate();
let h_x = PackedSecureField::broadcast(e_conjugate.x) * p.0
- PackedSecureField::broadcast(e_conjugate.y) * p.1;
let h_y = PackedSecureField::broadcast(e_conjugate.y) * p.0
+ PackedSecureField::broadcast(e_conjugate.x) * p.1;
(h_y, (PackedSecureField::one() + h_x))
}

fn denominator_inverses(
sample_batches: &[ColumnSampleBatch],
domain: CircleDomain,
) -> Vec<Col<AVX512Backend, SecureField>> {
let flat_denominators: SecureFieldVec = sample_batches
let (denominators, numerators): (SecureFieldVec, SecureFieldVec) = sample_batches
.iter()
.flat_map(|sample_batch| {
(0..(1 << (domain.log_size() - VECS_LOG_SIZE as u32)))
Expand All @@ -120,22 +115,25 @@ fn denominator_inverses(
let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x));
let domain_points_y = PackedBaseField::from_array(points.map(|p| p.y));
let domain_point_vec = (domain_points_x, domain_points_y);
packed_pair_vanishing(
sample_batch.point,
sample_batch.point.complex_conjugate(),
domain_point_vec,
)

packed_point_vanishing_fraction(sample_batch.point, domain_point_vec)
})
.collect_vec()
})
.collect();
.unzip();

let mut flat_denominator_inverses = SecureFieldVec::zeros(flat_denominators.len());
let mut flat_denominator_inverses = SecureFieldVec::zeros(denominators.len());
<AVX512Backend as FieldOps<SecureField>>::batch_inverse(
&flat_denominators,
&denominators,
&mut flat_denominator_inverses,
);

flat_denominator_inverses
.data
.iter_mut()
.zip(&numerators.data)
.for_each(|(inv, denom_denom)| *inv *= *denom_denom);

flat_denominator_inverses
.data
.chunks(domain.size() / K_BLOCK_SIZE)
Expand Down
39 changes: 22 additions & 17 deletions crates/prover/src/core/backend/cpu/quotients.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use itertools::{izip, zip_eq};
use itertools::izip;
use num_traits::{One, Zero};

use super::CPUBackend;
use crate::core::backend::{Backend, Col};
use crate::core::circle::CirclePoint;
use crate::core::constraints::{complex_conjugate_line_coeffs, pair_vanishing};
use crate::core::constraints::{complex_conjugate_line_coeffs, point_vanishing_fraction};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fields::{ComplexConjugate, FieldExpOps};
use crate::core::fields::FieldExpOps;
use crate::core::pcs::quotients::{ColumnSampleBatch, PointSample, QuotientOps};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
Expand Down Expand Up @@ -41,27 +41,27 @@ impl QuotientOps for CPUBackend {
}
}

// TODO(Ohad): no longer using pair_vanishing, remove domain_point_vec and line_coeffs, or write a
// function that deals with quotients over pair_vanishing polynomials.
pub fn accumulate_row_quotients(
sample_batches: &[ColumnSampleBatch],
columns: &[&CircleEvaluation<CPUBackend, BaseField, BitReversedOrder>],
quotient_constants: &QuotientConstants<CPUBackend>,
row: usize,
domain_point: CirclePoint<BaseField>,
_domain_point: CirclePoint<BaseField>,
) -> SecureField {
let mut row_accumulator = SecureField::zero();
for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!(
for (sample_batch, _line_coeffs, batch_coeff, denominator_inverses) in izip!(
sample_batches,
&quotient_constants.line_coeffs,
&quotient_constants.batch_random_coeffs,
&quotient_constants.denominator_inverses
) {
let mut numerator = SecureField::zero();
for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs)
{
for (column_index, sampled_value) in sample_batch.columns_and_values.iter() {
let column = &columns[*column_index];
let value = column[row] * *c;
let linear_term = *a * domain_point.y + *b;
numerator += value - linear_term;
let value = column[row];
numerator += value - *sampled_value;
}

row_accumulator = row_accumulator * *batch_coeff + numerator * denominator_inverses[row];
Expand Down Expand Up @@ -114,21 +114,24 @@ fn denominator_inverses(
sample_batches: &[ColumnSampleBatch],
domain: CircleDomain,
) -> Vec<Col<CPUBackend, SecureField>> {
let mut flat_denominators = Vec::with_capacity(sample_batches.len() * domain.size());
let n_fracions = sample_batches.len() * domain.size();
let mut flat_denominators = Vec::with_capacity(n_fracions);
let mut numerator_terms = Vec::with_capacity(n_fracions);
for sample_batch in sample_batches {
for row in 0..domain.size() {
let domain_point = domain.at(row);
let denominator = pair_vanishing(
sample_batch.point,
sample_batch.point.complex_conjugate(),
domain_point.into_ef(),
);
flat_denominators.push(denominator);
let (num, denom) = point_vanishing_fraction(sample_batch.point, domain_point);
flat_denominators.push(num);
numerator_terms.push(denom);
}
}

let mut flat_denominator_inverses = vec![SecureField::zero(); flat_denominators.len()];
SecureField::batch_inverse(&flat_denominators, &mut flat_denominator_inverses);
flat_denominator_inverses
.iter_mut()
.zip(&numerator_terms)
.for_each(|(inv, num_term)| *inv *= *num_term);

flat_denominator_inverses
.chunks_mut(domain.size())
Expand Down Expand Up @@ -176,6 +179,8 @@ mod tests {
use crate::{m31, qm31};

#[test]
// Ignored because we allow polynomials outside of fft-space, hence it is not a bug anymore.
#[ignore]
fn test_quotients_are_low_degree() {
const LOG_SIZE: u32 = 7;
let polynomial = CPUCirclePoly::new((0..1 << LOG_SIZE).map(|i| m31!(i)).collect());
Expand Down
12 changes: 12 additions & 0 deletions crates/prover/src/core/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ pub fn point_vanishing<F: ExtensionOf<BaseField>, EF: ExtensionOf<F>>(
h.y / (EF::one() + h.x)
}

/// Evaluates a vanishing polynomial of the vanish_point at a point.
/// Note that this function has a pole on the antipode of the vanish_point.
/// Returns the result in a fraction form: (numerator, denominator).
// TODO(Ohad): reorganize these functions.
pub fn point_vanishing_fraction<F: ExtensionOf<BaseField>, EF: ExtensionOf<F>>(
vanish_point: CirclePoint<EF>,
p: CirclePoint<F>,
) -> (EF, EF) {
let h = p.into_ef() - vanish_point;
(h.y, (EF::one() + h.x))
}

/// Evaluates a point on a line between a point and its complex conjugate.
/// Relies on the fact that every polynomial F over the base field holds:
/// F(p*) == F(p)* (* being the complex conjugate).
Expand Down
4 changes: 4 additions & 0 deletions crates/prover/src/core/pcs/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ mod tests {
use crate::{m31, qm31};

#[test]
// Ignored because we allow polynomials outside of fft-space, hence it is not a bug anymore.
// TODO(Ohad): consider adding a test for the new behavior,
// i.e. polynomials inside the r-h space.
#[ignore]
fn test_quotients_are_low_degree() {
const LOG_SIZE: u32 = 7;
let polynomial = CPUCirclePoly::new((0..1 << LOG_SIZE).map(|i| m31!(i)).collect());
Expand Down

0 comments on commit a6eabdd

Please sign in to comment.