diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 3c9de2969..2b2cbe4c0 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -4,7 +4,7 @@ version.workspace = true edition.workspace = true [features] -parallel = ["rayon"] +parallel = [] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -23,7 +23,7 @@ starknet-crypto = "0.6.2" starknet-ff = "0.3.7" thiserror.workspace = true tracing.workspace = true -rayon = { version = "1.10.0", optional = true } +rayon = { version = "1.10.0" } [dev-dependencies] aligned = "0.4.2" diff --git a/crates/prover/src/core/backend/simd/column.rs b/crates/prover/src/core/backend/simd/column.rs index 19954d0bd..9287e273c 100644 --- a/crates/prover/src/core/backend/simd/column.rs +++ b/crates/prover/src/core/backend/simd/column.rs @@ -226,8 +226,18 @@ impl<'a> SecureColumnMutSlice<'a> { *self.0[2].get_unchecked_mut(vec_index) = c; *self.0[3].get_unchecked_mut(vec_index) = d; } + pub fn chunks_mut(&mut self, chunk_pack_size: usize) -> Vec> { + let mut_refs = self.0.get_many_mut([0, 1, 2, 3]).unwrap(); + let [a, b, c, d] = mut_refs.map(|c| c.chunks_mut(chunk_pack_size).collect_vec()); + izip!(a, b, c, d) + .map(|(a, b, c, d)| SecureColumnMutSlice([a, b, c, d])) + .collect() + } + pub fn as_ref(&self) -> SecureColumnSlice<'_> { + let refs = std::array::from_fn(|i| &self.0[i]); + SecureColumnSlice(refs.map(|c| &**c)) + } } - impl SecureColumn { pub fn as_ref(&self) -> SecureColumnSlice<'_> { assert_eq!(self.columns[0].length, self.columns[0].data.len() * N_LANES); diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index b011bd098..668e55c29 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -2,12 +2,15 @@ use std::iter::zip; use itertools::izip; use num_traits::{One, Zero}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use super::column::SecureFieldVec; use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; use super::qm31::PackedSecureField; use super::SimdBackend; use crate::core::backend::cpu::quotients::{batch_random_coeffs, QuotientConstants}; +use crate::core::backend::simd::utils::MaybeParIter; use crate::core::backend::{Col, Column}; use crate::core::circle::CirclePoint; use crate::core::fields::m31::BaseField; @@ -31,25 +34,34 @@ impl QuotientOps for SimdBackend { let quotient_constants = quotient_constants(sample_batches, random_coeff, domain); // TODO(spapini): bit reverse iterator. - for vec_row in 0..1 << (domain.log_size() - LOG_N_LANES) { - // TODO(spapini): Optimize this, for the small number of columns case. - let points = std::array::from_fn(|i| { - domain.at(bit_reverse_index( - (vec_row << LOG_N_LANES) + i, - domain.log_size(), - )) + const PACK_CHUNK_SIZE: usize = 1 << 5; + values + .as_mut() + .chunks_mut(PACK_CHUNK_SIZE) + .maybe_par_iter() + .enumerate() + .for_each(|(chunk_i, mut chunk)| { + for i in 0..chunk.as_ref().packed_len() { + let vec_row = chunk_i * PACK_CHUNK_SIZE + i; + // TODO(spapini): Optimize this, for the small number of columns case. + let points = std::array::from_fn(|i| { + domain.at(bit_reverse_index( + (vec_row << LOG_N_LANES) + i, + domain.log_size(), + )) + }); + let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x)); + let domain_points_y = PackedBaseField::from_array(points.map(|p| p.y)); + let row_accumulator = accumulate_row_quotients( + sample_batches, + columns, + "ient_constants, + vec_row, + (domain_points_x, domain_points_y), + ); + unsafe { chunk.set_packed(i, row_accumulator) }; + } }); - let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x)); - let domain_points_y = PackedBaseField::from_array(points.map(|p| p.y)); - let row_accumulator = accumulate_row_quotients( - sample_batches, - columns, - "ient_constants, - vec_row, - (domain_points_x, domain_points_y), - ); - unsafe { values.set_packed(vec_row, row_accumulator) }; - } SecureEvaluation { domain, values } } } diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index 87dfd2246..96bb382c9 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -1,5 +1,7 @@ use std::simd::Swizzle; +use rayon::prelude::*; + /// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. pub struct InterleaveEvens; @@ -24,6 +26,24 @@ const fn parity_interleave(odd: bool) -> [usize; N] { res } +pub trait MaybeParIter: IntoParallelIterator + IntoIterator { + type Iter; + fn maybe_par_iter(self) -> ::Iter; +} +impl MaybeParIter for I { + #[cfg(feature = "parallel")] + type Iter = ::Iter; + #[cfg(not(feature = "parallel"))] + type Iter = ::IntoIter; + + fn maybe_par_iter(self) -> ::Iter { + #[cfg(feature = "parallel")] + return self.into_par_iter(); + #[cfg(not(feature = "parallel"))] + self.into_iter() + } +} + #[cfg(test)] mod tests { use std::simd::{u32x4, Swizzle};