From 2a4fec2c1c6f7e88a6fbb8164521a38ad8bd433c Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Thu, 5 Sep 2024 07:49:56 +0300 Subject: [PATCH 1/2] Parallel fft --- .../prover/src/core/backend/simd/fft/ifft.rs | 32 ++++++++++++++-- .../prover/src/core/backend/simd/fft/mod.rs | 38 +++++++++++++++++-- .../prover/src/core/backend/simd/fft/rfft.rs | 37 +++++++++++++++--- 3 files changed, 94 insertions(+), 13 deletions(-) diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs index eb34da490..e7ad3343c 100644 --- a/crates/prover/src/core/backend/simd/fft/ifft.rs +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -3,10 +3,13 @@ use std::simd::{simd_swizzle, u32x16, u32x2, u32x4}; use itertools::Itertools; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, }; +use crate::core::backend::simd::fft::UnsafeMutI32; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::circle::Coset; use crate::core::fields::FieldExpOps; @@ -29,6 +32,7 @@ use crate::core::utils::bit_reverse; /// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) { assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize); + let log_n_vecs = log_n_elements - LOG_N_LANES as usize; if log_n_elements <= CACHED_FFT_LOG_SIZE as usize { ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements); @@ -81,7 +85,17 @@ pub unsafe fn ifft_lower_with_vecwise( assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - for index_h in 0..1 << (log_size - fft_layers) { + let iter_range = 0..1 << (log_size - fft_layers); + + #[cfg(not(feature = "parallel"))] + let iter = iter_range; + + #[cfg(feature = "parallel")] + let iter = iter_range.into_par_iter(); + + let values = UnsafeMutI32(values); + iter.for_each(|index_h| { + let values = values.get(); ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h); for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) { match fft_layers - layer { @@ -102,7 +116,7 @@ pub unsafe fn ifft_lower_with_vecwise( } } } - } + }); } /// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of @@ -131,7 +145,17 @@ pub unsafe fn ifft_lower_without_vecwise( ) { assert!(log_size >= LOG_N_LANES as usize); - for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) { + let iter_range = 0..1 << (log_size - fft_layers - LOG_N_LANES as usize); + + #[cfg(not(feature = "parallel"))] + let iter = iter_range; + + #[cfg(feature = "parallel")] + let iter = iter_range.into_par_iter(); + + let values = UnsafeMutI32(values); + iter.for_each(|index_h| { + let values = values.get(); for layer in (0..fft_layers).step_by(3) { let fixed_layer = layer + LOG_N_LANES as usize; match fft_layers - layer { @@ -152,7 +176,7 @@ pub unsafe fn ifft_lower_without_vecwise( } } } - } + }); } /// Runs the first 5 ifft layers across the entire array. diff --git a/crates/prover/src/core/backend/simd/fft/mod.rs b/crates/prover/src/core/backend/simd/fft/mod.rs index ca44979e8..0fe51aece 100644 --- a/crates/prover/src/core/backend/simd/fft/mod.rs +++ b/crates/prover/src/core/backend/simd/fft/mod.rs @@ -1,5 +1,8 @@ use std::simd::{simd_swizzle, u32x16, u32x8}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + use super::m31::PackedBaseField; use crate::core::fields::m31::P; @@ -10,6 +13,26 @@ pub const CACHED_FFT_LOG_SIZE: u32 = 16; pub const MIN_FFT_LOG_SIZE: u32 = 5; +pub struct UnsafeMutI32(pub *mut u32); +impl UnsafeMutI32 { + pub fn get(&self) -> *mut u32 { + self.0 + } +} + +unsafe impl Send for UnsafeMutI32 {} +unsafe impl Sync for UnsafeMutI32 {} + +pub struct UnsafeConstI32(pub *const u32); +impl UnsafeConstI32 { + pub fn get(&self) -> *const u32 { + self.0 + } +} + +unsafe impl Send for UnsafeConstI32 {} +unsafe impl Sync for UnsafeConstI32 {} + // TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce // it somewhere. @@ -29,8 +52,17 @@ pub const MIN_FFT_LOG_SIZE: u32 = 5; /// Behavior is undefined if `values` does not have the same alignment as [`u32x16`]. pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) { let half = log_n_vecs / 2; - for b in 0..1 << (log_n_vecs & 1) { - for a in 0..1 << half { + + #[cfg(not(feature = "parallel"))] + let iter = 0..1 << half; + + #[cfg(feature = "parallel")] + let iter = (0..1 << half).into_par_iter(); + + let values = UnsafeMutI32(values); + iter.for_each(|a| { + let values = values.get(); + for b in 0..1 << (log_n_vecs & 1) { for c in 0..1 << half { let i = (a << (log_n_vecs - half)) | (b << half) | c; let j = (c << (log_n_vecs - half)) | (b << half) | a; @@ -43,7 +75,7 @@ pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) { store(values.add(j << 4), val0); } } - } + }); } /// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers. diff --git a/crates/prover/src/core/backend/simd/fft/rfft.rs b/crates/prover/src/core/backend/simd/fft/rfft.rs index 6d51fd09d..25f452344 100644 --- a/crates/prover/src/core/backend/simd/fft/rfft.rs +++ b/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -4,10 +4,13 @@ use std::array; use std::simd::{simd_swizzle, u32x16, u32x2, u32x4, u32x8}; use itertools::Itertools; +#[cfg(feature = "parallel")] +use rayon::prelude::*; use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, }; +use crate::core::backend::simd::fft::{UnsafeConstI32, UnsafeMutI32}; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::circle::Coset; use crate::core::utils::bit_reverse; @@ -86,8 +89,19 @@ pub unsafe fn fft_lower_with_vecwise( assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - for index_h in 0..1 << (log_size - fft_layers) { - let mut src = src; + let iter_range = 0..1 << (log_size - fft_layers); + + #[cfg(not(feature = "parallel"))] + let iter = iter_range; + + #[cfg(feature = "parallel")] + let iter = iter_range.into_par_iter(); + + let src = UnsafeConstI32(src); + let dst = UnsafeMutI32(dst); + iter.for_each(|index_h| { + let mut src = src.get(); + let dst = dst.get(); for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() { match fft_layers - layer { 1 => { @@ -116,7 +130,7 @@ pub unsafe fn fft_lower_with_vecwise( fft_layers - VECWISE_FFT_BITS, index_h, ); - } + }); } /// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of @@ -147,8 +161,19 @@ pub unsafe fn fft_lower_without_vecwise( ) { assert!(log_size >= LOG_N_LANES as usize); - for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) { - let mut src = src; + let iter_range = 0..1 << (log_size - fft_layers - LOG_N_LANES as usize); + + #[cfg(not(feature = "parallel"))] + let iter = iter_range; + + #[cfg(feature = "parallel")] + let iter = iter_range.into_par_iter(); + + let src = UnsafeConstI32(src); + let dst = UnsafeMutI32(dst); + iter.for_each(|index_h| { + let mut src = src.get(); + let dst = dst.get(); for layer in (0..fft_layers).step_by(3).rev() { let fixed_layer = layer + LOG_N_LANES as usize; match fft_layers - layer { @@ -171,7 +196,7 @@ pub unsafe fn fft_lower_without_vecwise( } src = dst; } - } + }); } /// Runs the last 5 fft layers across the entire array. From 755db7c561caa6aff26d322f98cf953f5e4d4833 Mon Sep 17 00:00:00 2001 From: Shahar Papini Date: Thu, 5 Sep 2024 10:04:36 +0300 Subject: [PATCH 2/2] Simd twiddles --- crates/prover/src/core/backend/simd/circle.rs | 102 ++++++++++++++---- 1 file changed, 84 insertions(+), 18 deletions(-) diff --git a/crates/prover/src/core/backend/simd/circle.rs b/crates/prover/src/core/backend/simd/circle.rs index e930f77b2..66f67d52b 100644 --- a/crates/prover/src/core/backend/simd/circle.rs +++ b/crates/prover/src/core/backend/simd/circle.rs @@ -2,6 +2,7 @@ use std::iter::zip; use std::mem::transmute; use bytemuck::{cast_slice, Zeroable}; +use itertools::Itertools; use num_traits::One; use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE}; @@ -9,9 +10,10 @@ use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; use super::qm31::PackedSecureField; use super::SimdBackend; use crate::core::backend::simd::column::BaseColumn; -use crate::core::backend::{Col, CpuBackend}; -use crate::core::circle::{CirclePoint, Coset}; -use crate::core::fields::m31::BaseField; +use crate::core::backend::simd::m31::PackedM31; +use crate::core::backend::{Col, Column, CpuBackend}; +use crate::core::circle::{CirclePoint, Coset, M31_CIRCLE_LOG_ORDER}; +use crate::core::fields::m31::{BaseField, M31}; use crate::core::fields::qm31::SecureField; use crate::core::fields::{Field, FieldExpOps}; use crate::core::poly::circle::{ @@ -20,6 +22,7 @@ use crate::core::poly::circle::{ use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold}; use crate::core::poly::BitReversedOrder; +use crate::core::utils::{bit_reverse, bit_reverse_index}; impl SimdBackend { // TODO(Ohad): optimize. @@ -275,32 +278,95 @@ impl PolyOps for SimdBackend { ) } - fn precompute_twiddles(coset: Coset) -> TwiddleTree { - let mut twiddles = Vec::with_capacity(coset.size()); - let mut itwiddles = Vec::with_capacity(coset.size()); + #[allow(clippy::int_plus_one)] + fn precompute_twiddles(mut coset: Coset) -> TwiddleTree { + let root_coset = coset; - // TODO(spapini): Optimize. - for layer in &rfft::get_twiddle_dbls(coset) { - twiddles.extend(layer); + // Generate xs for descending cosets, each bit reversed. + let mut xs = Vec::with_capacity(coset.size() / N_LANES); + while coset.log_size() - 1 >= LOG_N_LANES { + gen_coset_xs(coset, &mut xs); + coset = coset.double(); + } + + let mut extra = Vec::with_capacity(N_LANES); + while coset.log_size() > 0 { + let start = extra.len(); + extra.extend( + coset + .iter() + .take(coset.size() / 2) + .map(|p| p.x) + .collect_vec(), + ); + bit_reverse(&mut extra[start..]); + coset = coset.double(); } - // Pad by any value, to make the size a power of 2. - twiddles.push(1); - assert_eq!(twiddles.len(), coset.size()); - for layer in &ifft::get_itwiddle_dbls(coset) { - itwiddles.extend(layer); + extra.push(M31::one()); + + if extra.len() < N_LANES { + let twiddles = extra.iter().map(|x| x.0 * 2).collect(); + let itwiddles = extra.iter().map(|x| x.inverse().0 * 2).collect(); + return TwiddleTree { + root_coset, + twiddles, + itwiddles, + }; } - // Pad by any value, to make the size a power of 2. - itwiddles.push(1); - assert_eq!(itwiddles.len(), coset.size()); + + xs.push(PackedM31::from_array(extra.try_into().unwrap())); + + let mut ixs = unsafe { BaseColumn::uninitialized(root_coset.size()) }.data; + PackedBaseField::batch_inverse(&xs, &mut ixs); + + let twiddles = xs + .into_iter() + .flat_map(|x| x.to_array().map(|x| x.0 * 2)) + .collect(); + let itwiddles = ixs + .into_iter() + .flat_map(|x| x.to_array().map(|x| x.0 * 2)) + .collect(); TwiddleTree { - root_coset: coset, + root_coset, twiddles, itwiddles, } } } +#[allow(clippy::int_plus_one)] +fn gen_coset_xs(coset: Coset, res: &mut Vec) { + let log_size = coset.log_size() - 1; + assert!(log_size >= LOG_N_LANES); + + let initial_points = std::array::from_fn(|i| coset.at(bit_reverse_index(i, log_size))); + let mut current = CirclePoint { + x: PackedM31::from_array(initial_points.each_ref().map(|p| p.x)), + y: PackedM31::from_array(initial_points.each_ref().map(|p| p.y)), + }; + + let mut flips = [CirclePoint::zero(); (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize]; + for i in 0..(log_size - LOG_N_LANES) { + let prev_mul = bit_reverse_index((1 << i) - 1, log_size - LOG_N_LANES); + let new_mul = bit_reverse_index(1 << i, log_size - LOG_N_LANES); + let flip = coset.step.mul(new_mul as u128) - coset.step.mul(prev_mul as u128); + flips[i as usize] = flip; + } + + for i in 0u32..1 << (log_size - LOG_N_LANES) { + let x = current.x; + let flip_index = i.trailing_ones() as usize; + let flip = CirclePoint { + x: PackedM31::broadcast(flips[flip_index].x), + y: PackedM31::broadcast(flips[flip_index].y), + }; + current = current + flip; + res.push(x); + } +} + fn slow_eval_at_point( poly: &CirclePoly, point: CirclePoint,