Skip to content

Commit

Permalink
Parallel quotients
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Jul 1, 2024
1 parent 3db6f2f commit 50f7600
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 21 deletions.
4 changes: 2 additions & 2 deletions crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version.workspace = true
edition.workspace = true

[features]
parallel = ["rayon"]
parallel = []

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand All @@ -23,7 +23,7 @@ starknet-crypto = "0.6.2"
starknet-ff = "0.3.7"
thiserror.workspace = true
tracing.workspace = true
rayon = { version = "1.10.0", optional = true }
rayon = { version = "1.10.0" }

[dev-dependencies]
aligned = "0.4.2"
Expand Down
12 changes: 11 additions & 1 deletion crates/prover/src/core/backend/simd/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,18 @@ impl<'a> SecureColumnMutSlice<'a> {
*self.0[2].get_unchecked_mut(vec_index) = c;
*self.0[3].get_unchecked_mut(vec_index) = d;
}
pub fn chunks_mut(&mut self, chunk_pack_size: usize) -> Vec<SecureColumnMutSlice<'_>> {
let mut_refs = self.0.get_many_mut([0, 1, 2, 3]).unwrap();
let [a, b, c, d] = mut_refs.map(|c| c.chunks_mut(chunk_pack_size).collect_vec());
izip!(a, b, c, d)
.map(|(a, b, c, d)| SecureColumnMutSlice([a, b, c, d]))
.collect()
}
pub fn as_ref(&self) -> SecureColumnSlice<'_> {
let refs = std::array::from_fn(|i| &self.0[i]);
SecureColumnSlice(refs.map(|c| &**c))
}
}

impl SecureColumn<SimdBackend> {
pub fn as_ref(&self) -> SecureColumnSlice<'_> {
assert_eq!(self.columns[0].length, self.columns[0].data.len() * N_LANES);
Expand Down
48 changes: 30 additions & 18 deletions crates/prover/src/core/backend/simd/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ use std::iter::zip;

use itertools::izip;
use num_traits::{One, Zero};
#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::column::SecureFieldVec;
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::cpu::quotients::{batch_random_coeffs, QuotientConstants};
use crate::core::backend::simd::utils::MaybeParIter;
use crate::core::backend::{Col, Column};
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
Expand All @@ -31,25 +34,34 @@ impl QuotientOps for SimdBackend {
let quotient_constants = quotient_constants(sample_batches, random_coeff, domain);

// TODO(spapini): bit reverse iterator.
for vec_row in 0..1 << (domain.log_size() - LOG_N_LANES) {
// TODO(spapini): Optimize this, for the small number of columns case.
let points = std::array::from_fn(|i| {
domain.at(bit_reverse_index(
(vec_row << LOG_N_LANES) + i,
domain.log_size(),
))
const PACK_CHUNK_SIZE: usize = 1 << 5;
values
.as_mut()
.chunks_mut(PACK_CHUNK_SIZE)
.maybe_par_iter()
.enumerate()
.for_each(|(chunk_i, mut chunk)| {
for i in 0..chunk.as_ref().packed_len() {
let vec_row = chunk_i * PACK_CHUNK_SIZE + i;
// TODO(spapini): Optimize this, for the small number of columns case.
let points = std::array::from_fn(|i| {
domain.at(bit_reverse_index(
(vec_row << LOG_N_LANES) + i,
domain.log_size(),
))
});
let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x));
let domain_points_y = PackedBaseField::from_array(points.map(|p| p.y));
let row_accumulator = accumulate_row_quotients(
sample_batches,
columns,
&quotient_constants,
vec_row,
(domain_points_x, domain_points_y),
);
unsafe { chunk.set_packed(i, row_accumulator) };
}
});
let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x));
let domain_points_y = PackedBaseField::from_array(points.map(|p| p.y));
let row_accumulator = accumulate_row_quotients(
sample_batches,
columns,
&quotient_constants,
vec_row,
(domain_points_x, domain_points_y),
);
unsafe { values.set_packed(vec_row, row_accumulator) };
}
SecureEvaluation { domain, values }
}
}
Expand Down
20 changes: 20 additions & 0 deletions crates/prover/src/core/backend/simd/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::simd::Swizzle;

use rayon::prelude::*;

/// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors.
pub struct InterleaveEvens;

Expand All @@ -24,6 +26,24 @@ const fn parity_interleave<const N: usize>(odd: bool) -> [usize; N] {
res
}

pub trait MaybeParIter: IntoParallelIterator + IntoIterator {
type Iter;
fn maybe_par_iter(self) -> <Self as MaybeParIter>::Iter;
}
impl<I: IntoParallelIterator + IntoIterator> MaybeParIter for I {
#[cfg(feature = "parallel")]
type Iter = <Self as IntoParallelIterator>::Iter;
#[cfg(not(feature = "parallel"))]
type Iter = <Self as IntoIterator>::IntoIter;

fn maybe_par_iter(self) -> <Self as MaybeParIter>::Iter {
#[cfg(feature = "parallel")]
return self.into_par_iter();
#[cfg(not(feature = "parallel"))]
self.into_iter()
}
}

#[cfg(test)]
mod tests {
use std::simd::{u32x4, Swizzle};
Expand Down

0 comments on commit 50f7600

Please sign in to comment.