diff --git a/commit/src/pcs.rs b/commit/src/pcs.rs index 5bdb6df06..8d2b95f3c 100644 --- a/commit/src/pcs.rs +++ b/commit/src/pcs.rs @@ -77,6 +77,8 @@ where where Self: 'a; + type LdeOwned: MatrixRows + MatrixGet + Sync; + fn coset_shift(&self) -> Val; fn log_blowup(&self) -> usize; @@ -85,6 +87,17 @@ where where 'a: 'b; + // Compute the (shifted) low-degree extensions only without computing the commitment. + fn compute_ldes_batches( + &self, + polynomials: Vec, + coset_shifts: &[Val], + ) -> Vec; + + fn compute_lde_batch(&self, polynomials: In, coset_shift: Val) -> Vec { + self.compute_ldes_batches(vec![polynomials], &[coset_shift]) + } + // Commit to polys that are already defined over a coset. fn commit_shifted_batches( &self, diff --git a/fri/src/two_adic_pcs.rs b/fri/src/two_adic_pcs.rs index 7d9e1a83f..3137898c6 100644 --- a/fri/src/two_adic_pcs.rs +++ b/fri/src/two_adic_pcs.rs @@ -13,7 +13,7 @@ use p3_field::{ }; use p3_interpolation::interpolate_coset; use p3_matrix::bitrev::{BitReversableMatrix, BitReversedMatrixView}; -use p3_matrix::dense::RowMajorMatrixView; +use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView}; use p3_matrix::{Dimensions, Matrix, MatrixRows}; use p3_maybe_rayon::prelude::*; use p3_util::linear_map::LinearMap; @@ -141,6 +141,7 @@ where >::ProverData: Send + Sync + Sized, { type Lde<'a> = BitReversedMatrixView<>::Mat<'a>> where Self: 'a; + type LdeOwned = BitReversedMatrixView>; fn coset_shift(&self) -> C::Val { C::Val::generator() @@ -162,26 +163,38 @@ where .collect() } - fn commit_shifted_batches( + fn compute_ldes_batches( &self, polynomials: Vec, coset_shifts: &[C::Val], - ) -> (Self::Commitment, Self::ProverData) { - let ldes = info_span!("compute all coset LDEs").in_scope(|| { + ) -> Vec { + info_span!("compute all coset LDEs").in_scope(|| { polynomials .par_iter() .zip_eq(coset_shifts) .map(|(poly, coset_shift)| { let shift = C::Val::generator() / *coset_shift; let input = ((*poly).clone()).to_row_major_matrix(); - // Commit to the bit-reversed LDE. self.dft .coset_lde_batch(input, self.fri.log_blowup, shift) .bit_reverse_rows() .to_row_major_matrix() }) + .map(BitReversedMatrixView::new) .collect() - }); + }) + } + + fn commit_shifted_batches( + &self, + polynomials: Vec, + coset_shifts: &[C::Val], + ) -> (Self::Commitment, Self::ProverData) { + let ldes = self + .compute_ldes_batches(polynomials, coset_shifts) + .into_iter() + .map(|x| x.to_row_major_matrix()) + .collect(); self.mmcs.commit(ldes) }