From 2a4fec2c1c6f7e88a6fbb8164521a38ad8bd433c Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Thu, 5 Sep 2024 07:49:56 +0300 Subject: [PATCH] Parallel fft --- .../prover/src/core/backend/simd/fft/ifft.rs | 32 ++++++++++++++-- .../prover/src/core/backend/simd/fft/mod.rs | 38 +++++++++++++++++-- .../prover/src/core/backend/simd/fft/rfft.rs | 37 +++++++++++++++--- 3 files changed, 94 insertions(+), 13 deletions(-) diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs index eb34da490..e7ad3343c 100644 --- a/crates/prover/src/core/backend/simd/fft/ifft.rs +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -3,10 +3,13 @@ use std::simd::{simd_swizzle, u32x16, u32x2, u32x4}; use itertools::Itertools; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, }; +use crate::core::backend::simd::fft::UnsafeMutI32; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::circle::Coset; use crate::core::fields::FieldExpOps; @@ -29,6 +32,7 @@ use crate::core::utils::bit_reverse; /// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) { assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize); + let log_n_vecs = log_n_elements - LOG_N_LANES as usize; if log_n_elements <= CACHED_FFT_LOG_SIZE as usize { ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements); @@ -81,7 +85,17 @@ pub unsafe fn ifft_lower_with_vecwise( assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - for index_h in 0..1 << (log_size - fft_layers) { + let iter_range = 0..1 << (log_size - fft_layers); + + #[cfg(not(feature = "parallel"))] + let iter = iter_range; + + #[cfg(feature = "parallel")] + let iter = iter_range.into_par_iter(); + + let values = UnsafeMutI32(values); + iter.for_each(|index_h| { + let values = values.get(); ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h); for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) { match fft_layers - layer { @@ -102,7 +116,7 @@ pub unsafe fn ifft_lower_with_vecwise( } } } - } + }); } /// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of @@ -131,7 +145,17 @@ pub unsafe fn ifft_lower_without_vecwise( ) { assert!(log_size >= LOG_N_LANES as usize); - for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) { + let iter_range = 0..1 << (log_size - fft_layers - LOG_N_LANES as usize); + + #[cfg(not(feature = "parallel"))] + let iter = iter_range; + + #[cfg(feature = "parallel")] + let iter = iter_range.into_par_iter(); + + let values = UnsafeMutI32(values); + iter.for_each(|index_h| { + let values = values.get(); for layer in (0..fft_layers).step_by(3) { let fixed_layer = layer + LOG_N_LANES as usize; match fft_layers - layer { @@ -152,7 +176,7 @@ pub unsafe fn ifft_lower_without_vecwise( } } } - } + }); } /// Runs the first 5 ifft layers across the entire array. diff --git a/crates/prover/src/core/backend/simd/fft/mod.rs b/crates/prover/src/core/backend/simd/fft/mod.rs index ca44979e8..0fe51aece 100644 --- a/crates/prover/src/core/backend/simd/fft/mod.rs +++ b/crates/prover/src/core/backend/simd/fft/mod.rs @@ -1,5 +1,8 @@ use std::simd::{simd_swizzle, u32x16, u32x8}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + use super::m31::PackedBaseField; use crate::core::fields::m31::P; @@ -10,6 +13,26 @@ pub const CACHED_FFT_LOG_SIZE: u32 = 16; pub const MIN_FFT_LOG_SIZE: u32 = 5; +pub struct UnsafeMutI32(pub *mut u32); +impl UnsafeMutI32 { + pub fn get(&self) -> *mut u32 { + self.0 + } +} + +unsafe impl Send for UnsafeMutI32 {} +unsafe impl Sync for UnsafeMutI32 {} + +pub struct UnsafeConstI32(pub *const u32); +impl UnsafeConstI32 { + pub fn get(&self) -> *const u32 { + self.0 + } +} + +unsafe impl Send for UnsafeConstI32 {} +unsafe impl Sync for UnsafeConstI32 {} + // TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce // it somewhere. @@ -29,8 +52,17 @@ pub const MIN_FFT_LOG_SIZE: u32 = 5; /// Behavior is undefined if `values` does not have the same alignment as [`u32x16`]. pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) { let half = log_n_vecs / 2; - for b in 0..1 << (log_n_vecs & 1) { - for a in 0..1 << half { + + #[cfg(not(feature = "parallel"))] + let iter = 0..1 << half; + + #[cfg(feature = "parallel")] + let iter = (0..1 << half).into_par_iter(); + + let values = UnsafeMutI32(values); + iter.for_each(|a| { + let values = values.get(); + for b in 0..1 << (log_n_vecs & 1) { for c in 0..1 << half { let i = (a << (log_n_vecs - half)) | (b << half) | c; let j = (c << (log_n_vecs - half)) | (b << half) | a; @@ -43,7 +75,7 @@ pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) { store(values.add(j << 4), val0); } } - } + }); } /// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers. diff --git a/crates/prover/src/core/backend/simd/fft/rfft.rs b/crates/prover/src/core/backend/simd/fft/rfft.rs index 6d51fd09d..25f452344 100644 --- a/crates/prover/src/core/backend/simd/fft/rfft.rs +++ b/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -4,10 +4,13 @@ use std::array; use std::simd::{simd_swizzle, u32x16, u32x2, u32x4, u32x8}; use itertools::Itertools; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, }; +use crate::core::backend::simd::fft::{UnsafeConstI32, UnsafeMutI32}; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::circle::Coset; use crate::core::utils::bit_reverse; @@ -86,8 +89,19 @@ pub unsafe fn fft_lower_with_vecwise( assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - for index_h in 0..1 << (log_size - fft_layers) { - let mut src = src; + let iter_range = 0..1 << (log_size - fft_layers); + + #[cfg(not(feature = "parallel"))] + let iter = iter_range; + + #[cfg(feature = "parallel")] + let iter = iter_range.into_par_iter(); + + let src = UnsafeConstI32(src); + let dst = UnsafeMutI32(dst); + iter.for_each(|index_h| { + let mut src = src.get(); + let dst = dst.get(); for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() { match fft_layers - layer { 1 => { @@ -116,7 +130,7 @@ pub unsafe fn fft_lower_with_vecwise( fft_layers - VECWISE_FFT_BITS, index_h, ); - } + }); } /// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of @@ -147,8 +161,19 @@ pub unsafe fn fft_lower_without_vecwise( ) { assert!(log_size >= LOG_N_LANES as usize); - for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) { - let mut src = src; + let iter_range = 0..1 << (log_size - fft_layers - LOG_N_LANES as usize); + + #[cfg(not(feature = "parallel"))] + let iter = iter_range; + + #[cfg(feature = "parallel")] + let iter = iter_range.into_par_iter(); + + let src = UnsafeConstI32(src); + let dst = UnsafeMutI32(dst); + iter.for_each(|index_h| { + let mut src = src.get(); + let dst = dst.get(); for layer in (0..fft_layers).step_by(3).rev() { let fixed_layer = layer + LOG_N_LANES as usize; match fft_layers - layer { @@ -171,7 +196,7 @@ pub unsafe fn fft_lower_without_vecwise( } src = dst; } - } + }); } /// Runs the last 5 fft layers across the entire array.