From f9768909f44304a6e1e068d93c461d709b5d2142 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Wed, 21 Aug 2024 17:49:03 +0100 Subject: [PATCH] Add eval order and interpolate to SecureEvaluation (#794) --- crates/prover/src/core/backend/cpu/fri.rs | 21 +++--- .../prover/src/core/backend/cpu/quotients.rs | 4 +- crates/prover/src/core/backend/simd/fri.rs | 39 +++++------ .../prover/src/core/backend/simd/quotients.rs | 7 +- crates/prover/src/core/fri.rs | 66 ++++++------------- crates/prover/src/core/pcs/mod.rs | 2 +- crates/prover/src/core/pcs/quotients.rs | 4 +- .../src/core/poly/circle/secure_poly.rs | 61 ++++++++++++----- 8 files changed, 99 insertions(+), 105 deletions(-) diff --git a/crates/prover/src/core/backend/cpu/fri.rs b/crates/prover/src/core/backend/cpu/fri.rs index 28b0b406d..693fb9926 100644 --- a/crates/prover/src/core/backend/cpu/fri.rs +++ b/crates/prover/src/core/backend/cpu/fri.rs @@ -6,6 +6,7 @@ use crate::core::fri::{fold_circle_into_line, fold_line, FriOps}; use crate::core::poly::circle::SecureEvaluation; use crate::core::poly::line::LineEvaluation; use crate::core::poly::twiddles::TwiddleTree; +use crate::core::poly::BitReversedOrder; // TODO(spapini): Optimized these functions as well. impl FriOps for CpuBackend { @@ -18,14 +19,16 @@ impl FriOps for CpuBackend { } fn fold_circle_into_line( dst: &mut LineEvaluation, - src: &SecureEvaluation, + src: &SecureEvaluation, alpha: SecureField, _twiddles: &TwiddleTree, ) { fold_circle_into_line(dst, src, alpha) } - fn decompose(eval: &SecureEvaluation) -> (SecureEvaluation, SecureField) { + fn decompose( + eval: &SecureEvaluation, + ) -> (SecureEvaluation, SecureField) { let lambda = Self::decomposition_coefficient(eval); let mut g_values = unsafe { SecureColumnByCoords::::uninitialized(eval.len()) }; @@ -43,10 +46,7 @@ impl FriOps for CpuBackend { g_values.set(i, val); } - let g = SecureEvaluation { - domain: eval.domain, - values: g_values, - }; + let g = SecureEvaluation::new(eval.domain, g_values); (g, lambda) } } @@ -67,7 +67,7 @@ impl CpuBackend { /// This function assumes the blowupfactor is 2 /// /// [`CirclePoly`]: crate::core::poly::circle::CirclePoly - fn decomposition_coefficient(eval: &SecureEvaluation) -> SecureField { + fn decomposition_coefficient(eval: &SecureEvaluation) -> SecureField { let domain_size = 1 << eval.domain.log_size(); let half_domain_size = domain_size / 2; @@ -96,6 +96,7 @@ mod tests { use crate::core::fields::secure_column::SecureColumnByCoords; use crate::core::fri::FriOps; use crate::core::poly::circle::{CanonicCoset, SecureEvaluation}; + use crate::core::poly::BitReversedOrder; use crate::m31; #[test] @@ -121,10 +122,10 @@ mod tests { values.values.clone(), ], }; - let secure_eval = SecureEvaluation:: { + let secure_eval = SecureEvaluation::::new( domain, - values: secure_column.clone(), - }; + secure_column.clone(), + ); let (g, lambda) = CpuBackend::decompose(&secure_eval); diff --git a/crates/prover/src/core/backend/cpu/quotients.rs b/crates/prover/src/core/backend/cpu/quotients.rs index 3f0c201da..17cc007e1 100644 --- a/crates/prover/src/core/backend/cpu/quotients.rs +++ b/crates/prover/src/core/backend/cpu/quotients.rs @@ -21,7 +21,7 @@ impl QuotientOps for CpuBackend { random_coeff: SecureField, sample_batches: &[ColumnSampleBatch], _log_blowup_factor: u32, - ) -> SecureEvaluation { + ) -> SecureEvaluation { let mut values = unsafe { SecureColumnByCoords::uninitialized(domain.size()) }; let quotient_constants = quotient_constants(sample_batches, random_coeff, domain); @@ -36,7 +36,7 @@ impl QuotientOps for CpuBackend { ); values.set(row, row_value); } - SecureEvaluation { domain, values } + SecureEvaluation::new(domain, values) } } diff --git a/crates/prover/src/core/backend/simd/fri.rs b/crates/prover/src/core/backend/simd/fri.rs index e404b050a..97212497b 100644 --- a/crates/prover/src/core/backend/simd/fri.rs +++ b/crates/prover/src/core/backend/simd/fri.rs @@ -17,6 +17,7 @@ use crate::core::poly::circle::SecureEvaluation; use crate::core::poly::line::LineEvaluation; use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::utils::domain_line_twiddles_from_tree; +use crate::core::poly::BitReversedOrder; impl FriOps for SimdBackend { fn fold_line( @@ -57,7 +58,7 @@ impl FriOps for SimdBackend { fn fold_circle_into_line( dst: &mut LineEvaluation, - src: &SecureEvaluation, + src: &SecureEvaluation, alpha: SecureField, twiddles: &TwiddleTree, ) { @@ -96,7 +97,9 @@ impl FriOps for SimdBackend { } } - fn decompose(eval: &SecureEvaluation) -> (SecureEvaluation, SecureField) { + fn decompose( + eval: &SecureEvaluation, + ) -> (SecureEvaluation, SecureField) { let lambda = decomposition_coefficient(eval); let broadcasted_lambda = PackedSecureField::broadcast(lambda); let mut g_values = SecureColumnByCoords::::zeros(eval.len()); @@ -112,10 +115,7 @@ impl FriOps for SimdBackend { unsafe { g_values.set_packed(i, val) } } - let g = SecureEvaluation { - domain: eval.domain, - values: g_values, - }; + let g = SecureEvaluation::new(eval.domain, g_values); (g, lambda) } } @@ -123,7 +123,9 @@ impl FriOps for SimdBackend { /// See [`decomposition_coefficient`]. /// /// [`decomposition_coefficient`]: crate::core::backend::cpu::CpuBackend::decomposition_coefficient -fn decomposition_coefficient(eval: &SecureEvaluation) -> SecureField { +fn decomposition_coefficient( + eval: &SecureEvaluation, +) -> SecureField { let cols = &eval.values.columns; let [mut x_sum, mut y_sum, mut z_sum, mut w_sum] = [PackedBaseField::zero(); 4]; @@ -167,6 +169,7 @@ mod tests { use crate::core::fri::FriOps; use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps, SecureEvaluation}; use crate::core::poly::line::{LineDomain, LineEvaluation}; + use crate::core::poly::BitReversedOrder; use crate::qm31; #[test] @@ -206,10 +209,7 @@ mod tests { ); CpuBackend::fold_circle_into_line( &mut cpu_fold, - &SecureEvaluation { - domain: circle_domain, - values: values.iter().copied().collect(), - }, + &SecureEvaluation::new(circle_domain, values.iter().copied().collect()), alpha, &CpuBackend::precompute_twiddles(line_domain.coset()), ); @@ -220,10 +220,7 @@ mod tests { ); SimdBackend::fold_circle_into_line( &mut simd_fold, - &SecureEvaluation { - domain: circle_domain, - values: values.iter().copied().collect(), - }, + &SecureEvaluation::new(circle_domain, values.iter().copied().collect()), alpha, &SimdBackend::precompute_twiddles(line_domain.coset()), ); @@ -250,16 +247,10 @@ mod tests { values.values.clone(), ], }; - let avx_eval = SecureEvaluation { - domain, - values: avx_column.clone(), - }; - let cpu_eval = SecureEvaluation:: { - domain, - values: avx_eval.to_cpu(), - }; + let avx_eval = SecureEvaluation::new(domain, avx_column.clone()); + let cpu_eval = + SecureEvaluation::::new(domain, avx_eval.to_cpu()); let (cpu_g, cpu_lambda) = CpuBackend::decompose(&cpu_eval); - let (avx_g, avx_lambda) = SimdBackend::decompose(&avx_eval); assert_eq!(avx_lambda, cpu_lambda); diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index 2568cff29..3cb664aeb 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -32,7 +32,7 @@ impl QuotientOps for SimdBackend { random_coeff: SecureField, sample_batches: &[ColumnSampleBatch], log_blowup_factor: u32, - ) -> SecureEvaluation { + ) -> SecureEvaluation { // Split the domain into a subdomain and a shift coset. // TODO(spapini): Move to the caller when Columns support slices. let (subdomain, mut subdomain_shifts) = domain.split(log_blowup_factor); @@ -74,10 +74,7 @@ impl QuotientOps for SimdBackend { } span.exit(); - SecureEvaluation { - domain, - values: extended_eval, - } + SecureEvaluation::new(domain, extended_eval) } } diff --git a/crates/prover/src/core/fri.rs b/crates/prover/src/core/fri.rs index 1dbc89764..9f13de253 100644 --- a/crates/prover/src/core/fri.rs +++ b/crates/prover/src/core/fri.rs @@ -109,7 +109,7 @@ pub trait FriOps: FieldOps + PolyOps + Sized + FieldOps // TODO(andrew): Fold directly into FRI layer to prevent allocation. fn fold_circle_into_line( dst: &mut LineEvaluation, - src: &SecureEvaluation, + src: &SecureEvaluation, alpha: SecureField, twiddles: &TwiddleTree, ); @@ -119,7 +119,9 @@ pub trait FriOps: FieldOps + PolyOps + Sized + FieldOps /// FRI-space: polynomials of total degree n/2. /// Based on lemma #12 from the CircleStark paper: f(P) = g(P)+ lambda * alternating(P), /// where lambda is the cosset diff of eval, and g is a polynomial in the fft-space. - fn decompose(eval: &SecureEvaluation) -> (SecureEvaluation, SecureField); + fn decompose( + eval: &SecureEvaluation, + ) -> (SecureEvaluation, SecureField); } /// A FRI prover that applies the FRI protocol to prove a set of polynomials are of low degree. pub struct FriProver, MC: MerkleChannel> { @@ -154,7 +156,7 @@ impl, MC: MerkleChannel> FriProver { pub fn commit( channel: &mut MC::C, config: FriConfig, - columns: &[SecureEvaluation], + columns: &[SecureEvaluation], twiddles: &TwiddleTree, ) -> Self { let _span = span!(Level::INFO, "FRI commitment").entered(); @@ -186,11 +188,12 @@ impl, MC: MerkleChannel> FriProver { fn commit_inner_layers( channel: &mut MC::C, config: FriConfig, - columns: &[SecureEvaluation], + columns: &[SecureEvaluation], twiddles: &TwiddleTree, ) -> (Vec>, LineEvaluation) { // Returns the length of the [LineEvaluation] a [CircleEvaluation] gets folded into. - let folded_len = |e: &SecureEvaluation| e.len() >> CIRCLE_TO_LINE_FOLD_STEP; + let folded_len = + |e: &SecureEvaluation| e.len() >> CIRCLE_TO_LINE_FOLD_STEP; let first_layer_size = folded_len(&columns[0]); let first_layer_domain = LineDomain::new(Coset::half_odds(first_layer_size.ilog2())); @@ -538,21 +541,6 @@ fn get_opening_positions( positions } -pub trait FriChannel { - type Digest; - - type Field; - - /// Reseeds the channel with a commitment to an inner FRI layer. - fn reseed_with_inner_layer(&mut self, commitment: &Self::Digest); - - /// Reseeds the channel with the FRI last layer polynomial. - fn reseed_with_last_layer(&mut self, last_layer: &LinePoly); - - /// Draws a random field element. - fn draw(&mut self) -> Self::Field; -} - #[derive(Clone, Copy, Debug, Error)] pub enum FriVerificationError { #[error("proof contains an invalid number of FRI layers")] @@ -876,10 +864,7 @@ impl SparseCircleEvaluation { let mut buffer = LineEvaluation::new_zero(buffer_domain); fold_circle_into_line( &mut buffer, - &SecureEvaluation { - domain: e.domain, - values: e.values.into_iter().collect(), - }, + &SecureEvaluation::new(e.domain, e.values.into_iter().collect()), alpha, ); buffer.values.at(0) @@ -957,7 +942,7 @@ pub fn fold_line( /// See [`FriOps::fold_circle_into_line`]. pub fn fold_circle_into_line( dst: &mut LineEvaluation, - src: &SecureEvaluation, + src: &SecureEvaluation, alpha: SecureField, ) { assert_eq!(src.len() >> CIRCLE_TO_LINE_FOLD_STEP, dst.len()); @@ -993,11 +978,10 @@ mod tests { use super::{get_opening_positions, FriVerificationError, SparseCircleEvaluation}; use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly}; - use crate::core::backend::{Col, Column, ColumnOps, CpuBackend}; + use crate::core::backend::{ColumnOps, CpuBackend}; use crate::core::circle::{CirclePointIndex, Coset}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; - use crate::core::fields::secure_column::SecureColumnByCoords; use crate::core::fields::Field; use crate::core::fri::{ fold_circle_into_line, fold_line, CirclePolyDegreeBound, FriConfig, @@ -1005,7 +989,7 @@ mod tests { }; use crate::core::poly::circle::{CircleDomain, PolyOps, SecureEvaluation}; use crate::core::poly::line::{LineDomain, LineEvaluation, LinePoly}; - use crate::core::poly::NaturalOrder; + use crate::core::poly::{BitReversedOrder, NaturalOrder}; use crate::core::queries::{Queries, SparseSubCircleDomain}; use crate::core::test_utils::test_channel; use crate::core::utils::bit_reverse_index; @@ -1089,10 +1073,10 @@ mod tests { fn committing_evaluation_from_invalid_domain_fails() { let invalid_domain = CircleDomain::new(Coset::new(CirclePointIndex::generator(), 3)); assert!(!invalid_domain.is_canonic(), "must be an invalid domain"); - let evaluation = SecureEvaluation { - domain: invalid_domain, - values: vec![SecureField::one(); 1 << 4].into_iter().collect(), - }; + let evaluation = SecureEvaluation::new( + invalid_domain, + vec![SecureField::one(); 1 << 4].into_iter().collect(), + ); FriProver::commit( &mut test_channel(), @@ -1388,22 +1372,12 @@ mod tests { fn polynomial_evaluation( log_degree: u32, log_blowup_factor: u32, - ) -> SecureEvaluation { + ) -> SecureEvaluation { let poly = CpuCirclePoly::new(vec![BaseField::one(); 1 << log_degree]); let coset = Coset::half_odds(log_degree + log_blowup_factor - 1); let domain = CircleDomain::new(coset); let values = poly.evaluate(domain); - SecureEvaluation { - domain, - values: SecureColumnByCoords { - columns: [ - values.values, - Col::::zeros(1 << (log_degree + log_blowup_factor)), - Col::::zeros(1 << (log_degree + log_blowup_factor)), - Col::::zeros(1 << (log_degree + log_blowup_factor)), - ], - }, - } + SecureEvaluation::new(domain, values.into_iter().map(SecureField::from).collect()) } /// Returns the log degree bound of a polynomial. @@ -1415,7 +1389,7 @@ mod tests { // TODO: Remove after SubcircleDomain integration. fn query_polynomial( - polynomial: &SecureEvaluation, + polynomial: &SecureEvaluation, queries: &Queries, ) -> SparseCircleEvaluation { let polynomial_log_size = polynomial.domain.log_size(); @@ -1425,7 +1399,7 @@ mod tests { } fn open_polynomial( - polynomial: &SecureEvaluation, + polynomial: &SecureEvaluation, positions: &SparseSubCircleDomain, ) -> SparseCircleEvaluation { let coset_evals = positions diff --git a/crates/prover/src/core/pcs/mod.rs b/crates/prover/src/core/pcs/mod.rs index 8d09abce7..07ddfd2d3 100644 --- a/crates/prover/src/core/pcs/mod.rs +++ b/crates/prover/src/core/pcs/mod.rs @@ -18,7 +18,7 @@ pub use self::utils::TreeVec; pub use self::verifier::CommitmentSchemeVerifier; use super::fri::FriConfig; -#[derive(Copy, Debug, Clone)] +#[derive(Copy, Debug, Clone, PartialEq, Eq)] pub struct TreeColumnSpan { pub tree_index: usize, pub col_start: usize, diff --git a/crates/prover/src/core/pcs/quotients.rs b/crates/prover/src/core/pcs/quotients.rs index 2327150f1..1a41e8303 100644 --- a/crates/prover/src/core/pcs/quotients.rs +++ b/crates/prover/src/core/pcs/quotients.rs @@ -31,7 +31,7 @@ pub trait QuotientOps: PolyOps { random_coeff: SecureField, sample_batches: &[ColumnSampleBatch], log_blowup_factor: u32, - ) -> SecureEvaluation; + ) -> SecureEvaluation; } /// A batch of column samplings at a point. @@ -78,7 +78,7 @@ pub fn compute_fri_quotients( samples: &[Vec], random_coeff: SecureField, log_blowup_factor: u32, -) -> Vec> { +) -> Vec> { let _span = span!(Level::INFO, "Compute FRI quotients").entered(); zip(columns, samples) .sorted_by_key(|(c, _)| Reverse(c.domain.log_size())) diff --git a/crates/prover/src/core/poly/circle/secure_poly.rs b/crates/prover/src/core/poly/circle/secure_poly.rs index 294f9ac77..8482e2971 100644 --- a/crates/prover/src/core/poly/circle/secure_poly.rs +++ b/crates/prover/src/core/poly/circle/secure_poly.rs @@ -1,13 +1,14 @@ +use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; use super::{CircleDomain, CircleEvaluation, CirclePoly, PolyOps}; -use crate::core::backend::cpu::CpuCircleEvaluation; use crate::core::backend::CpuBackend; use crate::core::circle::CirclePoint; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; use crate::core::fields::FieldOps; +use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::BitReversedOrder; pub struct SecureCirclePoly>(pub [CirclePoly; SECURE_EXTENSION_DEGREE]); @@ -32,6 +33,16 @@ impl SecureCirclePoly { pub fn log_size(&self) -> u32 { self[0].log_size() } + + pub fn evaluate_with_twiddles( + &self, + domain: CircleDomain, + twiddles: &TwiddleTree, + ) -> SecureEvaluation { + let polys = self.0.each_ref(); + let columns = polys.map(|poly| poly.evaluate_with_twiddles(domain, twiddles).values); + SecureEvaluation::new(domain, SecureColumnByCoords { columns }) + } } impl> Deref for SecureCirclePoly { @@ -42,12 +53,29 @@ impl> Deref for SecureCirclePoly { } } +/// A [`SecureField`] evaluation defined on a [CircleDomain]. +/// +/// The evaluation is stored as a column major array of [`SECURE_EXTENSION_DEGREE`] many base field +/// evaluations. The evaluations are ordered according to the [CircleDomain] ordering. #[derive(Clone)] -pub struct SecureEvaluation> { +pub struct SecureEvaluation, EvalOrder> { pub domain: CircleDomain, pub values: SecureColumnByCoords, + _eval_order: PhantomData, } -impl> Deref for SecureEvaluation { + +impl, EvalOrder> SecureEvaluation { + pub fn new(domain: CircleDomain, values: SecureColumnByCoords) -> Self { + assert_eq!(domain.size(), values.len()); + Self { + domain, + values, + _eval_order: PhantomData, + } + } +} + +impl, EvalOrder> Deref for SecureEvaluation { type Target = SecureColumnByCoords; fn deref(&self) -> &Self::Target { @@ -55,26 +83,29 @@ impl> Deref for SecureEvaluation { } } -impl> DerefMut for SecureEvaluation { +impl, EvalOrder> DerefMut for SecureEvaluation { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.values } } -impl SecureEvaluation { - // TODO(spapini): Remove when we no longer use CircleEvaluation. - pub fn to_cpu(self) -> CpuCircleEvaluation { - CpuCircleEvaluation::new(self.domain, self.values.to_vec()) +impl SecureEvaluation { + /// Computes a minimal [`SecureCirclePoly`] that evaluates to the same values as this + /// evaluation, using precomputed twiddles. + pub fn interpolate_with_twiddles(self, twiddles: &TwiddleTree) -> SecureCirclePoly { + let domain = self.domain; + let cols = self.values.columns; + SecureCirclePoly(cols.map(|c| { + CircleEvaluation::::new(domain, c) + .interpolate_with_twiddles(twiddles) + })) } } -impl From> - for SecureEvaluation +impl From> + for SecureEvaluation { - fn from(evaluation: CircleEvaluation) -> Self { - Self { - domain: evaluation.domain, - values: evaluation.values.into_iter().collect(), - } + fn from(evaluation: CircleEvaluation) -> Self { + Self::new(evaluation.domain, evaluation.values.into_iter().collect()) } }