Skip to content

Commit

Permalink
Parallel fft
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Sep 5, 2024
1 parent 1ee6a70 commit 2a4fec2
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 13 deletions.
32 changes: 28 additions & 4 deletions crates/prover/src/core/backend/simd/fft/ifft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4};

use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::simd::fft::UnsafeMutI32;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::circle::Coset;
use crate::core::fields::FieldExpOps;
Expand All @@ -29,6 +32,7 @@ use crate::core::utils::bit_reverse;
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) {
assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize);

let log_n_vecs = log_n_elements - LOG_N_LANES as usize;
if log_n_elements <= CACHED_FFT_LOG_SIZE as usize {
ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements);
Expand Down Expand Up @@ -81,7 +85,17 @@ pub unsafe fn ifft_lower_with_vecwise(

assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));

for index_h in 0..1 << (log_size - fft_layers) {
let iter_range = 0..1 << (log_size - fft_layers);

#[cfg(not(feature = "parallel"))]
let iter = iter_range;

#[cfg(feature = "parallel")]
let iter = iter_range.into_par_iter();

let values = UnsafeMutI32(values);
iter.for_each(|index_h| {
let values = values.get();
ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h);
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) {
match fft_layers - layer {
Expand All @@ -102,7 +116,7 @@ pub unsafe fn ifft_lower_with_vecwise(
}
}
}
}
});
}

/// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of
Expand Down Expand Up @@ -131,7 +145,17 @@ pub unsafe fn ifft_lower_without_vecwise(
) {
assert!(log_size >= LOG_N_LANES as usize);

for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) {
let iter_range = 0..1 << (log_size - fft_layers - LOG_N_LANES as usize);

#[cfg(not(feature = "parallel"))]
let iter = iter_range;

#[cfg(feature = "parallel")]
let iter = iter_range.into_par_iter();

let values = UnsafeMutI32(values);
iter.for_each(|index_h| {
let values = values.get();
for layer in (0..fft_layers).step_by(3) {
let fixed_layer = layer + LOG_N_LANES as usize;
match fft_layers - layer {
Expand All @@ -152,7 +176,7 @@ pub unsafe fn ifft_lower_without_vecwise(
}
}
}
}
});
}

/// Runs the first 5 ifft layers across the entire array.
Expand Down
38 changes: 35 additions & 3 deletions crates/prover/src/core/backend/simd/fft/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::simd::{simd_swizzle, u32x16, u32x8};

#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::m31::PackedBaseField;
use crate::core::fields::m31::P;

Expand All @@ -10,6 +13,26 @@ pub const CACHED_FFT_LOG_SIZE: u32 = 16;

pub const MIN_FFT_LOG_SIZE: u32 = 5;

pub struct UnsafeMutI32(pub *mut u32);
impl UnsafeMutI32 {
pub fn get(&self) -> *mut u32 {
self.0
}
}

unsafe impl Send for UnsafeMutI32 {}
unsafe impl Sync for UnsafeMutI32 {}

pub struct UnsafeConstI32(pub *const u32);
impl UnsafeConstI32 {
pub fn get(&self) -> *const u32 {
self.0
}
}

unsafe impl Send for UnsafeConstI32 {}
unsafe impl Sync for UnsafeConstI32 {}

// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce
// it somewhere.

Expand All @@ -29,8 +52,17 @@ pub const MIN_FFT_LOG_SIZE: u32 = 5;
/// Behavior is undefined if `values` does not have the same alignment as [`u32x16`].
pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
let half = log_n_vecs / 2;
for b in 0..1 << (log_n_vecs & 1) {
for a in 0..1 << half {

#[cfg(not(feature = "parallel"))]
let iter = 0..1 << half;

#[cfg(feature = "parallel")]
let iter = (0..1 << half).into_par_iter();

let values = UnsafeMutI32(values);
iter.for_each(|a| {
let values = values.get();
for b in 0..1 << (log_n_vecs & 1) {
for c in 0..1 << half {
let i = (a << (log_n_vecs - half)) | (b << half) | c;
let j = (c << (log_n_vecs - half)) | (b << half) | a;
Expand All @@ -43,7 +75,7 @@ pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
store(values.add(j << 4), val0);
}
}
}
});
}

/// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers.
Expand Down
37 changes: 31 additions & 6 deletions crates/prover/src/core/backend/simd/fft/rfft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ use std::array;
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4, u32x8};

use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::simd::fft::{UnsafeConstI32, UnsafeMutI32};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::circle::Coset;
use crate::core::utils::bit_reverse;
Expand Down Expand Up @@ -86,8 +89,19 @@ pub unsafe fn fft_lower_with_vecwise(

assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));

for index_h in 0..1 << (log_size - fft_layers) {
let mut src = src;
let iter_range = 0..1 << (log_size - fft_layers);

#[cfg(not(feature = "parallel"))]
let iter = iter_range;

#[cfg(feature = "parallel")]
let iter = iter_range.into_par_iter();

let src = UnsafeConstI32(src);
let dst = UnsafeMutI32(dst);
iter.for_each(|index_h| {
let mut src = src.get();
let dst = dst.get();
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() {
match fft_layers - layer {
1 => {
Expand Down Expand Up @@ -116,7 +130,7 @@ pub unsafe fn fft_lower_with_vecwise(
fft_layers - VECWISE_FFT_BITS,
index_h,
);
}
});
}

/// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of
Expand Down Expand Up @@ -147,8 +161,19 @@ pub unsafe fn fft_lower_without_vecwise(
) {
assert!(log_size >= LOG_N_LANES as usize);

for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) {
let mut src = src;
let iter_range = 0..1 << (log_size - fft_layers - LOG_N_LANES as usize);

#[cfg(not(feature = "parallel"))]
let iter = iter_range;

#[cfg(feature = "parallel")]
let iter = iter_range.into_par_iter();

let src = UnsafeConstI32(src);
let dst = UnsafeMutI32(dst);
iter.for_each(|index_h| {
let mut src = src.get();
let dst = dst.get();
for layer in (0..fft_layers).step_by(3).rev() {
let fixed_layer = layer + LOG_N_LANES as usize;
match fft_layers - layer {
Expand All @@ -171,7 +196,7 @@ pub unsafe fn fft_lower_without_vecwise(
}
src = dst;
}
}
});
}

/// Runs the last 5 fft layers across the entire array.
Expand Down

0 comments on commit 2a4fec2

Please sign in to comment.