Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simd twiddles #820

Open
wants to merge 2 commits into
base: spapini/09-05-parallel_fft
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 84 additions & 18 deletions crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ use std::iter::zip;
use std::mem::transmute;

use bytemuck::{cast_slice, Zeroable};
use itertools::Itertools;
use num_traits::One;

use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE};
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::{Col, CpuBackend};
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::backend::simd::m31::PackedM31;
use crate::core::backend::{Col, Column, CpuBackend};
use crate::core::circle::{CirclePoint, Coset, M31_CIRCLE_LOG_ORDER};
use crate::core::fields::m31::{BaseField, M31};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{Field, FieldExpOps};
use crate::core::poly::circle::{
Expand All @@ -20,6 +22,7 @@ use crate::core::poly::circle::{
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse, bit_reverse_index};

impl SimdBackend {
// TODO(Ohad): optimize.
Expand Down Expand Up @@ -275,32 +278,95 @@ impl PolyOps for SimdBackend {
)
}

fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {
let mut twiddles = Vec::with_capacity(coset.size());
let mut itwiddles = Vec::with_capacity(coset.size());
#[allow(clippy::int_plus_one)]
fn precompute_twiddles(mut coset: Coset) -> TwiddleTree<Self> {
let root_coset = coset;

// TODO(spapini): Optimize.
for layer in &rfft::get_twiddle_dbls(coset) {
twiddles.extend(layer);
// Generate xs for descending cosets, each bit reversed.
let mut xs = Vec::with_capacity(coset.size() / N_LANES);
while coset.log_size() - 1 >= LOG_N_LANES {
gen_coset_xs(coset, &mut xs);
coset = coset.double();
}

let mut extra = Vec::with_capacity(N_LANES);
while coset.log_size() > 0 {
let start = extra.len();
extra.extend(
coset
.iter()
.take(coset.size() / 2)
.map(|p| p.x)
.collect_vec(),
);
bit_reverse(&mut extra[start..]);
coset = coset.double();
}
// Pad by any value, to make the size a power of 2.
twiddles.push(1);
assert_eq!(twiddles.len(), coset.size());
for layer in &ifft::get_itwiddle_dbls(coset) {
itwiddles.extend(layer);
extra.push(M31::one());

if extra.len() < N_LANES {
let twiddles = extra.iter().map(|x| x.0 * 2).collect();
let itwiddles = extra.iter().map(|x| x.inverse().0 * 2).collect();
return TwiddleTree {
root_coset,
twiddles,
itwiddles,
};
}
// Pad by any value, to make the size a power of 2.
itwiddles.push(1);
assert_eq!(itwiddles.len(), coset.size());

xs.push(PackedM31::from_array(extra.try_into().unwrap()));

let mut ixs = unsafe { BaseColumn::uninitialized(root_coset.size()) }.data;
PackedBaseField::batch_inverse(&xs, &mut ixs);

let twiddles = xs
.into_iter()
.flat_map(|x| x.to_array().map(|x| x.0 * 2))
.collect();
let itwiddles = ixs
.into_iter()
.flat_map(|x| x.to_array().map(|x| x.0 * 2))
.collect();

TwiddleTree {
root_coset: coset,
root_coset,
twiddles,
itwiddles,
}
}
}

#[allow(clippy::int_plus_one)]
fn gen_coset_xs(coset: Coset, res: &mut Vec<PackedM31>) {
let log_size = coset.log_size() - 1;
assert!(log_size >= LOG_N_LANES);

let initial_points = std::array::from_fn(|i| coset.at(bit_reverse_index(i, log_size)));
let mut current = CirclePoint {
x: PackedM31::from_array(initial_points.each_ref().map(|p| p.x)),
y: PackedM31::from_array(initial_points.each_ref().map(|p| p.y)),
};

let mut flips = [CirclePoint::zero(); (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize];
for i in 0..(log_size - LOG_N_LANES) {
let prev_mul = bit_reverse_index((1 << i) - 1, log_size - LOG_N_LANES);
let new_mul = bit_reverse_index(1 << i, log_size - LOG_N_LANES);
let flip = coset.step.mul(new_mul as u128) - coset.step.mul(prev_mul as u128);
flips[i as usize] = flip;
}

for i in 0u32..1 << (log_size - LOG_N_LANES) {
let x = current.x;
let flip_index = i.trailing_ones() as usize;
let flip = CirclePoint {
x: PackedM31::broadcast(flips[flip_index].x),
y: PackedM31::broadcast(flips[flip_index].y),
};
current = current + flip;
res.push(x);
}
}

fn slow_eval_at_point(
poly: &CirclePoly<SimdBackend>,
point: CirclePoint<SecureField>,
Expand Down
32 changes: 28 additions & 4 deletions crates/prover/src/core/backend/simd/fft/ifft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4};

use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::simd::fft::UnsafeMutI32;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::circle::Coset;
use crate::core::fields::FieldExpOps;
Expand All @@ -29,6 +32,7 @@ use crate::core::utils::bit_reverse;
/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`].
pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) {
assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize);

let log_n_vecs = log_n_elements - LOG_N_LANES as usize;
if log_n_elements <= CACHED_FFT_LOG_SIZE as usize {
ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements);
Expand Down Expand Up @@ -81,7 +85,17 @@ pub unsafe fn ifft_lower_with_vecwise(

assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));

for index_h in 0..1 << (log_size - fft_layers) {
let iter_range = 0..1 << (log_size - fft_layers);

#[cfg(not(feature = "parallel"))]
let iter = iter_range;

#[cfg(feature = "parallel")]
let iter = iter_range.into_par_iter();

let values = UnsafeMutI32(values);
iter.for_each(|index_h| {
let values = values.get();
ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h);
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) {
match fft_layers - layer {
Expand All @@ -102,7 +116,7 @@ pub unsafe fn ifft_lower_with_vecwise(
}
}
}
}
});
}

/// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of
Expand Down Expand Up @@ -131,7 +145,17 @@ pub unsafe fn ifft_lower_without_vecwise(
) {
assert!(log_size >= LOG_N_LANES as usize);

for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) {
let iter_range = 0..1 << (log_size - fft_layers - LOG_N_LANES as usize);

#[cfg(not(feature = "parallel"))]
let iter = iter_range;

#[cfg(feature = "parallel")]
let iter = iter_range.into_par_iter();

let values = UnsafeMutI32(values);
iter.for_each(|index_h| {
let values = values.get();
for layer in (0..fft_layers).step_by(3) {
let fixed_layer = layer + LOG_N_LANES as usize;
match fft_layers - layer {
Expand All @@ -152,7 +176,7 @@ pub unsafe fn ifft_lower_without_vecwise(
}
}
}
}
});
}

/// Runs the first 5 ifft layers across the entire array.
Expand Down
38 changes: 35 additions & 3 deletions crates/prover/src/core/backend/simd/fft/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::simd::{simd_swizzle, u32x16, u32x8};

#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::m31::PackedBaseField;
use crate::core::fields::m31::P;

Expand All @@ -10,6 +13,26 @@ pub const CACHED_FFT_LOG_SIZE: u32 = 16;

pub const MIN_FFT_LOG_SIZE: u32 = 5;

pub struct UnsafeMutI32(pub *mut u32);
impl UnsafeMutI32 {
pub fn get(&self) -> *mut u32 {
self.0
}
}

unsafe impl Send for UnsafeMutI32 {}
unsafe impl Sync for UnsafeMutI32 {}

pub struct UnsafeConstI32(pub *const u32);
impl UnsafeConstI32 {
pub fn get(&self) -> *const u32 {
self.0
}
}

unsafe impl Send for UnsafeConstI32 {}
unsafe impl Sync for UnsafeConstI32 {}

// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce
// it somewhere.

Expand All @@ -29,8 +52,17 @@ pub const MIN_FFT_LOG_SIZE: u32 = 5;
/// Behavior is undefined if `values` does not have the same alignment as [`u32x16`].
pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
let half = log_n_vecs / 2;
for b in 0..1 << (log_n_vecs & 1) {
for a in 0..1 << half {

#[cfg(not(feature = "parallel"))]
let iter = 0..1 << half;

#[cfg(feature = "parallel")]
let iter = (0..1 << half).into_par_iter();

let values = UnsafeMutI32(values);
iter.for_each(|a| {
let values = values.get();
for b in 0..1 << (log_n_vecs & 1) {
for c in 0..1 << half {
let i = (a << (log_n_vecs - half)) | (b << half) | c;
let j = (c << (log_n_vecs - half)) | (b << half) | a;
Expand All @@ -43,7 +75,7 @@ pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
store(values.add(j << 4), val0);
}
}
}
});
}

/// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers.
Expand Down
37 changes: 31 additions & 6 deletions crates/prover/src/core/backend/simd/fft/rfft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ use std::array;
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4, u32x8};

use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::simd::fft::{UnsafeConstI32, UnsafeMutI32};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::circle::Coset;
use crate::core::utils::bit_reverse;
Expand Down Expand Up @@ -86,8 +89,19 @@ pub unsafe fn fft_lower_with_vecwise(

assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2));

for index_h in 0..1 << (log_size - fft_layers) {
let mut src = src;
let iter_range = 0..1 << (log_size - fft_layers);

#[cfg(not(feature = "parallel"))]
let iter = iter_range;

#[cfg(feature = "parallel")]
let iter = iter_range.into_par_iter();

let src = UnsafeConstI32(src);
let dst = UnsafeMutI32(dst);
iter.for_each(|index_h| {
let mut src = src.get();
let dst = dst.get();
for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() {
match fft_layers - layer {
1 => {
Expand Down Expand Up @@ -116,7 +130,7 @@ pub unsafe fn fft_lower_with_vecwise(
fft_layers - VECWISE_FFT_BITS,
index_h,
);
}
});
}

/// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of
Expand Down Expand Up @@ -147,8 +161,19 @@ pub unsafe fn fft_lower_without_vecwise(
) {
assert!(log_size >= LOG_N_LANES as usize);

for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) {
let mut src = src;
let iter_range = 0..1 << (log_size - fft_layers - LOG_N_LANES as usize);

#[cfg(not(feature = "parallel"))]
let iter = iter_range;

#[cfg(feature = "parallel")]
let iter = iter_range.into_par_iter();

let src = UnsafeConstI32(src);
let dst = UnsafeMutI32(dst);
iter.for_each(|index_h| {
let mut src = src.get();
let dst = dst.get();
for layer in (0..fft_layers).step_by(3).rev() {
let fixed_layer = layer + LOG_N_LANES as usize;
match fft_layers - layer {
Expand All @@ -171,7 +196,7 @@ pub unsafe fn fft_lower_without_vecwise(
}
src = dst;
}
}
});
}

/// Runs the last 5 fft layers across the entire array.
Expand Down
Loading