diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index f687cbe72..9385071cb 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -44,9 +44,6 @@ nonstandard-style = "deny" rust-2018-idioms = "deny" unused = "deny" -[features] -avx512 = [] - [[bench]] name = "bit_rev" harness = false diff --git a/crates/prover/benches/bit_rev.rs b/crates/prover/benches/bit_rev.rs index 219a6d688..e71149e0e 100644 --- a/crates/prover/benches/bit_rev.rs +++ b/crates/prover/benches/bit_rev.rs @@ -32,30 +32,6 @@ pub fn simd_bit_rev(c: &mut Criterion) { }); } -#[cfg(target_arch = "x86_64")] -pub fn avx512_bit_rev(c: &mut Criterion) { - use stwo_prover::core::backend::avx512::bit_reverse::bit_reverse_m31; - use stwo_prover::core::backend::avx512::BaseFieldVec; - const SIZE: usize = 1 << 26; - if !stwo_prover::platform::avx512_detected() { - return; - } - let data = (0..SIZE).map(BaseField::from).collect::(); - c.bench_function("avx bit_rev 26bit", |b| { - b.iter_batched( - || data.data.clone(), - |mut data| bit_reverse_m31(&mut data), - BatchSize::LargeInput, - ); - }); -} - -#[cfg(target_arch = "x86_64")] -criterion_group!( - name = bit_rev; - config = Criterion::default().sample_size(10); - targets = avx512_bit_rev, simd_bit_rev, cpu_bit_rev); -#[cfg(not(target_arch = "x86_64"))] criterion_group!( name = bit_rev; config = Criterion::default().sample_size(10); diff --git a/crates/prover/benches/eval_at_point.rs b/crates/prover/benches/eval_at_point.rs index d63c78aef..0b2f22a68 100644 --- a/crates/prover/benches/eval_at_point.rs +++ b/crates/prover/benches/eval_at_point.rs @@ -24,11 +24,6 @@ fn bench_eval_at_secure_point(c: &mut Criterion, id: &str) { } fn eval_at_secure_point_benches(c: &mut Criterion) { - #[cfg(target_arch = "x86_64")] - if stwo_prover::platform::avx512_detected() { - use stwo_prover::core::backend::avx512::AVX512Backend; - bench_eval_at_secure_point::(c, "avx"); - } bench_eval_at_secure_point::(c, "simd"); bench_eval_at_secure_point::(c, "cpu"); } diff --git a/crates/prover/benches/fft.rs b/crates/prover/benches/fft.rs index 34d00797a..a8dd61ff7 100644 --- a/crates/prover/benches/fft.rs +++ b/crates/prover/benches/fft.rs @@ -3,15 +3,19 @@ use std::hint::black_box; use std::mem::{size_of_val, transmute}; -use criterion::{BatchSize, BenchmarkId, Criterion, Throughput}; +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; use itertools::Itertools; +use stwo_prover::core::backend::simd::column::BaseFieldVec; +use stwo_prover::core::backend::simd::fft::ifft::{ + get_itwiddle_dbls, ifft, ifft3_loop, ifft_vecwise_loop, +}; +use stwo_prover::core::backend::simd::fft::rfft::{fft, get_twiddle_dbls}; +use stwo_prover::core::backend::simd::fft::transpose_vecs; +use stwo_prover::core::backend::simd::m31::PackedBaseField; use stwo_prover::core::fields::m31::BaseField; use stwo_prover::core::poly::circle::CanonicCoset; pub fn simd_ifft(c: &mut Criterion) { - use stwo_prover::core::backend::simd::column::BaseFieldVec; - use stwo_prover::core::backend::simd::fft::ifft::{get_itwiddle_dbls, ifft}; - let mut group = c.benchmark_group("iffts"); for log_size in 16..=28 { @@ -37,12 +41,6 @@ pub fn simd_ifft(c: &mut Criterion) { } pub fn simd_ifft_parts(c: &mut Criterion) { - use stwo_prover::core::backend::simd::column::BaseFieldVec; - use stwo_prover::core::backend::simd::fft::ifft::{ - get_itwiddle_dbls, ifft3_loop, ifft_vecwise_loop, - }; - use stwo_prover::core::backend::simd::fft::transpose_vecs; - const LOG_SIZE: u32 = 14; let domain = CanonicCoset::new(LOG_SIZE).circle_domain(); @@ -104,10 +102,6 @@ pub fn simd_ifft_parts(c: &mut Criterion) { } pub fn simd_rfft(c: &mut Criterion) { - use stwo_prover::core::backend::simd::column::BaseFieldVec; - use stwo_prover::core::backend::simd::fft::rfft::{fft, get_twiddle_dbls}; - use stwo_prover::core::backend::simd::m31::PackedBaseField; - const LOG_SIZE: u32 = 20; let domain = CanonicCoset::new(LOG_SIZE).circle_domain(); @@ -131,161 +125,8 @@ pub fn simd_rfft(c: &mut Criterion) { }); } -#[cfg(target_arch = "x86_64")] -pub fn avx512_ifft(c: &mut criterion::Criterion) { - use stwo_prover::core::backend::avx512::fft::ifft; - use stwo_prover::platform; - if !platform::avx512_detected() { - return; - } - - let mut group = c.benchmark_group("iffts"); - for log_size in 16..=28 { - let (values, twiddle_dbls) = prepare_values(log_size); - - group.throughput(Throughput::Bytes( - (std::mem::size_of::() as u64) << log_size, - )); - group.bench_function(BenchmarkId::new("avx ifft", log_size), |b| { - b.iter_batched( - || values.clone().data, - |mut values| unsafe { - ifft::ifft( - std::mem::transmute(values.as_mut_ptr()), - &twiddle_dbls - .iter() - .map(|x| x.as_slice()) - .collect::>(), - log_size as usize, - ) - }, - BatchSize::LargeInput, - ); - }); - } -} - -#[cfg(target_arch = "x86_64")] -pub fn avx512_ifft_parts(c: &mut criterion::Criterion) { - use stwo_prover::core::backend::avx512::fft::{ifft, transpose_vecs}; - use stwo_prover::platform; - if !platform::avx512_detected() { - return; - } - - let (values, twiddle_dbls) = prepare_values(14); - let mut group = c.benchmark_group("ifft parts"); - - // Note: These benchmarks run only on 2^14 elements ebcause of their parameters. - // Increasing the figure above won't change the runtime of these benchmarks. - group.throughput(Throughput::Bytes(4 << 14)); - group.bench_function("avx ifft_vecwise_loop 2^14", |b| { - b.iter_batched( - || values.clone().data, - |mut values| unsafe { - ifft::ifft_vecwise_loop( - std::mem::transmute(values.as_mut_ptr()), - &twiddle_dbls - .iter() - .map(|x| x.as_slice()) - .collect::>(), - 9, - 0, - ) - }, - BatchSize::LargeInput, - ); - }); - - group.bench_function("avx ifft3_loop 2^14", |b| { - b.iter_batched( - || values.clone().data, - |mut values| unsafe { - ifft::ifft3_loop( - std::mem::transmute(values.as_mut_ptr()), - &twiddle_dbls - .iter() - .skip(3) - .map(|x| x.as_slice()) - .collect::>(), - 7, - 4, - 0, - ) - }, - BatchSize::LargeInput, - ); - }); - - let (values, _twiddle_dbls) = prepare_values(20); - group.throughput(Throughput::Bytes(4 << 20)); - group.bench_function("avx transpose_vecs 2^20", |b| { - b.iter_batched( - || values.clone().data, - |mut values| unsafe { - transpose_vecs(std::mem::transmute(values.as_mut_ptr()), (20 - 4) as usize); - }, - BatchSize::LargeInput, - ); - }); -} - -#[cfg(target_arch = "x86_64")] -pub fn avx512_rfft(c: &mut criterion::Criterion) { - use stwo_prover::core::backend::avx512::fft::rfft; - use stwo_prover::core::backend::avx512::PackedBaseField; - use stwo_prover::platform; - if !platform::avx512_detected() { - return; - } - - const LOG_SIZE: u32 = 20; - let (values, twiddle_dbls) = prepare_values(LOG_SIZE); - - c.bench_function("avx rfft 20bit", |b| { - b.iter_with_large_drop(|| unsafe { - let mut target = Vec::::with_capacity(values.data.len()); - #[allow(clippy::uninit_vec)] - target.set_len(values.data.len()); - - rfft::fft( - black_box(std::mem::transmute(values.data.as_ptr())), - std::mem::transmute(target.as_mut_ptr()), - &twiddle_dbls - .iter() - .map(|x| x.as_slice()) - .collect::>(), - LOG_SIZE as usize, - ); - }) - }); -} - -#[cfg(target_arch = "x86_64")] -fn prepare_values( - log_size: u32, -) -> ( - stwo_prover::core::backend::avx512::BaseFieldVec, - Vec>, -) { - use stwo_prover::core::backend::avx512::fft::ifft::get_itwiddle_dbls; - let domain = CanonicCoset::new(log_size).circle_domain(); - let values = (0..domain.size()) - .map(|i| BaseField::from_u32_unchecked(i as u32)) - .collect::>(); - let values = stwo_prover::core::backend::avx512::BaseFieldVec::from_iter(values); - let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); - (values, twiddle_dbls) -} - -#[cfg(target_arch = "x86_64")] -criterion::criterion_group!( - name = benches; - config = Criterion::default().sample_size(10); - targets = avx512_ifft, avx512_ifft_parts, avx512_rfft, simd_ifft, simd_ifft_parts, simd_rfft); -#[cfg(not(target_arch = "x86_64"))] -criterion::criterion_group!( +criterion_group!( name = benches; config = Criterion::default().sample_size(10); targets = simd_ifft, simd_ifft_parts, simd_rfft); -criterion::criterion_main!(benches); +criterion_main!(benches); diff --git a/crates/prover/benches/field.rs b/crates/prover/benches/field.rs index 4c5493a2f..acb318cf5 100644 --- a/crates/prover/benches/field.rs +++ b/crates/prover/benches/field.rs @@ -1,8 +1,8 @@ -use criterion::Criterion; +use criterion::{criterion_group, criterion_main, Criterion}; use num_traits::One; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; -use stwo_prover::core::backend::simd::m31::N_LANES; +use stwo_prover::core::backend::simd::m31::{PackedBaseField, N_LANES}; use stwo_prover::core::fields::cm31::CM31; use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fields::qm31::SecureField; @@ -10,7 +10,7 @@ use stwo_prover::core::fields::qm31::SecureField; pub const N_ELEMENTS: usize = 1 << 16; pub const N_STATE_ELEMENTS: usize = 8; -pub fn m31_operations_bench(c: &mut criterion::Criterion) { +pub fn m31_operations_bench(c: &mut Criterion) { let mut rng = SmallRng::seed_from_u64(0); let elements: Vec = (0..N_ELEMENTS).map(|_| rng.gen()).collect(); let mut state: [M31; N_STATE_ELEMENTS] = rng.gen(); @@ -40,7 +40,7 @@ pub fn m31_operations_bench(c: &mut criterion::Criterion) { }); } -pub fn cm31_operations_bench(c: &mut criterion::Criterion) { +pub fn cm31_operations_bench(c: &mut Criterion) { let mut rng = SmallRng::seed_from_u64(0); let elements: Vec = (0..N_ELEMENTS).map(|_| rng.gen()).collect(); let mut state: [CM31; N_STATE_ELEMENTS] = rng.gen(); @@ -70,7 +70,7 @@ pub fn cm31_operations_bench(c: &mut criterion::Criterion) { }); } -pub fn qm31_operations_bench(c: &mut criterion::Criterion) { +pub fn qm31_operations_bench(c: &mut Criterion) { let mut rng = SmallRng::seed_from_u64(0); let elements: Vec = (0..N_ELEMENTS).map(|_| rng.gen()).collect(); let mut state: [SecureField; N_STATE_ELEMENTS] = rng.gen(); @@ -100,64 +100,7 @@ pub fn qm31_operations_bench(c: &mut criterion::Criterion) { }); } -#[cfg(target_arch = "x86_64")] -pub fn avx512_m31_operations_bench(c: &mut criterion::Criterion) { - use stwo_prover::core::backend::avx512::m31::{PackedBaseField, K_BLOCK_SIZE}; - use stwo_prover::platform; - - if !platform::avx512_detected() { - return; - } - - let mut rng = SmallRng::seed_from_u64(0); - let mut elements: Vec = Vec::new(); - let mut states: Vec = - vec![PackedBaseField::from_array([1.into(); K_BLOCK_SIZE]); N_STATE_ELEMENTS]; - - for _ in 0..(N_ELEMENTS / K_BLOCK_SIZE) { - elements.push(PackedBaseField::from_array(rng.gen())); - } - - c.bench_function("mul_avx512", |b| { - b.iter(|| { - for elem in elements.iter() { - for _ in 0..128 { - for state in states.iter_mut() { - *state *= *elem; - } - } - } - }) - }); - - c.bench_function("add_avx512", |b| { - b.iter(|| { - for elem in elements.iter() { - for _ in 0..128 { - for state in states.iter_mut() { - *state += *elem; - } - } - } - }) - }); - - c.bench_function("sub_avx512", |b| { - b.iter(|| { - for elem in elements.iter() { - for _ in 0..128 { - for state in states.iter_mut() { - *state -= *elem; - } - } - } - }) - }); -} - -pub fn simd_m31_operations_bench(c: &mut criterion::Criterion) { - use stwo_prover::core::backend::simd::m31::PackedBaseField; - +pub fn simd_m31_operations_bench(c: &mut Criterion) { let mut rng = SmallRng::seed_from_u64(0); let elements: Vec = (0..N_ELEMENTS / N_LANES).map(|_| rng.gen()).collect(); let mut states = vec![PackedBaseField::broadcast(BaseField::one()); N_STATE_ELEMENTS]; @@ -199,16 +142,9 @@ pub fn simd_m31_operations_bench(c: &mut criterion::Criterion) { }); } -#[cfg(target_arch = "x86_64")] -criterion::criterion_group!( - name = benches; - config = Criterion::default().sample_size(10); - targets = m31_operations_bench, cm31_operations_bench, qm31_operations_bench, - avx512_m31_operations_bench, simd_m31_operations_bench); -#[cfg(not(target_arch = "x86_64"))] -criterion::criterion_group!( +criterion_group!( name = benches; config = Criterion::default().sample_size(10); targets = m31_operations_bench, cm31_operations_bench, qm31_operations_bench, simd_m31_operations_bench); -criterion::criterion_main!(benches); +criterion_main!(benches); diff --git a/crates/prover/benches/merkle.rs b/crates/prover/benches/merkle.rs index 94dc707e8..89b8904d4 100644 --- a/crates/prover/benches/merkle.rs +++ b/crates/prover/benches/merkle.rs @@ -27,11 +27,6 @@ fn bench_blake2s_merkle>(c: &mut Criterion, id } fn blake2s_merkle_benches(c: &mut Criterion) { - #[cfg(target_arch = "x86_64")] - if stwo_prover::platform::avx512_detected() { - use stwo_prover::core::backend::avx512::AVX512Backend; - bench_blake2s_merkle::(c, "avx"); - } bench_blake2s_merkle::(c, "simd"); bench_blake2s_merkle::(c, "cpu"); } diff --git a/crates/prover/benches/quotients.rs b/crates/prover/benches/quotients.rs index 87b642e6d..bae785640 100644 --- a/crates/prover/benches/quotients.rs +++ b/crates/prover/benches/quotients.rs @@ -13,7 +13,7 @@ use stwo_prover::core::poly::BitReversedOrder; // TODO(andrew): Consider removing const generics and making all sizes the same. fn bench_quotients( - c: &mut criterion::Criterion, + c: &mut Criterion, id: &str, ) { let domain = CanonicCoset::new(LOG_N_ROWS).circle_domain(); @@ -42,12 +42,7 @@ fn bench_quotients ); } -fn quotients_benches(c: &mut criterion::Criterion) { - #[cfg(target_arch = "x86_64")] - if stwo_prover::platform::avx512_detected() { - use stwo_prover::core::backend::avx512::AVX512Backend; - bench_quotients::(c, "avx"); - } +fn quotients_benches(c: &mut Criterion) { bench_quotients::(c, "simd"); bench_quotients::(c, "cpu"); } diff --git a/crates/prover/src/core/backend/avx512/accumulation.rs b/crates/prover/src/core/backend/avx512/accumulation.rs deleted file mode 100644 index 973698555..000000000 --- a/crates/prover/src/core/backend/avx512/accumulation.rs +++ /dev/null @@ -1,12 +0,0 @@ -use super::AVX512Backend; -use crate::core::air::accumulation::AccumulationOps; -use crate::core::fields::secure_column::SecureColumn; - -impl AccumulationOps for AVX512Backend { - fn accumulate(column: &mut SecureColumn, other: &SecureColumn) { - for i in 0..column.n_packs() { - let res_coeff = column.packed_at(i) + other.packed_at(i); - unsafe { column.set_packed(i, res_coeff) }; - } - } -} diff --git a/crates/prover/src/core/backend/avx512/bit_reverse.rs b/crates/prover/src/core/backend/avx512/bit_reverse.rs deleted file mode 100644 index f729ca590..000000000 --- a/crates/prover/src/core/backend/avx512/bit_reverse.rs +++ /dev/null @@ -1,171 +0,0 @@ -use std::arch::x86_64::{__m512i, _mm512_permutex2var_epi32}; - -use super::PackedBaseField; -use crate::core::utils::bit_reverse_index; - -const VEC_BITS: u32 = 4; -const W_BITS: u32 = 3; -pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS; - -/// Bit reverses packed M31 values. -/// Given an array `A[0..2^n)`, computes `B[i] = A[bit_reverse(i)]`. -pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { - assert!(data.len().is_power_of_two()); - assert!(data.len().ilog2() >= MIN_LOG_SIZE); - - // Indices in the array are of the form v_h w_h a w_l v_l, with - // |v_h| = |v_l| = VEC_BITS, |w_h| = |w_l| = W_BITS, |a| = n - 2*W_BITS - VEC_BITS. - // The loops go over a, w_l, w_h, and then swaps the 16 by 16 values at: - // * w_h a w_l * <-> * rev(w_h a w_l) *. - // These are 1 or 2 chunks of 2^W_BITS contiguous AVX512 vectors. - - let log_size = data.len().ilog2(); - let a_bits = log_size - 2 * W_BITS - VEC_BITS; - - // TODO(spapini): when doing multithreading, do it over a. - for a in 0u32..(1 << a_bits) { - for w_l in 0u32..(1 << W_BITS) { - let w_l_rev = w_l.reverse_bits() >> (32 - W_BITS); - for w_h in 0u32..(w_l_rev + 1) { - let idx = ((((w_h << a_bits) | a) << W_BITS) | w_l) as usize; - let idx_rev = bit_reverse_index(idx, log_size - VEC_BITS); - - // In order to not swap twice, only swap if idx <= idx_rev. - if idx > idx_rev { - continue; - } - - // Read first chunk. - // TODO(spapini): Think about optimizing a_bits. - let chunk0 = std::array::from_fn(|i| unsafe { - *data.get_unchecked(idx + (i << (2 * W_BITS + a_bits))) - }); - let values0 = bit_reverse16(chunk0); - - if idx == idx_rev { - // Palindrome index. Write into the same chunk. - #[allow(clippy::needless_range_loop)] - for i in 0..16 { - unsafe { - *data.get_unchecked_mut(idx + (i << (2 * W_BITS + a_bits))) = - values0[i]; - } - } - continue; - } - - // Read bit reversed chunk. - let chunk1 = std::array::from_fn(|i| unsafe { - *data.get_unchecked(idx_rev + (i << (2 * W_BITS + a_bits))) - }); - let values1 = bit_reverse16(chunk1); - - for i in 0..16 { - unsafe { - *data.get_unchecked_mut(idx + (i << (2 * W_BITS + a_bits))) = values1[i]; - *data.get_unchecked_mut(idx_rev + (i << (2 * W_BITS + a_bits))) = - values0[i]; - } - } - } - } - } -} - -/// Bit reverses 256 M31 values, packed in 16 words of 16 elements each. -fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { - let mut data: [__m512i; 16] = unsafe { std::mem::transmute(data) }; - // L is an input to _mm512_permutex2var_epi32, and it is used to - // interleave the first half of a with the first half of b. - const L: __m512i = unsafe { - core::mem::transmute([ - 0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100, - 0b10100, 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111, - ]) - }; - // H is an input to _mm512_permutex2var_epi32, and it is used to interleave the second half - // interleave the second half of a with the second half of b. - const H: __m512i = unsafe { - core::mem::transmute([ - 0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100, - 0b11100, 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111, - ]) - }; - - // Denote the index of each element in the 16 packed M31 words as abcd:0123, - // where abcd is the index of the packed word and 0123 is the index of the element in the word. - // Bit reversal is achieved by applying the following permutation to the index for 4 times: - // abcd:0123 => 0abc:123d - // This is how it looks like at each iteration. - // abcd:0123 - // 0abc:123d - // 10ab:23dc - // 210a:3dcb - // 3210:dcba - for _ in 0..4 { - // Apply the abcd:0123 => 0abc:123d permutation. - // _mm512_permutex2var_epi32() with L allows us to interleave the first half of 2 words. - // For example, the second call interleaves 0010:0xyz (low half of register 2) with - // 0011:0xyz (low half of register 3), and stores the result in register 1 (0001). - // This results in - // 0001:xyz0 (even indices of register 1) <= 0010:0xyz (low half of register2), and - // 0001:xyz1 (odd indices of register 1) <= 0011:0xyz (low half of register 3) - // or 0001:xyzw <= 001w:0xyz. - unsafe { - data = [ - _mm512_permutex2var_epi32(data[0], L, data[1]), - _mm512_permutex2var_epi32(data[2], L, data[3]), - _mm512_permutex2var_epi32(data[4], L, data[5]), - _mm512_permutex2var_epi32(data[6], L, data[7]), - _mm512_permutex2var_epi32(data[8], L, data[9]), - _mm512_permutex2var_epi32(data[10], L, data[11]), - _mm512_permutex2var_epi32(data[12], L, data[13]), - _mm512_permutex2var_epi32(data[14], L, data[15]), - _mm512_permutex2var_epi32(data[0], H, data[1]), - _mm512_permutex2var_epi32(data[2], H, data[3]), - _mm512_permutex2var_epi32(data[4], H, data[5]), - _mm512_permutex2var_epi32(data[6], H, data[7]), - _mm512_permutex2var_epi32(data[8], H, data[9]), - _mm512_permutex2var_epi32(data[10], H, data[11]), - _mm512_permutex2var_epi32(data[12], H, data[13]), - _mm512_permutex2var_epi32(data[14], H, data[15]), - ]; - } - } - unsafe { std::mem::transmute(data) } -} - -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -#[cfg(test)] -mod tests { - use super::bit_reverse16; - use crate::core::backend::avx512::bit_reverse::bit_reverse_m31; - use crate::core::backend::avx512::BaseFieldVec; - use crate::core::backend::Column; - use crate::core::fields::m31::BaseField; - use crate::core::utils::bit_reverse; - - #[test] - fn test_bit_reverse16() { - let data: [u32; 256] = std::array::from_fn(|i| i as u32); - let expected: [u32; 256] = std::array::from_fn(|i| (i as u32).reverse_bits() >> 24); - unsafe { - let data = bit_reverse16(std::mem::transmute(data)); - assert_eq!(std::mem::transmute::<_, [u32; 256]>(data), expected); - } - } - - #[test] - fn test_bit_reverse() { - const SIZE: usize = 1 << 15; - let data: Vec<_> = (0..SIZE as u32) - .map(BaseField::from_u32_unchecked) - .collect(); - let mut expected = data.clone(); - bit_reverse(&mut expected); - let mut data: BaseFieldVec = data.into_iter().collect(); - - bit_reverse_m31(&mut data.data[..]); - assert_eq!(data.to_cpu(), expected); - } -} diff --git a/crates/prover/src/core/backend/avx512/blake2s.rs b/crates/prover/src/core/backend/avx512/blake2s.rs deleted file mode 100644 index 7a64776eb..000000000 --- a/crates/prover/src/core/backend/avx512/blake2s.rs +++ /dev/null @@ -1,91 +0,0 @@ -use std::arch::x86_64::{__m512i, _mm512_loadu_si512}; - -use itertools::Itertools; - -use super::blake2s_avx::{compress16, set1, transpose_msgs, untranspose_states}; -use super::{AVX512Backend, VECS_LOG_SIZE}; -use crate::core::backend::{Col, Column, ColumnOps}; -use crate::core::fields::m31::BaseField; -use crate::core::vcs::blake2_hash::Blake2sHash; -use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; -use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; - -impl ColumnOps for AVX512Backend { - type Column = Vec; - - fn bit_reverse_column(_column: &mut Self::Column) { - unimplemented!() - } -} - -impl MerkleOps for AVX512Backend { - fn commit_on_layer( - log_size: u32, - prev_layer: Option<&Vec>, - columns: &[&Col], - ) -> Vec { - // Pad prev_layer if too small. - if log_size < VECS_LOG_SIZE as u32 { - return (0..(1 << log_size)) - .map(|i| { - Blake2sMerkleHasher::hash_node( - prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), - &columns.iter().map(|column| column.at(i)).collect_vec(), - ) - }) - .collect(); - } - - if let Some(prev_layer) = prev_layer { - assert_eq!(prev_layer.len(), 1 << (log_size + 1)); - } - - // Commit to columns. - let mut res = Vec::with_capacity(1 << log_size); - for i in 0..(1 << (log_size - VECS_LOG_SIZE as u32)) { - let mut state: [__m512i; 8] = unsafe { std::mem::zeroed() }; - // Hash prev_layer, if exists. - if let Some(prev_layer) = prev_layer { - let ptr = prev_layer[(i << 5)..((i + 1) << 5)].as_ptr() as *const __m512i; - let msgs: [__m512i; 16] = std::array::from_fn(|j| unsafe { - _mm512_loadu_si512(ptr.add(j) as *const i32) - }); - state = unsafe { - compress16( - state, - transpose_msgs(msgs), - set1(0), - set1(0), - set1(0), - set1(0), - ) - }; - } - - // Hash columns in chunks of 16. - let mut col_chunk_iter = columns.array_chunks(); - for col_chunk in &mut col_chunk_iter { - let msgs = col_chunk.map(|column| column.data[i].0); - state = unsafe { compress16(state, msgs, set1(0), set1(0), set1(0), set1(0)) }; - } - - // Hash remaining columns. - let remainder = col_chunk_iter.remainder(); - if !remainder.is_empty() { - let msgs = remainder - .iter() - .map(|column| column.data[i].0) - .chain(std::iter::repeat(unsafe { set1(0) })) - .take(16) - .collect_vec() - .try_into() - .unwrap(); - state = unsafe { compress16(state, msgs, set1(0), set1(0), set1(0), set1(0)) }; - } - let state: [Blake2sHash; 16] = - unsafe { std::mem::transmute(untranspose_states(state)) }; - res.extend_from_slice(&state); - } - res - } -} diff --git a/crates/prover/src/core/backend/avx512/blake2s_avx.rs b/crates/prover/src/core/backend/avx512/blake2s_avx.rs deleted file mode 100644 index 2c6aef385..000000000 --- a/crates/prover/src/core/backend/avx512/blake2s_avx.rs +++ /dev/null @@ -1,354 +0,0 @@ -//! An AVX512 implementation of the BLAKE2s compression function. -//! Based on . - -use std::arch::x86_64::{ - __m512i, _mm512_add_epi32, _mm512_or_si512, _mm512_permutex2var_epi32, _mm512_set1_epi32, - _mm512_slli_epi32, _mm512_srli_epi32, _mm512_xor_si512, -}; - -use super::tranpose_utils::{ - EVENS_CONCAT_EVENS, HHALF_INTERLEAVE_HHALF, LHALF_INTERLEAVE_LHALF, ODDS_CONCAT_ODDS, -}; - -const IV: [u32; 8] = [ - 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, -]; - -const SIGMA: [[u8; 16]; 10] = [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], - [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], - [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], - [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], - [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], - [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], - [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], - [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], - [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0], -]; - -/// # Safety -#[inline(always)] -pub unsafe fn set1(iv: i32) -> __m512i { - _mm512_set1_epi32(iv) -} - -#[inline(always)] -unsafe fn add(a: __m512i, b: __m512i) -> __m512i { - _mm512_add_epi32(a, b) -} - -#[inline(always)] -unsafe fn xor(a: __m512i, b: __m512i) -> __m512i { - _mm512_xor_si512(a, b) -} - -#[inline(always)] -unsafe fn rot16(x: __m512i) -> __m512i { - _mm512_or_si512(_mm512_srli_epi32(x, 16), _mm512_slli_epi32(x, 32 - 16)) -} - -#[inline(always)] -unsafe fn rot12(x: __m512i) -> __m512i { - _mm512_or_si512(_mm512_srli_epi32(x, 12), _mm512_slli_epi32(x, 32 - 12)) -} - -#[inline(always)] -unsafe fn rot8(x: __m512i) -> __m512i { - _mm512_or_si512(_mm512_srli_epi32(x, 8), _mm512_slli_epi32(x, 32 - 8)) -} - -#[inline(always)] -unsafe fn rot7(x: __m512i) -> __m512i { - _mm512_or_si512(_mm512_srli_epi32(x, 7), _mm512_slli_epi32(x, 32 - 7)) -} - -#[inline(always)] -unsafe fn round(v: &mut [__m512i; 16], m: [__m512i; 16], r: usize) { - v[0] = add(v[0], m[SIGMA[r][0] as usize]); - v[1] = add(v[1], m[SIGMA[r][2] as usize]); - v[2] = add(v[2], m[SIGMA[r][4] as usize]); - v[3] = add(v[3], m[SIGMA[r][6] as usize]); - v[0] = add(v[0], v[4]); - v[1] = add(v[1], v[5]); - v[2] = add(v[2], v[6]); - v[3] = add(v[3], v[7]); - v[12] = xor(v[12], v[0]); - v[13] = xor(v[13], v[1]); - v[14] = xor(v[14], v[2]); - v[15] = xor(v[15], v[3]); - v[12] = rot16(v[12]); - v[13] = rot16(v[13]); - v[14] = rot16(v[14]); - v[15] = rot16(v[15]); - v[8] = add(v[8], v[12]); - v[9] = add(v[9], v[13]); - v[10] = add(v[10], v[14]); - v[11] = add(v[11], v[15]); - v[4] = xor(v[4], v[8]); - v[5] = xor(v[5], v[9]); - v[6] = xor(v[6], v[10]); - v[7] = xor(v[7], v[11]); - v[4] = rot12(v[4]); - v[5] = rot12(v[5]); - v[6] = rot12(v[6]); - v[7] = rot12(v[7]); - v[0] = add(v[0], m[SIGMA[r][1] as usize]); - v[1] = add(v[1], m[SIGMA[r][3] as usize]); - v[2] = add(v[2], m[SIGMA[r][5] as usize]); - v[3] = add(v[3], m[SIGMA[r][7] as usize]); - v[0] = add(v[0], v[4]); - v[1] = add(v[1], v[5]); - v[2] = add(v[2], v[6]); - v[3] = add(v[3], v[7]); - v[12] = xor(v[12], v[0]); - v[13] = xor(v[13], v[1]); - v[14] = xor(v[14], v[2]); - v[15] = xor(v[15], v[3]); - v[12] = rot8(v[12]); - v[13] = rot8(v[13]); - v[14] = rot8(v[14]); - v[15] = rot8(v[15]); - v[8] = add(v[8], v[12]); - v[9] = add(v[9], v[13]); - v[10] = add(v[10], v[14]); - v[11] = add(v[11], v[15]); - v[4] = xor(v[4], v[8]); - v[5] = xor(v[5], v[9]); - v[6] = xor(v[6], v[10]); - v[7] = xor(v[7], v[11]); - v[4] = rot7(v[4]); - v[5] = rot7(v[5]); - v[6] = rot7(v[6]); - v[7] = rot7(v[7]); - - v[0] = add(v[0], m[SIGMA[r][8] as usize]); - v[1] = add(v[1], m[SIGMA[r][10] as usize]); - v[2] = add(v[2], m[SIGMA[r][12] as usize]); - v[3] = add(v[3], m[SIGMA[r][14] as usize]); - v[0] = add(v[0], v[5]); - v[1] = add(v[1], v[6]); - v[2] = add(v[2], v[7]); - v[3] = add(v[3], v[4]); - v[15] = xor(v[15], v[0]); - v[12] = xor(v[12], v[1]); - v[13] = xor(v[13], v[2]); - v[14] = xor(v[14], v[3]); - v[15] = rot16(v[15]); - v[12] = rot16(v[12]); - v[13] = rot16(v[13]); - v[14] = rot16(v[14]); - v[10] = add(v[10], v[15]); - v[11] = add(v[11], v[12]); - v[8] = add(v[8], v[13]); - v[9] = add(v[9], v[14]); - v[5] = xor(v[5], v[10]); - v[6] = xor(v[6], v[11]); - v[7] = xor(v[7], v[8]); - v[4] = xor(v[4], v[9]); - v[5] = rot12(v[5]); - v[6] = rot12(v[6]); - v[7] = rot12(v[7]); - v[4] = rot12(v[4]); - v[0] = add(v[0], m[SIGMA[r][9] as usize]); - v[1] = add(v[1], m[SIGMA[r][11] as usize]); - v[2] = add(v[2], m[SIGMA[r][13] as usize]); - v[3] = add(v[3], m[SIGMA[r][15] as usize]); - v[0] = add(v[0], v[5]); - v[1] = add(v[1], v[6]); - v[2] = add(v[2], v[7]); - v[3] = add(v[3], v[4]); - v[15] = xor(v[15], v[0]); - v[12] = xor(v[12], v[1]); - v[13] = xor(v[13], v[2]); - v[14] = xor(v[14], v[3]); - v[15] = rot8(v[15]); - v[12] = rot8(v[12]); - v[13] = rot8(v[13]); - v[14] = rot8(v[14]); - v[10] = add(v[10], v[15]); - v[11] = add(v[11], v[12]); - v[8] = add(v[8], v[13]); - v[9] = add(v[9], v[14]); - v[5] = xor(v[5], v[10]); - v[6] = xor(v[6], v[11]); - v[7] = xor(v[7], v[8]); - v[4] = xor(v[4], v[9]); - v[5] = rot7(v[5]); - v[6] = rot7(v[6]); - v[7] = rot7(v[7]); - v[4] = rot7(v[4]); -} - -/// Transposes input chunks (16 chunks of 16 u32s each), to get 16 __m512i, each -/// representing 16 packed instances of a message word. -/// # Safety -pub unsafe fn transpose_msgs(mut data: [__m512i; 16]) -> [__m512i; 16] { - // Each _m512i chunk contains 16 u32 words. - // Index abcd:xyzw, refers to a specific word in data as follows: - // abcd - chunk index (in base 2) - // xyzw - word offset (in base 2) - // Transpose by applying 4 times the index permutation: - // abcd:xyzw => wabc:dxyz - // In other words, rotate the index to the right by 1. - for _ in 0..4 { - data = [ - _mm512_permutex2var_epi32(data[0], EVENS_CONCAT_EVENS, data[1]), - _mm512_permutex2var_epi32(data[2], EVENS_CONCAT_EVENS, data[3]), - _mm512_permutex2var_epi32(data[4], EVENS_CONCAT_EVENS, data[5]), - _mm512_permutex2var_epi32(data[6], EVENS_CONCAT_EVENS, data[7]), - _mm512_permutex2var_epi32(data[8], EVENS_CONCAT_EVENS, data[9]), - _mm512_permutex2var_epi32(data[10], EVENS_CONCAT_EVENS, data[11]), - _mm512_permutex2var_epi32(data[12], EVENS_CONCAT_EVENS, data[13]), - _mm512_permutex2var_epi32(data[14], EVENS_CONCAT_EVENS, data[15]), - _mm512_permutex2var_epi32(data[0], ODDS_CONCAT_ODDS, data[1]), - _mm512_permutex2var_epi32(data[2], ODDS_CONCAT_ODDS, data[3]), - _mm512_permutex2var_epi32(data[4], ODDS_CONCAT_ODDS, data[5]), - _mm512_permutex2var_epi32(data[6], ODDS_CONCAT_ODDS, data[7]), - _mm512_permutex2var_epi32(data[8], ODDS_CONCAT_ODDS, data[9]), - _mm512_permutex2var_epi32(data[10], ODDS_CONCAT_ODDS, data[11]), - _mm512_permutex2var_epi32(data[12], ODDS_CONCAT_ODDS, data[13]), - _mm512_permutex2var_epi32(data[14], ODDS_CONCAT_ODDS, data[15]), - ]; - } - data -} - -/// Transposes states, from 8 packed words, to get 16 results, each of size 32B. -/// # Safety -pub unsafe fn transpose_states(mut states: [__m512i; 8]) -> [__m512i; 8] { - // Each _m512i chunk contains 16 u32 words. - // Index abc:xyzw, refers to a specific word in data as follows: - // abc - chunk index (in base 2) - // xyzw - word offset (in base 2) - // Transpose by applying 3 times the index permutation: - // abc:xyzw => wab:cxyz - // In other words, rotate the index to the right by 1. - for _ in 0..3 { - states = [ - _mm512_permutex2var_epi32(states[0], EVENS_CONCAT_EVENS, states[1]), - _mm512_permutex2var_epi32(states[2], EVENS_CONCAT_EVENS, states[3]), - _mm512_permutex2var_epi32(states[4], EVENS_CONCAT_EVENS, states[5]), - _mm512_permutex2var_epi32(states[6], EVENS_CONCAT_EVENS, states[7]), - _mm512_permutex2var_epi32(states[0], ODDS_CONCAT_ODDS, states[1]), - _mm512_permutex2var_epi32(states[2], ODDS_CONCAT_ODDS, states[3]), - _mm512_permutex2var_epi32(states[4], ODDS_CONCAT_ODDS, states[5]), - _mm512_permutex2var_epi32(states[6], ODDS_CONCAT_ODDS, states[7]), - ]; - } - states -} - -/// Transposes states, from 8 packed words, to get 16 results, each of size 32B. -/// # Safety -pub unsafe fn untranspose_states(mut states: [__m512i; 8]) -> [__m512i; 8] { - // Each _m512i chunk contains 16 u32 words. - // Index abc:xyzw, refers to a specific word in data as follows: - // abc - chunk index (in base 2) - // xyzw - word offset (in base 2) - // Transpose by applying 3 times the index permutation: - // abc:xyzw => bcx:yzwa - // In other words, rotate the index to the left by 1. - for _ in 0..3 { - states = [ - _mm512_permutex2var_epi32(states[0], LHALF_INTERLEAVE_LHALF, states[4]), - _mm512_permutex2var_epi32(states[0], HHALF_INTERLEAVE_HHALF, states[4]), - _mm512_permutex2var_epi32(states[1], LHALF_INTERLEAVE_LHALF, states[5]), - _mm512_permutex2var_epi32(states[1], HHALF_INTERLEAVE_HHALF, states[5]), - _mm512_permutex2var_epi32(states[2], LHALF_INTERLEAVE_LHALF, states[6]), - _mm512_permutex2var_epi32(states[2], HHALF_INTERLEAVE_HHALF, states[6]), - _mm512_permutex2var_epi32(states[3], LHALF_INTERLEAVE_LHALF, states[7]), - _mm512_permutex2var_epi32(states[3], HHALF_INTERLEAVE_HHALF, states[7]), - ]; - } - states -} - -/// Compress 16 blake2s instances. -/// # Safety -pub unsafe fn compress16( - h_vecs: [__m512i; 8], - msg_vecs: [__m512i; 16], - count_low: __m512i, - count_high: __m512i, - lastblock: __m512i, - lastnode: __m512i, -) -> [__m512i; 8] { - let mut v = [ - h_vecs[0], - h_vecs[1], - h_vecs[2], - h_vecs[3], - h_vecs[4], - h_vecs[5], - h_vecs[6], - h_vecs[7], - set1(IV[0] as i32), - set1(IV[1] as i32), - set1(IV[2] as i32), - set1(IV[3] as i32), - xor(set1(IV[4] as i32), count_low), - xor(set1(IV[5] as i32), count_high), - xor(set1(IV[6] as i32), lastblock), - xor(set1(IV[7] as i32), lastnode), - ]; - - round(&mut v, msg_vecs, 0); - round(&mut v, msg_vecs, 1); - round(&mut v, msg_vecs, 2); - round(&mut v, msg_vecs, 3); - round(&mut v, msg_vecs, 4); - round(&mut v, msg_vecs, 5); - round(&mut v, msg_vecs, 6); - round(&mut v, msg_vecs, 7); - round(&mut v, msg_vecs, 8); - round(&mut v, msg_vecs, 9); - - [ - xor(xor(h_vecs[0], v[0]), v[8]), - xor(xor(h_vecs[1], v[1]), v[9]), - xor(xor(h_vecs[2], v[2]), v[10]), - xor(xor(h_vecs[3], v[3]), v[11]), - xor(xor(h_vecs[4], v[4]), v[12]), - xor(xor(h_vecs[5], v[5]), v[13]), - xor(xor(h_vecs[6], v[6]), v[14]), - xor(xor(h_vecs[7], v[7]), v[15]), - ] -} - -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -#[cfg(test)] -mod tests { - use super::{compress16, set1, transpose_msgs, transpose_states, untranspose_states}; - use crate::core::vcs::blake2s_ref::compress; - - #[test] - fn test_compress16() { - let states: [[u32; 8]; 16] = - std::array::from_fn(|i| std::array::from_fn(|j| (i + j) as u32)); - let msgs: [[u32; 16]; 16] = - std::array::from_fn(|i| std::array::from_fn(|j| (i + j + 20) as u32)); - let count_low = 1; - let count_high = 2; - let lastblock = 3; - let lastnode = 4; - let res_unvectorized = std::array::from_fn(|i| { - compress( - states[i], msgs[i], count_low, count_high, lastblock, lastnode, - ) - }); - - let res_vectorized: [[u32; 8]; 16] = unsafe { - std::mem::transmute(untranspose_states(compress16( - transpose_states(std::mem::transmute(states)), - transpose_msgs(std::mem::transmute(msgs)), - set1(count_low as i32), - set1(count_high as i32), - set1(lastblock as i32), - set1(lastnode as i32), - ))) - }; - - assert_eq!(res_unvectorized, res_vectorized); - } -} diff --git a/crates/prover/src/core/backend/avx512/circle.rs b/crates/prover/src/core/backend/avx512/circle.rs deleted file mode 100644 index 170325a2f..000000000 --- a/crates/prover/src/core/backend/avx512/circle.rs +++ /dev/null @@ -1,458 +0,0 @@ -use bytemuck::{cast_slice, Zeroable}; -use num_traits::One; - -use super::fft::{ifft, CACHED_FFT_LOG_SIZE}; -use super::m31::PackedBaseField; -use super::qm31::PackedSecureField; -use super::{as_cpu_vec, AVX512Backend, K_BLOCK_SIZE, VECS_LOG_SIZE}; -use crate::core::backend::avx512::fft::rfft; -use crate::core::backend::avx512::BaseFieldVec; -use crate::core::backend::{CPUBackend, Col}; -use crate::core::circle::{CirclePoint, Coset}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::{Field, FieldExpOps}; -use crate::core::poly::circle::{ - CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps, -}; -use crate::core::poly::twiddles::TwiddleTree; -use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold}; -use crate::core::poly::BitReversedOrder; - -impl AVX512Backend { - // TODO(Ohad): optimize. - fn twiddle_at(mappings: &[F], mut index: usize) -> F { - debug_assert!( - (1 << mappings.len()) as usize >= index, - "Index out of bounds. mappings log len = {}, index = {index}", - mappings.len().ilog2() - ); - - let mut product = F::one(); - for &num in mappings.iter() { - if index & 1 == 1 { - product *= num; - } - index >>= 1; - if index == 0 { - break; - } - } - - product - } - - // TODO(Ohad): consider moving this to to a more general place. - // Note: CACHED_FFT_LOG_SIZE is specific to the backend. - fn generate_evaluation_mappings(point: CirclePoint, log_size: u32) -> Vec { - // Mappings are the factors used to compute the evaluation twiddle. - // Every twiddle (i) is of the form (m[0])^b_0 * (m[1])^b_1 * ... * (m[log_size - - // 1])^b_log_size. - // Where (m)_j are the mappings, and b_i is the j'th bit of i. - let mut mappings = vec![point.y, point.x]; - let mut x = point.x; - for _ in 2..log_size { - x = CirclePoint::double_x(x); - mappings.push(x); - } - - // The caller function expects the mapping in natural order. i.e. (y,x,h(x),h(h(x)),...). - // If the polynomial is large, the fft does a transpose in the middle in a granularity of 16 - // (avx512). The coefficients would then be in tranposed order of 16-sized chunks. - // i.e. (a_(n-15), a_(n-14), ..., a_(n-1), a_(n-31), ..., a_(n-16), a_(n-32), ...). - // To compute the twiddles in the correct order, we need to transpose the coprresponding - // 'transposed bits' in the mappings. The result order of the mappings would then be - // (y, x, h(x), h^2(x), h^(log_n-1)(x), h^(log_n-2)(x) ...). To avoid code - // complexity for now, we just reverse the mappings, transpose, then reverse back. - // TODO(Ohad): optimize. consider changing the caller to expect the mappings in - // reversed-tranposed order. - if log_size as usize > CACHED_FFT_LOG_SIZE { - mappings.reverse(); - let n = mappings.len(); - let n0 = (n - VECS_LOG_SIZE) / 2; - let n1 = (n - VECS_LOG_SIZE + 1) / 2; - let (ab, c) = mappings.split_at_mut(n1); - let (a, _b) = ab.split_at_mut(n0); - // Swap content of a,c. - a.swap_with_slice(&mut c[0..n0]); - mappings.reverse(); - } - - mappings - } - - // Generates twiddle steps for efficiently computing the twiddles. - // steps[i] = t_i/(t_0*t_1*...*t_i-1). - fn twiddle_steps(mappings: &[F]) -> Vec - where - F: FieldExpOps, - { - let mut denominators: Vec = vec![mappings[0]]; - - for i in 1..mappings.len() { - denominators.push(denominators[i - 1] * mappings[i]); - } - - let mut denom_inverses = vec![F::zero(); denominators.len()]; - F::batch_inverse(&denominators, &mut denom_inverses); - - let mut steps = vec![mappings[0]]; - - mappings - .iter() - .skip(1) - .zip(denom_inverses.iter()) - .for_each(|(&m, &d)| { - steps.push(m * d); - }); - steps.push(F::one()); - steps - } - - // Advances the twiddle by multiplying it by the next step. e.g: - // If idx(t) = 0b100..1010 , then f(t) = t * step[0] - // If idx(t) = 0b100..0111 , then f(t) = t * step[3] - fn advance_twiddle(twiddle: F, steps: &[F], curr_idx: usize) -> F { - twiddle * steps[curr_idx.trailing_ones() as usize] - } -} - -// TODO(spapini): Everything is returned in redundant representation, where values can also be P. -// Decide if and when it's ok and what to do if it's not. -impl PolyOps for AVX512Backend { - // The twiddles type is i32, and not BaseField. This is because the fast AVX mul implementation - // requries one of the numbers to be shifted left by 1 bit. This is not a reduced - // representation of the field. - type Twiddles = Vec; - - fn new_canonical_ordered( - coset: CanonicCoset, - values: Col, - ) -> CircleEvaluation { - // TODO(spapini): Optimize. - let eval = CPUBackend::new_canonical_ordered(coset, as_cpu_vec(values)); - CircleEvaluation::new( - eval.domain, - Col::::from_iter(eval.values), - ) - } - - fn interpolate( - eval: CircleEvaluation, - twiddles: &TwiddleTree, - ) -> CirclePoly { - let mut values = eval.values; - let log_size = values.length.ilog2(); - - let twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles); - - // Safe because [PackedBaseField] is aligned on 64 bytes. - unsafe { - ifft::ifft( - std::mem::transmute(values.data.as_mut_ptr()), - &twiddles, - log_size as usize, - ); - } - - // TODO(spapini): Fuse this multiplication / rotation. - let inv = BaseField::from_u32_unchecked(eval.domain.size() as u32).inverse(); - let inv = PackedBaseField::from_array([inv; 16]); - for x in values.data.iter_mut() { - *x *= inv; - } - - CirclePoly::new(values) - } - - fn eval_at_point(poly: &CirclePoly, point: CirclePoint) -> SecureField { - // If the polynomial is small, fallback to evaluate directly. - // TODO(Ohad): it's possible to avoid falling back. Consider fixing. - if poly.log_size() <= 8 { - return slow_eval_at_point(poly, point); - } - - let mappings = Self::generate_evaluation_mappings(point, poly.log_size()); - - // 8 lowest mappings produce the first 2^8 twiddles. Separate to optimize each calculation. - let (map_low, map_high) = mappings.split_at(4); - let twiddle_lows = - PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_low, i))); - let (map_mid, map_high) = map_high.split_at(4); - let twiddle_mids = - PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_mid, i))); - - // Compute the high twiddle steps. - let twiddle_steps = Self::twiddle_steps(map_high); - - // Every twiddle is a product of mappings that correspond to '1's in the bit representation - // of the current index. For every 2^n alligned chunk of 2^n elements, the twiddle - // array is the same, denoted twiddle_low. Use this to compute sums of (coeff * - // twiddle_high) mod 2^n, then multiply by twiddle_low, and sum to get the final result. - let mut sum = PackedSecureField::zeroed(); - let mut twiddle_high = SecureField::one(); - for (i, coeff_chunk) in poly.coeffs.data.array_chunks::().enumerate() { - // For every chunk of 2 ^ 4 * 2 ^ 4 = 2 ^ 8 elements, the twiddle high is the same. - // Multiply it by every mid twiddle factor to get the factors for the current chunk. - let high_twiddle_factors = - (PackedSecureField::broadcast(twiddle_high) * twiddle_mids).to_array(); - - // Sum the coefficients multiplied by each corrseponsing twiddle. Result is effectivley - // an array[16] where the value at index 'i' is the sum of all coefficients at indices - // that are i mod 16. - for (&packed_coeffs, &mid_twiddle) in - coeff_chunk.iter().zip(high_twiddle_factors.iter()) - { - sum += PackedSecureField::broadcast(mid_twiddle).mul_packed_m31(packed_coeffs); - } - - // Advance twiddle high. - twiddle_high = Self::advance_twiddle(twiddle_high, &twiddle_steps, i); - } - - (sum * twiddle_lows).pointwise_sum() - } - - fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly { - // TODO(spapini): Optimize or get rid of extend. - poly.evaluate(CanonicCoset::new(log_size).circle_domain()) - .interpolate() - } - - fn evaluate( - poly: &CirclePoly, - domain: CircleDomain, - twiddles: &TwiddleTree, - ) -> CircleEvaluation { - // TODO(spapini): Precompute twiddles. - // TODO(spapini): Handle small cases. - let log_size = domain.log_size() as usize; - let fft_log_size = poly.log_size() as usize; - assert!( - log_size >= fft_log_size, - "Can only evaluate on larger domains" - ); - - let twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles); - - // Evaluate on a big domains by evaluating on several subdomains. - let log_subdomains = log_size - fft_log_size; - - // Alllocate the destination buffer without initializing. - let mut values = Vec::with_capacity(domain.size() >> VECS_LOG_SIZE); - #[allow(clippy::uninit_vec)] - unsafe { - values.set_len(domain.size() >> VECS_LOG_SIZE) - }; - - for i in 0..(1 << log_subdomains) { - // The subdomain twiddles are a slice of the large domain twiddles. - let subdomain_twiddles = (0..(fft_log_size - 1)) - .map(|layer_i| { - &twiddles[layer_i] - [i << (fft_log_size - 2 - layer_i)..(i + 1) << (fft_log_size - 2 - layer_i)] - }) - .collect::>(); - - // FFT from the coefficients buffer to the values chunk. - unsafe { - rfft::fft( - std::mem::transmute(poly.coeffs.data.as_ptr()), - std::mem::transmute( - values[i << (fft_log_size - VECS_LOG_SIZE) - ..(i + 1) << (fft_log_size - VECS_LOG_SIZE)] - .as_mut_ptr(), - ), - &subdomain_twiddles, - fft_log_size, - ); - } - } - - CircleEvaluation::new( - domain, - BaseFieldVec { - data: values, - length: domain.size(), - }, - ) - } - - fn precompute_twiddles(coset: Coset) -> TwiddleTree { - let mut twiddles = Vec::with_capacity(coset.size()); - let mut itwiddles = Vec::with_capacity(coset.size()); - - // TODO(spapini): Optimize. - for layer in &rfft::get_twiddle_dbls(coset) { - twiddles.extend(layer); - } - // Pad by any value, to make the size a power of 2. - twiddles.push(1); - assert_eq!(twiddles.len(), coset.size()); - for layer in &ifft::get_itwiddle_dbls(coset) { - itwiddles.extend(layer); - } - // Pad by any value, to make the size a power of 2. - itwiddles.push(1); - assert_eq!(itwiddles.len(), coset.size()); - - TwiddleTree { - root_coset: coset, - twiddles, - itwiddles, - } - } -} - -fn slow_eval_at_point( - poly: &CirclePoly, - point: CirclePoint, -) -> SecureField { - let mut mappings = vec![point.y, point.x]; - let mut x = point.x; - for _ in 2..poly.log_size() { - x = CirclePoint::double_x(x); - mappings.push(x); - } - mappings.reverse(); - - // If the polynomial is large, the fft does a transpose in the middle. - if poly.log_size() as usize > CACHED_FFT_LOG_SIZE { - let n = mappings.len(); - let n0 = (n - VECS_LOG_SIZE) / 2; - let n1 = (n - VECS_LOG_SIZE + 1) / 2; - let (ab, c) = mappings.split_at_mut(n1); - let (a, _b) = ab.split_at_mut(n0); - // Swap content of a,c. - a.swap_with_slice(&mut c[0..n0]); - } - fold(cast_slice::<_, BaseField>(&poly.coeffs.data), &mappings) -} - -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -#[cfg(test)] -mod tests { - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use crate::core::backend::avx512::circle::slow_eval_at_point; - use crate::core::backend::avx512::fft::{CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; - use crate::core::backend::avx512::AVX512Backend; - use crate::core::backend::Column; - use crate::core::circle::CirclePoint; - use crate::core::fields::m31::BaseField; - use crate::core::fields::qm31::SecureField; - use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, PolyOps}; - use crate::core::poly::{BitReversedOrder, NaturalOrder}; - - #[test] - fn test_interpolate_and_eval() { - for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 4) { - let domain = CanonicCoset::new(log_size as u32).circle_domain(); - let evaluation = CircleEvaluation::::new( - domain, - (0..(1 << log_size)) - .map(BaseField::from_u32_unchecked) - .collect(), - ); - let poly = evaluation.clone().interpolate(); - let evaluation2 = poly.evaluate(domain); - assert_eq!(evaluation.values.to_cpu(), evaluation2.values.to_cpu()); - } - } - - #[test] - fn test_eval_extension() { - for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 4) { - let log_size = log_size as u32; - let domain = CanonicCoset::new(log_size).circle_domain(); - let domain_ext = CanonicCoset::new(log_size + 3).circle_domain(); - let evaluation = CircleEvaluation::::new( - domain, - (0..(1 << log_size)) - .map(BaseField::from_u32_unchecked) - .collect(), - ); - let poly = evaluation.clone().interpolate(); - let evaluation2 = poly.evaluate(domain_ext); - let poly2 = evaluation2.interpolate(); - assert_eq!( - poly.extend(log_size + 3).coeffs.to_cpu(), - poly2.coeffs.to_cpu() - ); - } - } - - #[test] - fn test_eval_at_point() { - for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 4) { - let domain = CanonicCoset::new(log_size as u32).circle_domain(); - let evaluation = CircleEvaluation::::new( - domain, - (0..(1 << log_size)) - .map(BaseField::from_u32_unchecked) - .collect(), - ); - let poly = evaluation.bit_reverse().interpolate(); - for i in [0, 1, 3, 1 << (log_size - 1), 1 << (log_size - 2)] { - let p = domain.at(i); - assert_eq!( - poly.eval_at_point(p.into_ef()), - BaseField::from_u32_unchecked(i as u32).into(), - "log_size = {log_size} i = {i}" - ); - } - } - } - - #[test] - fn test_circle_poly_extend() { - for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 2) { - let log_size = log_size as u32; - let poly = CirclePoly::::new( - (0..(1 << log_size)) - .map(BaseField::from_u32_unchecked) - .collect(), - ); - let eval0 = poly.evaluate(CanonicCoset::new(log_size + 2).circle_domain()); - let eval1 = poly - .extend(log_size + 2) - .evaluate(CanonicCoset::new(log_size + 2).circle_domain()); - - assert_eq!(eval0.values.to_cpu(), eval1.values.to_cpu()); - } - } - - #[test] - fn test_eval_securefield() { - use crate::core::backend::avx512::fft::MIN_FFT_LOG_SIZE; - let mut rng = SmallRng::seed_from_u64(0); - - for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 2) { - let domain = CanonicCoset::new(log_size as u32).circle_domain(); - let evaluation = CircleEvaluation::::new( - domain, - (0..(1 << log_size)) - .map(BaseField::from_u32_unchecked) - .collect(), - ); - let poly = evaluation.bit_reverse().interpolate(); - - let x: SecureField = rng.gen(); - let y: SecureField = rng.gen(); - - let p = CirclePoint { x, y }; - - assert_eq!( - ::eval_at_point(&poly, p), - slow_eval_at_point(&poly, p), - "log_size = {log_size}" - ); - - println!( - "log_size = {log_size} passed, eval{}", - ::eval_at_point(&poly, p) - ); - } - } -} diff --git a/crates/prover/src/core/backend/avx512/cm31.rs b/crates/prover/src/core/backend/avx512/cm31.rs deleted file mode 100644 index 1122bdb85..000000000 --- a/crates/prover/src/core/backend/avx512/cm31.rs +++ /dev/null @@ -1,133 +0,0 @@ -use std::ops::{Add, Mul, MulAssign, Neg, Sub}; - -use num_traits::{One, Zero}; - -use super::m31::{PackedBaseField, K_BLOCK_SIZE}; -use crate::core::fields::cm31::CM31; -use crate::core::fields::FieldExpOps; - -/// AVX implementation for the complex extension field of M31. -/// See [crate::core::fields::cm31::CM31] for more information. -#[derive(Copy, Clone, Debug)] -pub struct PackedCM31(pub [PackedBaseField; 2]); -impl PackedCM31 { - pub fn broadcast(value: CM31) -> Self { - Self([ - PackedBaseField::broadcast(value.0), - PackedBaseField::broadcast(value.1), - ]) - } - pub fn a(&self) -> PackedBaseField { - self.0[0] - } - pub fn b(&self) -> PackedBaseField { - self.0[1] - } - pub fn to_array(&self) -> [CM31; K_BLOCK_SIZE] { - std::array::from_fn(|i| CM31(self.0[0].to_array()[i], self.0[1].to_array()[i])) - } -} -impl Add for PackedCM31 { - type Output = Self; - fn add(self, rhs: Self) -> Self::Output { - Self([self.a() + rhs.a(), self.b() + rhs.b()]) - } -} -impl Sub for PackedCM31 { - type Output = Self; - fn sub(self, rhs: Self) -> Self::Output { - Self([self.a() - rhs.a(), self.b() - rhs.b()]) - } -} -impl Mul for PackedCM31 { - type Output = Self; - fn mul(self, rhs: Self) -> Self::Output { - // Compute using Karatsuba. - let ac = self.a() * rhs.a(); - let bd = self.b() * rhs.b(); - // Computes (a + b) * (c + d). - let ab_t_cd = (self.a() + self.b()) * (rhs.a() + rhs.b()); - // (ac - bd) + (ad + bc)i. - Self([ac - bd, ab_t_cd - ac - bd]) - } -} -impl Zero for PackedCM31 { - fn zero() -> Self { - Self([PackedBaseField::zero(), PackedBaseField::zero()]) - } - fn is_zero(&self) -> bool { - self.a().is_zero() && self.b().is_zero() - } -} -impl One for PackedCM31 { - fn one() -> Self { - Self([PackedBaseField::one(), PackedBaseField::zero()]) - } -} -impl MulAssign for PackedCM31 { - fn mul_assign(&mut self, rhs: Self) { - *self = *self * rhs; - } -} -impl Neg for PackedCM31 { - type Output = Self; - fn neg(self) -> Self::Output { - Self([-self.a(), -self.b()]) - } -} -impl FieldExpOps for PackedCM31 { - fn inverse(&self) -> Self { - assert!(!self.is_zero(), "0 has no inverse"); - // 1 / (a + bi) = (a - bi) / (a^2 + b^2). - Self([self.a(), -self.b()]) * (self.a().square() + self.b().square()).inverse() - } -} - -impl Add for PackedCM31 { - type Output = Self; - fn add(self, rhs: PackedBaseField) -> Self::Output { - Self([self.a() + rhs, self.b()]) - } -} -impl Sub for PackedCM31 { - type Output = Self; - fn sub(self, rhs: PackedBaseField) -> Self::Output { - Self([self.a() - rhs, self.b()]) - } -} -impl Mul for PackedCM31 { - type Output = Self; - fn mul(self, rhs: PackedBaseField) -> Self::Output { - Self([self.a() * rhs, self.b() * rhs]) - } -} - -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -#[cfg(test)] -mod tests { - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use super::*; - - #[test] - fn test_cm31avx512_basic_ops() { - let mut rng = SmallRng::seed_from_u64(0); - let x = PackedCM31([ - PackedBaseField::from_array(rng.gen()), - PackedBaseField::from_array(rng.gen()), - ]); - let y = PackedCM31([ - PackedBaseField::from_array(rng.gen()), - PackedBaseField::from_array(rng.gen()), - ]); - let sum = x + y; - let diff = x - y; - let prod = x * y; - for i in 0..16 { - assert_eq!(sum.to_array()[i], x.to_array()[i] + y.to_array()[i]); - assert_eq!(diff.to_array()[i], x.to_array()[i] - y.to_array()[i]); - assert_eq!(prod.to_array()[i], x.to_array()[i] * y.to_array()[i]); - } - } -} diff --git a/crates/prover/src/core/backend/avx512/fft/ifft.rs b/crates/prover/src/core/backend/avx512/fft/ifft.rs deleted file mode 100644 index 35a458d02..000000000 --- a/crates/prover/src/core/backend/avx512/fft/ifft.rs +++ /dev/null @@ -1,731 +0,0 @@ -//! Inverse fft. - -use std::arch::x86_64::{ - __m512i, _mm512_broadcast_i32x4, _mm512_mul_epu32, _mm512_permutex2var_epi32, - _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, -}; - -use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS}; -use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; -use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE}; -use crate::core::circle::Coset; -use crate::core::fields::FieldExpOps; -use crate::core::utils::bit_reverse; - -/// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values. -/// -/// # Safety -/// This function is unsafe because it takes a raw pointer to i32 values. -/// `values` must be aligned to 64 bytes. -/// -/// # Arguments -/// * `values`: A mutable pointer to the values on which the ICFFT is to be performed. -/// * `twiddle_dbl`: A reference to the doubles of the twiddle factors. -/// * `log_n_elements`: The log of the number of elements in the `values` array. -/// -/// # Panics -/// This function will panic if `log_n_elements` is less than `MIN_FFT_LOG_SIZE`. -pub unsafe fn ifft(values: *mut i32, twiddle_dbl: &[&[i32]], log_n_elements: usize) { - assert!(log_n_elements >= MIN_FFT_LOG_SIZE); - let log_n_vecs = log_n_elements - VECS_LOG_SIZE; - if log_n_elements <= CACHED_FFT_LOG_SIZE { - ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements); - return; - } - - let fft_layers_pre_transpose = log_n_vecs.div_ceil(2); - let fft_layers_post_transpose = log_n_vecs / 2; - ifft_lower_with_vecwise( - values, - &twiddle_dbl[..(3 + fft_layers_pre_transpose)], - log_n_elements, - fft_layers_pre_transpose + VECS_LOG_SIZE, - ); - transpose_vecs(values, log_n_vecs); - ifft_lower_without_vecwise( - values, - &twiddle_dbl[(3 + fft_layers_pre_transpose)..], - log_n_elements, - fft_layers_post_transpose, - ); -} - -/// Computes partial ifft on `2^log_size` M31 elements. -/// Parameters: -/// values - Pointer to the entire value array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each layer of the the ifft. -/// layer i holds 2^(log_size - 1 - i) twiddles. -/// log_size - The log of the number of number of M31 elements in the array. -/// fft_layers - The number of ifft layers to apply, out of log_size. -/// # Safety -/// `values` must be aligned to 64 bytes. -/// `log_size` must be at least 5. -/// `fft_layers` must be at least 5. -pub unsafe fn ifft_lower_with_vecwise( - values: *mut i32, - twiddle_dbl: &[&[i32]], - log_size: usize, - fft_layers: usize, -) { - const VECWISE_FFT_BITS: usize = VECS_LOG_SIZE + 1; - assert!(log_size >= VECWISE_FFT_BITS); - - assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - - for index_h in 0..(1 << (log_size - fft_layers)) { - ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h); - for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) { - match fft_layers - layer { - 1 => { - ifft1_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); - } - 2 => { - ifft2_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); - } - _ => { - ifft3_loop( - values, - &twiddle_dbl[(layer - 1)..], - fft_layers - layer - 3, - layer, - index_h, - ); - } - } - } - } -} - -/// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits -/// of the index). -/// Parameters: -/// values - Pointer to the entire value array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each layer of the the ifft. -/// log_size - The log of the number of number of M31 elements in the array. -/// fft_layers - The number of ifft layers to apply, out of log_size - VEC_LOG_SIZE. -/// -/// # Safety -/// `values` must be aligned to 64 bytes. -/// `log_size` must be at least 4. -/// `fft_layers` must be at least 4. -pub unsafe fn ifft_lower_without_vecwise( - values: *mut i32, - twiddle_dbl: &[&[i32]], - log_size: usize, - fft_layers: usize, -) { - assert!(log_size >= VECS_LOG_SIZE); - - for index_h in 0..(1 << (log_size - fft_layers - VECS_LOG_SIZE)) { - for layer in (0..fft_layers).step_by(3) { - let fixed_layer = layer + VECS_LOG_SIZE; - match fft_layers - layer { - 1 => { - ifft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); - } - 2 => { - ifft2_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); - } - _ => { - ifft3_loop( - values, - &twiddle_dbl[layer..], - fft_layers - layer - 3, - fixed_layer, - index_h, - ); - } - } - } - } -} - -/// Runs the first 5 ifft layers across the entire array. -/// Parameters: -/// values - Pointer to the entire value array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each of the 5 ifft layers. -/// high_bits - The number of bits this loops needs to run on. -/// index_h - The higher part of the index, iterated by the caller. -/// # Safety -pub unsafe fn ifft_vecwise_loop( - values: *mut i32, - twiddle_dbl: &[&[i32]], - loop_bits: usize, - index_h: usize, -) { - for index_l in 0..(1 << loop_bits) { - let index = (index_h << loop_bits) + index_l; - let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const()); - let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const()); - (val0, val1) = vecwise_ibutterflies( - val0, - val1, - std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)), - std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), - std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), - ); - (val0, val1) = avx_ibutterfly( - val0, - val1, - _mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)), - ); - val0.store(values.add(index * 32)); - val1.store(values.add(index * 32 + 16)); - } -} - -/// Runs 3 ifft layers across the entire array. -/// Parameters: -/// values - Pointer to the entire value array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each of the 3 ifft layers. -/// loop_bits - The number of bits this loops needs to run on. -/// layer - The layer number of the first ifft layer to apply. -/// The layers `layer`, `layer + 1`, `layer + 2` are applied. -/// index_h - The higher part of the index, iterated by the caller. -/// # Safety -pub unsafe fn ifft3_loop( - values: *mut i32, - twiddle_dbl: &[&[i32]], - loop_bits: usize, - layer: usize, - index_h: usize, -) { - for index_l in 0..(1 << loop_bits) { - let index = (index_h << loop_bits) + index_l; - let offset = index << (layer + 3); - for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { - ifft3( - values, - offset + l, - layer, - std::array::from_fn(|i| { - *twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1)) - }), - std::array::from_fn(|i| { - *twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1)) - }), - std::array::from_fn(|i| { - *twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1)) - }), - ); - } - } -} - -/// Runs 2 ifft layers across the entire array. -/// Parameters: -/// values - Pointer to the entire value array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each of the 2 ifft layers. -/// loop_bits - The number of bits this loops needs to run on. -/// layer - The layer number of the first ifft layer to apply. -/// The layers `layer`, `layer + 1` are applied. -/// index - The index, iterated by the caller. -/// # Safety -unsafe fn ifft2_loop(values: *mut i32, twiddle_dbl: &[&[i32]], layer: usize, index: usize) { - let offset = index << (layer + 2); - for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { - ifft2( - values, - offset + l, - layer, - std::array::from_fn(|i| { - *twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1)) - }), - std::array::from_fn(|i| { - *twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1)) - }), - ); - } -} - -/// Runs 1 ifft layer across the entire array. -/// Parameters: -/// values - Pointer to the entire value array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for the ifft layer. -/// layer - The layer number of the ifft layer to apply. -/// index_h - The higher part of the index, iterated by the caller. -/// # Safety -unsafe fn ifft1_loop(values: *mut i32, twiddle_dbl: &[&[i32]], layer: usize, index: usize) { - let offset = index << (layer + 1); - for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { - ifft1( - values, - offset + l, - layer, - std::array::from_fn(|i| { - *twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1)) - }), - ); - } -} - -/// Computes the ibutterfly operation for packed M31 elements. -/// val0 + val1, t (val0 - val1). -/// val0, val1 are packed M31 elements. 16 M31 words at each. -/// Each value is assumed to be in unreduced form, [0, P] including P. -/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in unreduced form. -/// # Safety -/// This function is safe. -pub unsafe fn avx_ibutterfly( - val0: PackedBaseField, - val1: PackedBaseField, - twiddle_dbl: __m512i, -) -> (PackedBaseField, PackedBaseField) { - let r0 = val0 + val1; - let r1 = val0 - val1; - - // Extract the even and odd parts of r1 and twiddle_dbl, and spread as 8 64bit values. - let r1_e = r1.0; - let r1_o = _mm512_srli_epi64(r1.0, 32); - let twiddle_dbl_e = twiddle_dbl; - let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32); - - // To compute prod = r1 * twiddle start by multiplying - // r1_e/o by twiddle_dbl_e/o. - let prod_e_dbl = _mm512_mul_epu32(r1_e, twiddle_dbl_e); - let prod_o_dbl = _mm512_mul_epu32(r1_o, twiddle_dbl_o); - - // The result of a multiplication holds r1*twiddle_dbl in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_e_dbl - |0|prod_e_h|prod_e_l|0| - // prod_o_dbl - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl); - // prod_ls - |prod_o_l|0|prod_e_l|0| - - // Divide by 2: - let prod_ls = _mm512_srli_epi64(prod_ls, 1); - // prod_ls - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl); - // prod_hs - |0|prod_o_h|0|prod_e_h| - - let prod = PackedBaseField(prod_ls) + PackedBaseField(prod_hs); - - (r0, prod) -} - -/// Runs ifft on 2 vectors of 16 M31 elements. -/// This amounts to 4 butterfly layers, each with 16 butterflies. -/// Each of the vectors represents a bit reversed evaluation. -/// Each value in a vectors is in unreduced form: [0, P] including P. -/// Takes 3 twiddle arrays, one for each layer after the first, holding the double of the -/// corresponding twiddle. -/// The first layer's twiddles (lower bit of the index) are computed from the second layer's -/// twiddles. The second layer takes 8 twiddles. -/// The third layer takes 4 twiddles. -/// The fourth layer takes 2 twiddles. -/// # Safety -/// This function is safe. -pub unsafe fn vecwise_ibutterflies( - mut val0: PackedBaseField, - mut val1: PackedBaseField, - twiddle1_dbl: [i32; 8], - twiddle2_dbl: [i32; 4], - twiddle3_dbl: [i32; 2], -) -> (PackedBaseField, PackedBaseField) { - // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. - - // Each avx_ibutterfly take 2 512-bit registers, and does 16 butterflies element by element. - // We need to permute the 512-bit registers to get the right order for the butterflies. - // Denote the index of the 16 M31 elements in register i as i:abcd. - // At each layer we apply the following permutation to the index: - // i:abcd => d:iabc - // This is how it looks like at each iteration. - // i:abcd - // d:iabc - // ifft on d - // c:diab - // ifft on c - // b:cdia - // ifft on b - // a:bcid - // ifft on a - // i:abcd - - let (t0, t1) = compute_first_twiddles(twiddle1_dbl); - - // Apply the permutation, resulting in indexing d:iabc. - (val0, val1) = val0.deinterleave_with(val1); - (val0, val1) = avx_ibutterfly(val0, val1, t0); - - // Apply the permutation, resulting in indexing c:diab. - (val0, val1) = val0.deinterleave_with(val1); - (val0, val1) = avx_ibutterfly(val0, val1, t1); - - // The twiddles for layer 2 are replicated in the following pattern: - // 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 - let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl)); - // Apply the permutation, resulting in indexing b:cdia. - (val0, val1) = val0.deinterleave_with(val1); - (val0, val1) = avx_ibutterfly(val0, val1, t); - - // The twiddles for layer 3 are replicated in the following pattern: - // 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 - let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl)); - // Apply the permutation, resulting in indexing a:bcid. - (val0, val1) = val0.deinterleave_with(val1); - (val0, val1) = avx_ibutterfly(val0, val1, t); - - // Apply the permutation, resulting in indexing i:abcd. - val0.deinterleave_with(val1) -} - -/// Returns the line twiddles (x points) for an ifft on a coset. -pub fn get_itwiddle_dbls(mut coset: Coset) -> Vec> { - let mut res = vec![]; - for _ in 0..coset.log_size() { - res.push( - coset - .iter() - .take(coset.size() / 2) - .map(|p| (p.x.inverse().0 * 2) as i32) - .collect::>(), - ); - bit_reverse(res.last_mut().unwrap()); - coset = coset.double(); - } - - res -} - -/// Applies 3 ibutterfly layers on 8 vectors of 16 M31 elements. -/// Vectorized over the 16 elements of the vectors. -/// Used for radix-8 ifft. -/// Each butterfly layer, has 3 AVX butterflies. -/// Total of 12 AVX butterflies. -/// Parameters: -/// values - Pointer to the entire value array. -/// offset - The offset of the first value in the array. -/// log_step - The log of the distance in the array, in M31 elements, between each pair of -/// values that need to be transformed. For layer i this is i - 4. -/// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of ibutterflies. -/// Each layer has 4/2/1 twiddles. -/// # Safety -pub unsafe fn ifft3( - values: *mut i32, - offset: usize, - log_step: usize, - twiddles_dbl0: [i32; 4], - twiddles_dbl1: [i32; 2], - twiddles_dbl2: [i32; 1], -) { - // Load the 8 AVX vectors from the array. - let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); - let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); - let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); - let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const()); - let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const()); - let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const()); - let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const()); - - // Apply the first layer of ibutterflies. - (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - (val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); - (val4, val5) = avx_ibutterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2])); - (val6, val7) = avx_ibutterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3])); - - // Apply the second layer of ibutterflies. - (val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); - (val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); - (val4, val6) = avx_ibutterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1])); - (val5, val7) = avx_ibutterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1])); - - // Apply the third layer of ibutterflies. - (val0, val4) = avx_ibutterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0])); - (val1, val5) = avx_ibutterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0])); - (val2, val6) = avx_ibutterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0])); - (val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0])); - - // Store the 8 AVX vectors back to the array. - val0.store(values.add(offset + (0 << log_step))); - val1.store(values.add(offset + (1 << log_step))); - val2.store(values.add(offset + (2 << log_step))); - val3.store(values.add(offset + (3 << log_step))); - val4.store(values.add(offset + (4 << log_step))); - val5.store(values.add(offset + (5 << log_step))); - val6.store(values.add(offset + (6 << log_step))); - val7.store(values.add(offset + (7 << log_step))); -} - -/// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements. -/// Vectorized over the 16 elements of the vectors. -/// Used for radix-4 ifft. -/// Each ibutterfly layer, has 2 AVX butterflies. -/// Total of 4 AVX butterflies. -/// Parameters: -/// values - Pointer to the entire value array. -/// offset - The offset of the first value in the array. -/// log_step - The log of the distance in the array, in M31 elements, between each pair of -/// values that need to be transformed. For layer i this is i - 4. -/// twiddles_dbl0/1 - The double of the twiddles for the 2 layers of ibutterflies. -/// Each layer has 2/1 twiddles. -/// # Safety -pub unsafe fn ifft2( - values: *mut i32, - offset: usize, - log_step: usize, - twiddles_dbl0: [i32; 2], - twiddles_dbl1: [i32; 1], -) { - // Load the 4 AVX vectors from the array. - let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); - let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); - let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); - - // Apply the first layer of butterflies. - (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - (val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); - - // Apply the second layer of butterflies. - (val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); - (val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); - - // Store the 4 AVX vectors back to the array. - val0.store(values.add(offset + (0 << log_step))); - val1.store(values.add(offset + (1 << log_step))); - val2.store(values.add(offset + (2 << log_step))); - val3.store(values.add(offset + (3 << log_step))); -} - -/// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements. -/// Vectorized over the 16 elements of the vectors. -/// Parameters: -/// values - Pointer to the entire value array. -/// offset - The offset of the first value in the array. -/// log_step - The log of the distance in the array, in M31 elements, between each pair of -/// values that need to be transformed. For layer i this is i - 4. -/// twiddles_dbl0 - The double of the twiddles for the ibutterfly layer. -/// # Safety -pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) { - // Load the 2 AVX vectors from the array. - let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); - - (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - - // Store the 2 AVX vectors back to the array. - val0.store(values.add(offset + (0 << log_step))); - val1.store(values.add(offset + (1 << log_step))); -} - -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -#[cfg(test)] -mod tests { - use std::arch::x86_64::{_mm512_add_epi32, _mm512_setr_epi32}; - - use super::*; - use crate::core::backend::avx512::m31::PackedBaseField; - use crate::core::backend::avx512::BaseFieldVec; - use crate::core::backend::cpu::CPUCircleEvaluation; - use crate::core::backend::Column; - use crate::core::fft::ibutterfly; - use crate::core::fields::m31::BaseField; - use crate::core::poly::circle::{CanonicCoset, CircleDomain}; - - #[test] - fn test_ibutterfly() { - unsafe { - let val0 = PackedBaseField(_mm512_setr_epi32( - 2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - )); - let val1 = PackedBaseField(_mm512_setr_epi32( - 3, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - )); - let twiddle = _mm512_setr_epi32( - 1177558791, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - ); - let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); - let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl); - - let val0: [BaseField; 16] = std::mem::transmute(val0); - let val1: [BaseField; 16] = std::mem::transmute(val1); - let twiddle: [BaseField; 16] = std::mem::transmute(twiddle); - let r0: [BaseField; 16] = std::mem::transmute(r0); - let r1: [BaseField; 16] = std::mem::transmute(r1); - - for i in 0..16 { - let mut x = val0[i]; - let mut y = val1[i]; - let twiddle = twiddle[i]; - ibutterfly(&mut x, &mut y, twiddle); - assert_eq!(x, r0[i]); - assert_eq!(y, r1[i]); - } - } - } - - #[test] - fn test_ifft3() { - unsafe { - let mut values: Vec = (0..8) - .map(|i| { - PackedBaseField::from_array(std::array::from_fn(|_| { - BaseField::from_u32_unchecked(i) - })) - }) - .collect(); - let twiddles0 = [32, 33, 34, 35]; - let twiddles1 = [36, 37]; - let twiddles2 = [38]; - let twiddles0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); - let twiddles1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); - let twiddles2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); - ifft3( - std::mem::transmute(values.as_mut_ptr()), - 0, - VECS_LOG_SIZE, - twiddles0_dbl, - twiddles1_dbl, - twiddles2_dbl, - ); - - let expected: [u32; 8] = std::array::from_fn(|i| i as u32); - let mut expected: [BaseField; 8] = std::mem::transmute(expected); - let twiddles0: [BaseField; 4] = std::mem::transmute(twiddles0); - let twiddles1: [BaseField; 2] = std::mem::transmute(twiddles1); - let twiddles2: [BaseField; 1] = std::mem::transmute(twiddles2); - for i in 0..8 { - let j = i ^ 1; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - ibutterfly(&mut v0, &mut v1, twiddles0[i / 2]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - let j = i ^ 2; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - ibutterfly(&mut v0, &mut v1, twiddles1[i / 4]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - let j = i ^ 4; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - ibutterfly(&mut v0, &mut v1, twiddles2[0]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - assert_eq!(values[i].to_array()[0], expected[i]); - } - } - } - - fn ref_ifft(domain: CircleDomain, values: Vec) -> Vec { - let eval = CPUCircleEvaluation::new(domain, values); - let mut expected_coeffs = eval.interpolate().coeffs; - for x in expected_coeffs.iter_mut() { - *x *= BaseField::from_u32_unchecked(domain.size() as u32); - } - expected_coeffs - } - - #[test] - fn test_vecwise_ibutterflies() { - let domain = CanonicCoset::new(5).circle_domain(); - let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); - assert_eq!(twiddle_dbls.len(), 4); - let values0: [i32; 16] = std::array::from_fn(|i| i as i32); - let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32); - let result: [BaseField; 32] = unsafe { - let (val0, val1) = vecwise_ibutterflies( - std::mem::transmute(values0), - std::mem::transmute(values1), - twiddle_dbls[0].clone().try_into().unwrap(), - twiddle_dbls[1].clone().try_into().unwrap(), - twiddle_dbls[2].clone().try_into().unwrap(), - ); - let (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddle_dbls[3][0])); - std::mem::transmute([val0, val1]) - }; - - // ref. - let mut values = values0.to_vec(); - values.extend_from_slice(&values1); - let expected = ref_ifft(domain, values.into_iter().map(BaseField::from).collect()); - - // Compare. - for i in 0..32 { - assert_eq!(result[i], expected[i]); - } - } - - #[test] - fn test_ifft_lower_with_vecwise() { - for log_size in 5..12 { - let domain = CanonicCoset::new(log_size).circle_domain(); - let values = (0..domain.size()) - .map(|i| BaseField::from_u32_unchecked(i as u32)) - .collect::>(); - let expected_coeffs = ref_ifft(domain, values.clone()); - - // Compute. - let mut values = BaseFieldVec::from_iter(values); - let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); - - unsafe { - ifft_lower_with_vecwise( - std::mem::transmute(values.data.as_mut_ptr()), - &twiddle_dbls - .iter() - .map(|x| x.as_slice()) - .collect::>(), - log_size as usize, - log_size as usize, - ); - - // Compare. - assert_eq!(values.to_cpu(), expected_coeffs); - } - } - } - - fn run_ifft_full_test(log_size: u32) { - let domain = CanonicCoset::new(log_size).circle_domain(); - let values = (0..domain.size()) - .map(|i| BaseField::from_u32_unchecked(i as u32)) - .collect::>(); - let expected_coeffs = ref_ifft(domain, values.clone()); - - // Compute. - let mut values = BaseFieldVec::from_iter(values); - let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); - - unsafe { - ifft( - std::mem::transmute(values.data.as_mut_ptr()), - &twiddle_dbls - .iter() - .map(|x| x.as_slice()) - .collect::>(), - log_size as usize, - ); - transpose_vecs( - std::mem::transmute(values.data.as_mut_ptr()), - (log_size - 4) as usize, - ); - - // Compare. - assert_eq!(values.to_cpu(), expected_coeffs); - } - } - - #[test] - fn test_ifft_full() { - for i in (CACHED_FFT_LOG_SIZE + 1)..(CACHED_FFT_LOG_SIZE + 3) { - run_ifft_full_test(i as u32); - } - } -} diff --git a/crates/prover/src/core/backend/avx512/fft/mod.rs b/crates/prover/src/core/backend/avx512/fft/mod.rs deleted file mode 100644 index 37dbbd686..000000000 --- a/crates/prover/src/core/backend/avx512/fft/mod.rs +++ /dev/null @@ -1,100 +0,0 @@ -use std::arch::x86_64::{ - __m512i, _mm512_broadcast_i64x4, _mm512_load_epi32, _mm512_permutexvar_epi32, - _mm512_store_epi32, _mm512_xor_epi32, -}; - -pub mod ifft; -pub mod rfft; - -/// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a -/// with the even words of b. -const EVENS_INTERLEAVE_EVENS: __m512i = unsafe { - core::mem::transmute([ - 0b00000, 0b10000, 0b00010, 0b10010, 0b00100, 0b10100, 0b00110, 0b10110, 0b01000, 0b11000, - 0b01010, 0b11010, 0b01100, 0b11100, 0b01110, 0b11110, - ]) -}; -/// An input to _mm512_permutex2var_epi32, and is used to interleave the odd words of a -/// with the odd words of b. -const ODDS_INTERLEAVE_ODDS: __m512i = unsafe { - core::mem::transmute([ - 0b00001, 0b10001, 0b00011, 0b10011, 0b00101, 0b10101, 0b00111, 0b10111, 0b01001, 0b11001, - 0b01011, 0b11011, 0b01101, 0b11101, 0b01111, 0b11111, - ]) -}; - -pub const CACHED_FFT_LOG_SIZE: usize = 16; -pub const MIN_FFT_LOG_SIZE: usize = 5; - -// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce -// it somewhere. - -/// Transposes the AVX vectors in the given array. -/// Swaps the bit index abc <-> cba, where |a|=|c| and |b| = 0 or 1, according to the parity of -/// `log_n_vecs`. -/// When log_n_vecs is odd, transforms the index abc <-> cba, w -/// -/// # Safety -/// This function is unsafe because it takes a raw pointer to i32 values. -/// `values` must be aligned to 64 bytes. -/// -/// # Arguments -/// * `values`: A mutable pointer to the values that are to be transposed. -/// * `log_n_vecs`: The log of the number of AVX vectors in the `values` array. -pub unsafe fn transpose_vecs(values: *mut i32, log_n_vecs: usize) { - let half = log_n_vecs / 2; - for b in 0..(1 << (log_n_vecs & 1)) { - for a in 0..(1 << half) { - for c in 0..(1 << half) { - let i = (a << (log_n_vecs - half)) | (b << half) | c; - let j = (c << (log_n_vecs - half)) | (b << half) | a; - if i >= j { - continue; - } - let val0 = _mm512_load_epi32(values.add(i << 4).cast_const()); - let val1 = _mm512_load_epi32(values.add(j << 4).cast_const()); - _mm512_store_epi32(values.add(i << 4), val1); - _mm512_store_epi32(values.add(j << 4), val0); - } - } - } -} - -/// Computes the twiddles for the first fft layer from the second, and loads both to AVX registers. -/// Returns the twiddles for the first layer and the twiddles for the second layer. -/// # Safety -pub unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (__m512i, __m512i) { - // Start by loading the twiddles for the second layer (layer 1): - // The twiddles for layer 1 are replicated in the following pattern: - // 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 - let t1 = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); - - // The twiddles for layer 0 can be computed from the twiddles for layer 1. - // Since the twiddles are bit reversed, we consider the circle domain in bit reversed order. - // Each consecutive 4 points in the bit reversed order of a coset form a circle coset of size 4. - // A circle coset of size 4 in bit reversed order looks like this: - // [(x, y), (-x, -y), (y, -x), (-y, x)] - // Note: This is related to the choice of M31_CIRCLE_GEN, and the fact the a quarter rotation - // is (0,-1) and not (0,1). (0,1) would yield another relation. - // The twiddles for layer 0 are the y coordinates: - // [y, -y, -x, x] - // The twiddles for layer 1 in bit reversed order are the x coordinates: - // [x, y] - // Works also for inverse of the twiddles. - - // The twiddles for layer 0 are computed like this: - // t0[4i:4i+3] = [t1[2i+1], -t1[2i+1], -t1[2i], t1[2i]] - const INDICES_FROM_T1: __m512i = unsafe { - core::mem::transmute([ - 0b0001, 0b0001, 0b0000, 0b0000, 0b0011, 0b0011, 0b0010, 0b0010, 0b0101, 0b0101, 0b0100, - 0b0100, 0b0111, 0b0111, 0b0110, 0b0110, - ]) - }; - // Xoring a double twiddle with 2^32-2 transforms it to the double of it negation. - // Note that this keeps the values as a double of a value in the range [0, P]. - const NEGATION_MASK: __m512i = unsafe { - core::mem::transmute([0i32, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0]) - }; - let t0 = _mm512_xor_epi32(_mm512_permutexvar_epi32(INDICES_FROM_T1, t1), NEGATION_MASK); - (t0, t1) -} diff --git a/crates/prover/src/core/backend/avx512/fft/rfft.rs b/crates/prover/src/core/backend/avx512/fft/rfft.rs deleted file mode 100644 index 43fbf7f32..000000000 --- a/crates/prover/src/core/backend/avx512/fft/rfft.rs +++ /dev/null @@ -1,757 +0,0 @@ -//! Regular (forward) fft. - -use std::arch::x86_64::{ - __m512i, _mm512_broadcast_i32x4, _mm512_mul_epu32, _mm512_permutex2var_epi32, - _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, -}; - -use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS}; -use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; -use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE}; -use crate::core::circle::Coset; -use crate::core::utils::bit_reverse; - -/// Performs a Circle Fast Fourier Transform (ICFFT) on the given values. -/// -/// # Safety -/// This function is unsafe because it takes a raw pointer to i32 values. -/// `values` must be aligned to 64 bytes. -/// -/// # Arguments -/// * `src`: A pointer to the values to transform. -/// * `dst`: A pointer to the destination array. -/// * `twiddle_dbl`: A reference to the doubles of the twiddle factors. -/// * `log_n_elements`: The log of the number of elements in the `values` array. -/// -/// # Panics -/// This function will panic if `log_n_elements` is less than `MIN_FFT_LOG_SIZE`. -pub unsafe fn fft(src: *const i32, dst: *mut i32, twiddle_dbl: &[&[i32]], log_n_elements: usize) { - assert!(log_n_elements >= MIN_FFT_LOG_SIZE); - let log_n_vecs = log_n_elements - VECS_LOG_SIZE; - if log_n_elements <= CACHED_FFT_LOG_SIZE { - fft_lower_with_vecwise(src, dst, twiddle_dbl, log_n_elements, log_n_elements); - return; - } - - let fft_layers_pre_transpose = log_n_vecs.div_ceil(2); - let fft_layers_post_transpose = log_n_vecs / 2; - fft_lower_without_vecwise( - src, - dst, - &twiddle_dbl[(3 + fft_layers_pre_transpose)..], - log_n_elements, - fft_layers_post_transpose, - ); - transpose_vecs(dst, log_n_vecs); - fft_lower_with_vecwise( - dst, - dst, - &twiddle_dbl[..(3 + fft_layers_pre_transpose)], - log_n_elements, - fft_layers_pre_transpose + VECS_LOG_SIZE, - ); -} - -/// Computes partial fft on `2^log_size` M31 elements. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each layer of the the fft. -/// layer i holds 2^(log_size - 1 - i) twiddles. -/// log_size - The log of the number of number of M31 elements in the array. -/// fft_layers - The number of fft layers to apply, out of log_size. -/// # Safety -/// `values` must be aligned to 64 bytes. -/// `log_size` must be at least 5. -/// `fft_layers` must be at least 5. -pub unsafe fn fft_lower_with_vecwise( - src: *const i32, - dst: *mut i32, - twiddle_dbl: &[&[i32]], - log_size: usize, - fft_layers: usize, -) { - const VECWISE_FFT_BITS: usize = VECS_LOG_SIZE + 1; - assert!(log_size >= VECWISE_FFT_BITS); - - assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - - for index_h in 0..(1 << (log_size - fft_layers)) { - let mut src = src; - for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() { - match fft_layers - layer { - 1 => { - fft1_loop(src, dst, &twiddle_dbl[(layer - 1)..], layer, index_h); - } - 2 => { - fft2_loop(src, dst, &twiddle_dbl[(layer - 1)..], layer, index_h); - } - _ => { - fft3_loop( - src, - dst, - &twiddle_dbl[(layer - 1)..], - fft_layers - layer - 3, - layer, - index_h, - ); - } - } - src = dst; - } - fft_vecwise_loop( - src, - dst, - twiddle_dbl, - fft_layers - VECWISE_FFT_BITS, - index_h, - ); - } -} - -/// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits -/// of the index). -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each layer of the the fft. -/// log_size - The log of the number of number of M31 elements in the array. -/// fft_layers - The number of fft layers to apply, out of log_size - VEC_LOG_SIZE. -/// -/// # Safety -/// `values` must be aligned to 64 bytes. -/// `log_size` must be at least 4. -/// `fft_layers` must be at least 4. -pub unsafe fn fft_lower_without_vecwise( - src: *const i32, - dst: *mut i32, - twiddle_dbl: &[&[i32]], - log_size: usize, - fft_layers: usize, -) { - assert!(log_size >= VECS_LOG_SIZE); - - for index_h in 0..(1 << (log_size - fft_layers - VECS_LOG_SIZE)) { - let mut src = src; - for layer in (0..fft_layers).step_by(3).rev() { - let fixed_layer = layer + VECS_LOG_SIZE; - match fft_layers - layer { - 1 => { - fft1_loop(src, dst, &twiddle_dbl[layer..], fixed_layer, index_h); - } - 2 => { - fft2_loop(src, dst, &twiddle_dbl[layer..], fixed_layer, index_h); - } - _ => { - fft3_loop( - src, - dst, - &twiddle_dbl[layer..], - fft_layers - layer - 3, - fixed_layer, - index_h, - ); - } - } - src = dst; - } - } -} - -/// Runs the last 5 fft layers across the entire array. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each of the 5 fft layers. -/// high_bits - The number of bits this loops needs to run on. -/// index_h - The higher part of the index, iterated by the caller. -/// # Safety -unsafe fn fft_vecwise_loop( - src: *const i32, - dst: *mut i32, - twiddle_dbl: &[&[i32]], - loop_bits: usize, - index_h: usize, -) { - for index_l in 0..(1 << loop_bits) { - let index = (index_h << loop_bits) + index_l; - let mut val0 = PackedBaseField::load(src.add(index * 32)); - let mut val1 = PackedBaseField::load(src.add(index * 32 + 16)); - (val0, val1) = avx_butterfly( - val0, - val1, - _mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)), - ); - (val0, val1) = vecwise_butterflies( - val0, - val1, - std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)), - std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), - std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), - ); - val0.store(dst.add(index * 32)); - val1.store(dst.add(index * 32 + 16)); - } -} - -/// Runs 3 fft layers across the entire array. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each of the 3 fft layers. -/// loop_bits - The number of bits this loops needs to run on. -/// layer - The layer number of the first fft layer to apply. -/// The layers `layer`, `layer + 1`, `layer + 2` are applied. -/// index_h - The higher part of the index, iterated by the caller. -/// # Safety -unsafe fn fft3_loop( - src: *const i32, - dst: *mut i32, - twiddle_dbl: &[&[i32]], - loop_bits: usize, - layer: usize, - index_h: usize, -) { - for index_l in 0..(1 << loop_bits) { - let index = (index_h << loop_bits) + index_l; - let offset = index << (layer + 3); - for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { - fft3( - src, - dst, - offset + l, - layer, - std::array::from_fn(|i| { - *twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1)) - }), - std::array::from_fn(|i| { - *twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1)) - }), - std::array::from_fn(|i| { - *twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1)) - }), - ); - } - } -} - -/// Runs 2 fft layers across the entire array. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each of the 2 fft layers. -/// loop_bits - The number of bits this loops needs to run on. -/// layer - The layer number of the first fft layer to apply. -/// The layers `layer`, `layer + 1` are applied. -/// index - The index, iterated by the caller. -/// # Safety -unsafe fn fft2_loop( - src: *const i32, - dst: *mut i32, - twiddle_dbl: &[&[i32]], - layer: usize, - index: usize, -) { - let offset = index << (layer + 2); - for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { - fft2( - src, - dst, - offset + l, - layer, - std::array::from_fn(|i| { - *twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1)) - }), - std::array::from_fn(|i| { - *twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1)) - }), - ); - } -} - -/// Runs 1 fft layer across the entire array. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for the fft layer. -/// layer - The layer number of the fft layer to apply. -/// index_h - The higher part of the index, iterated by the caller. -/// # Safety -unsafe fn fft1_loop( - src: *const i32, - dst: *mut i32, - twiddle_dbl: &[&[i32]], - layer: usize, - index: usize, -) { - let offset = index << (layer + 1); - for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { - fft1( - src, - dst, - offset + l, - layer, - std::array::from_fn(|i| { - *twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1)) - }), - ); - } -} - -/// Computes the butterfly operation for packed M31 elements. -/// val0 + t val1, val0 - t val1. -/// val0, val1 are packed M31 elements. 16 M31 words at each. -/// Each value is assumed to be in unreduced form, [0, P] including P. -/// Returned values are in unreduced form, [0, P] including P. -/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in unreduced form. -/// # Safety -/// This function is safe. -pub unsafe fn avx_butterfly( - val0: PackedBaseField, - val1: PackedBaseField, - twiddle_dbl: __m512i, -) -> (PackedBaseField, PackedBaseField) { - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of val0. - let val1_e = val1.0; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of val0. - let val1_o = _mm512_srli_epi64(val1.0, 32); - let twiddle_dbl_e = twiddle_dbl; - let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32); - - // To compute prod = val1 * twiddle start by multiplying - // val1_e/o by twiddle_dbl_e/o. - let prod_e_dbl = _mm512_mul_epu32(val1_e, twiddle_dbl_e); - let prod_o_dbl = _mm512_mul_epu32(val1_o, twiddle_dbl_o); - - // The result of a multiplication holds val1*twiddle_dbl in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_e_dbl - |0|prod_e_h|prod_e_l|0| - // prod_o_dbl - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl); - // prod_ls - |prod_o_l|0|prod_e_l|0| - - // Divide by 2: - let prod_ls = _mm512_srli_epi64(prod_ls, 1); - // prod_ls - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl); - // prod_hs - |0|prod_o_h|0|prod_e_h| - - let prod = PackedBaseField(prod_ls) + PackedBaseField(prod_hs); - - let r0 = val0 + prod; - let r1 = val0 - prod; - - (r0, r1) -} - -/// Runs fft on 2 vectors of 16 M31 elements. -/// This amounts to 4 butterfly layers, each with 16 butterflies. -/// Each of the vectors represents natural ordered polynomial coefficeint. -/// Each value in a vectors is in unreduced form: [0, P] including P. -/// Takes 4 twiddle arrays, one for each layer, holding the double of the corresponding twiddle. -/// The first layer (higher bit of the index) takes 2 twiddles. -/// The second layer takes 4 twiddles. -/// etc. -/// # Safety -pub unsafe fn vecwise_butterflies( - mut val0: PackedBaseField, - mut val1: PackedBaseField, - twiddle1_dbl: [i32; 8], - twiddle2_dbl: [i32; 4], - twiddle3_dbl: [i32; 2], -) -> (PackedBaseField, PackedBaseField) { - // TODO(spapini): Compute twiddle0 from twiddle1. - // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. - // The implementation is the exact reverse of vecwise_ibutterflies(). - // See the comments in its body for more info. - let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl)); - (val0, val1) = val0.interleave_with(val1); - (val0, val1) = avx_butterfly(val0, val1, t); - - let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl)); - (val0, val1) = val0.interleave_with(val1); - (val0, val1) = avx_butterfly(val0, val1, t); - - let (t0, t1) = compute_first_twiddles(twiddle1_dbl); - (val0, val1) = val0.interleave_with(val1); - (val0, val1) = avx_butterfly(val0, val1, t1); - - (val0, val1) = val0.interleave_with(val1); - (val0, val1) = avx_butterfly(val0, val1, t0); - - val0.interleave_with(val1) -} - -/// Returns the line twiddles (x points) for an fft on a coset. -pub fn get_twiddle_dbls(mut coset: Coset) -> Vec> { - let mut res = vec![]; - for _ in 0..coset.log_size() { - res.push( - coset - .iter() - .take(coset.size() / 2) - .map(|p| (p.x.0 * 2) as i32) - .collect::>(), - ); - bit_reverse(res.last_mut().unwrap()); - coset = coset.double(); - } - - res -} - -/// Applies 3 butterfly layers on 8 vectors of 16 M31 elements. -/// Vectorized over the 16 elements of the vectors. -/// Used for radix-8 ifft. -/// Each butterfly layer, has 3 AVX butterflies. -/// Total of 12 AVX butterflies. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// offset - The offset of the first value in the array. -/// log_step - The log of the distance in the array, in M31 elements, between each pair of -/// values that need to be transformed. For layer i this is i - 4. -/// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of butterflies. -/// Each layer has 4/2/1 twiddles. -/// # Safety -pub unsafe fn fft3( - src: *const i32, - dst: *mut i32, - offset: usize, - log_step: usize, - twiddles_dbl0: [i32; 4], - twiddles_dbl1: [i32; 2], - twiddles_dbl2: [i32; 1], -) { - // Load the 8 AVX vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); - let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); - let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); - let mut val4 = PackedBaseField::load(src.add(offset + (4 << log_step))); - let mut val5 = PackedBaseField::load(src.add(offset + (5 << log_step))); - let mut val6 = PackedBaseField::load(src.add(offset + (6 << log_step))); - let mut val7 = PackedBaseField::load(src.add(offset + (7 << log_step))); - - // Apply the third layer of butterflies. - (val0, val4) = avx_butterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0])); - (val1, val5) = avx_butterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0])); - (val2, val6) = avx_butterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0])); - (val3, val7) = avx_butterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0])); - - // Apply the second layer of butterflies. - (val0, val2) = avx_butterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); - (val1, val3) = avx_butterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); - (val4, val6) = avx_butterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1])); - (val5, val7) = avx_butterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1])); - - // Apply the first layer of butterflies. - (val0, val1) = avx_butterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - (val2, val3) = avx_butterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); - (val4, val5) = avx_butterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2])); - (val6, val7) = avx_butterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3])); - - // Store the 8 AVX vectors back to the array. - val0.store(dst.add(offset + (0 << log_step))); - val1.store(dst.add(offset + (1 << log_step))); - val2.store(dst.add(offset + (2 << log_step))); - val3.store(dst.add(offset + (3 << log_step))); - val4.store(dst.add(offset + (4 << log_step))); - val5.store(dst.add(offset + (5 << log_step))); - val6.store(dst.add(offset + (6 << log_step))); - val7.store(dst.add(offset + (7 << log_step))); -} - -/// Applies 2 butterfly layers on 4 vectors of 16 M31 elements. -/// Vectorized over the 16 elements of the vectors. -/// Used for radix-4 fft. -/// Each butterfly layer, has 2 AVX butterflies. -/// Total of 4 AVX butterflies. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// offset - The offset of the first value in the array. -/// log_step - The log of the distance in the array, in M31 elements, between each pair of -/// values that need to be transformed. For layer i this is i - 4. -/// twiddles_dbl0/1 - The double of the twiddles for the 2 layers of butterflies. -/// Each layer has 2/1 twiddles. -/// # Safety -pub unsafe fn fft2( - src: *const i32, - dst: *mut i32, - offset: usize, - log_step: usize, - twiddles_dbl0: [i32; 2], - twiddles_dbl1: [i32; 1], -) { - // Load the 4 AVX vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); - let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); - let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); - - // Apply the second layer of butterflies. - (val0, val2) = avx_butterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); - (val1, val3) = avx_butterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); - - // Apply the first layer of butterflies. - (val0, val1) = avx_butterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - (val2, val3) = avx_butterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); - - // Store the 4 AVX vectors back to the array. - val0.store(dst.add(offset + (0 << log_step))); - val1.store(dst.add(offset + (1 << log_step))); - val2.store(dst.add(offset + (2 << log_step))); - val3.store(dst.add(offset + (3 << log_step))); -} - -/// Applies 1 butterfly layers on 2 vectors of 16 M31 elements. -/// Vectorized over the 16 elements of the vectors. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// offset - The offset of the first value in the array. -/// log_step - The log of the distance in the array, in M31 elements, between each pair of -/// values that need to be transformed. For layer i this is i - 4. -/// twiddles_dbl0 - The double of the twiddles for the butterfly layer. -/// # Safety -pub unsafe fn fft1( - src: *const i32, - dst: *mut i32, - offset: usize, - log_step: usize, - twiddles_dbl0: [i32; 1], -) { - // Load the 2 AVX vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); - - (val0, val1) = avx_butterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - - // Store the 2 AVX vectors back to the array. - val0.store(dst.add(offset + (0 << log_step))); - val1.store(dst.add(offset + (1 << log_step))); -} - -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -#[cfg(test)] -mod tests { - use std::arch::x86_64::{_mm512_add_epi32, _mm512_set1_epi32, _mm512_setr_epi32}; - - use super::*; - use crate::core::backend::avx512::{BaseFieldVec, PackedBaseField}; - use crate::core::backend::cpu::CPUCirclePoly; - use crate::core::backend::Column; - use crate::core::fft::butterfly; - use crate::core::fields::m31::BaseField; - use crate::core::poly::circle::{CanonicCoset, CircleDomain}; - - #[test] - fn test_butterfly() { - unsafe { - let val0 = PackedBaseField(_mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - )); - let val1 = PackedBaseField(_mm512_setr_epi32( - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - )); - let twiddle = _mm512_setr_epi32( - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - ); - let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); - let (r0, r1) = avx_butterfly(val0, val1, twiddle_dbl); - - let val0: [BaseField; 16] = std::mem::transmute(val0); - let val1: [BaseField; 16] = std::mem::transmute(val1); - let twiddle: [BaseField; 16] = std::mem::transmute(twiddle); - let r0: [BaseField; 16] = std::mem::transmute(r0); - let r1: [BaseField; 16] = std::mem::transmute(r1); - - for i in 0..16 { - let mut x = val0[i]; - let mut y = val1[i]; - let twiddle = twiddle[i]; - butterfly(&mut x, &mut y, twiddle); - assert_eq!(x, r0[i]); - assert_eq!(y, r1[i]); - } - } - } - - #[test] - fn test_fft3() { - unsafe { - let mut values: Vec = (0..8) - .map(|i| { - PackedBaseField::from_array(std::array::from_fn(|_| { - BaseField::from_u32_unchecked(i) - })) - }) - .collect(); - let twiddles0 = [32, 33, 34, 35]; - let twiddles1 = [36, 37]; - let twiddles2 = [38]; - let twiddles0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); - let twiddles1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); - let twiddles2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); - fft3( - std::mem::transmute(values.as_ptr()), - std::mem::transmute(values.as_mut_ptr()), - 0, - VECS_LOG_SIZE, - twiddles0_dbl, - twiddles1_dbl, - twiddles2_dbl, - ); - - let expected: [u32; 8] = std::array::from_fn(|i| i as u32); - let mut expected: [BaseField; 8] = std::mem::transmute(expected); - let twiddles0: [BaseField; 4] = std::mem::transmute(twiddles0); - let twiddles1: [BaseField; 2] = std::mem::transmute(twiddles1); - let twiddles2: [BaseField; 1] = std::mem::transmute(twiddles2); - for i in 0..8 { - let j = i ^ 4; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - butterfly(&mut v0, &mut v1, twiddles2[0]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - let j = i ^ 2; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - butterfly(&mut v0, &mut v1, twiddles1[i / 4]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - let j = i ^ 1; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - butterfly(&mut v0, &mut v1, twiddles0[i / 2]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - assert_eq!(values[i].to_array()[0], expected[i]); - } - } - } - - fn ref_fft(domain: CircleDomain, values: Vec) -> Vec { - let poly = CPUCirclePoly::new(values); - poly.evaluate(domain).values - } - - #[test] - fn test_vecwise_butterflies() { - let domain = CanonicCoset::new(5).circle_domain(); - let twiddle_dbls = get_twiddle_dbls(domain.half_coset); - assert_eq!(twiddle_dbls.len(), 4); - let values0: [i32; 16] = std::array::from_fn(|i| i as i32); - let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32); - let result: [BaseField; 32] = unsafe { - let (val0, val1) = avx_butterfly( - std::mem::transmute(values0), - std::mem::transmute(values1), - _mm512_set1_epi32(twiddle_dbls[3][0]), - ); - let (val0, val1) = vecwise_butterflies( - val0, - val1, - twiddle_dbls[0].clone().try_into().unwrap(), - twiddle_dbls[1].clone().try_into().unwrap(), - twiddle_dbls[2].clone().try_into().unwrap(), - ); - std::mem::transmute([val0, val1]) - }; - - // ref. - let mut values = values0.to_vec(); - values.extend_from_slice(&values1); - let expected = ref_fft(domain, values.into_iter().map(BaseField::from).collect()); - - // Compare. - for i in 0..32 { - assert_eq!(result[i], expected[i]); - } - } - - #[test] - fn test_fft_lower() { - for log_size in 5..12 { - let domain = CanonicCoset::new(log_size).circle_domain(); - let values = (0..domain.size()) - .map(|i| BaseField::from_u32_unchecked(i as u32)) - .collect::>(); - let expected_coeffs = ref_fft(domain, values.clone()); - - // Compute. - let mut values = BaseFieldVec::from_iter(values); - let twiddle_dbls = get_twiddle_dbls(domain.half_coset); - - unsafe { - fft_lower_with_vecwise( - std::mem::transmute(values.data.as_ptr()), - std::mem::transmute(values.data.as_mut_ptr()), - &twiddle_dbls - .iter() - .map(|x| x.as_slice()) - .collect::>(), - log_size as usize, - log_size as usize, - ); - - // Compare. - assert_eq!(values.to_cpu(), expected_coeffs); - } - } - } - - fn run_fft_full_test(log_size: u32) { - let domain = CanonicCoset::new(log_size).circle_domain(); - let values = (0..domain.size()) - .map(|i| BaseField::from_u32_unchecked(i as u32)) - .collect::>(); - let expected_coeffs = ref_fft(domain, values.clone()); - - // Compute. - let mut values = BaseFieldVec::from_iter(values); - let twiddle_dbls = get_twiddle_dbls(domain.half_coset); - - unsafe { - transpose_vecs( - std::mem::transmute(values.data.as_mut_ptr()), - (log_size - 4) as usize, - ); - fft( - std::mem::transmute(values.data.as_ptr()), - std::mem::transmute(values.data.as_mut_ptr()), - &twiddle_dbls - .iter() - .map(|x| x.as_slice()) - .collect::>(), - log_size as usize, - ); - - // Compare. - assert_eq!(values.to_cpu(), expected_coeffs); - } - } - - #[test] - fn test_fft_full() { - for i in (CACHED_FFT_LOG_SIZE + 1)..(CACHED_FFT_LOG_SIZE + 3) { - run_fft_full_test(i as u32); - } - } -} diff --git a/crates/prover/src/core/backend/avx512/fri.rs b/crates/prover/src/core/backend/avx512/fri.rs deleted file mode 100644 index 4b5cec68c..000000000 --- a/crates/prover/src/core/backend/avx512/fri.rs +++ /dev/null @@ -1,265 +0,0 @@ -use num_traits::Zero; - -use super::{AVX512Backend, PackedBaseField, K_BLOCK_SIZE}; -use crate::core::backend::avx512::fft::compute_first_twiddles; -use crate::core::backend::avx512::fft::ifft::avx_ibutterfly; -use crate::core::backend::avx512::qm31::PackedSecureField; -use crate::core::backend::avx512::VECS_LOG_SIZE; -use crate::core::backend::Column; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SecureColumn; -use crate::core::fri::{self, FriOps}; -use crate::core::poly::circle::SecureEvaluation; -use crate::core::poly::line::LineEvaluation; -use crate::core::poly::twiddles::TwiddleTree; -use crate::core::poly::utils::domain_line_twiddles_from_tree; - -impl FriOps for AVX512Backend { - fn fold_line( - eval: &LineEvaluation, - alpha: SecureField, - twiddles: &TwiddleTree, - ) -> LineEvaluation { - let log_size = eval.len().ilog2(); - if log_size <= VECS_LOG_SIZE as u32 { - let eval = fri::fold_line(&eval.to_cpu(), alpha); - return LineEvaluation::new(eval.domain(), eval.values.into_iter().collect()); - } - - let domain = eval.domain(); - let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0]; - - let mut folded_values = SecureColumn::::zeros(1 << (log_size - 1)); - - for vec_index in 0..(1 << (log_size - 1 - VECS_LOG_SIZE as u32)) { - let value = unsafe { - let twiddle_dbl: [i32; 16] = - std::array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i)); - let val0 = eval.values.packed_at(vec_index * 2).to_packed_m31s(); - let val1 = eval.values.packed_at(vec_index * 2 + 1).to_packed_m31s(); - let pairs: [_; 4] = std::array::from_fn(|i| { - let (a, b) = val0[i].deinterleave_with(val1[i]); - avx_ibutterfly(a, b, std::mem::transmute(twiddle_dbl)) - }); - let val0 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].0)); - let val1 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].1)); - val0 + PackedSecureField::broadcast(alpha) * val1 - }; - unsafe { folded_values.set_packed(vec_index, value) }; - } - - LineEvaluation::new(domain.double(), folded_values) - } - - fn fold_circle_into_line( - dst: &mut LineEvaluation, - src: &SecureEvaluation, - alpha: SecureField, - twiddles: &TwiddleTree, - ) { - let log_size = src.len().ilog2(); - assert!(log_size > VECS_LOG_SIZE as u32, "Evaluation too small"); - - let domain = src.domain; - let alpha_sq = alpha * alpha; - let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0]; - - for vec_index in 0..(1 << (log_size - 1 - VECS_LOG_SIZE as u32)) { - let value = unsafe { - // The 16 twiddles of the circle domain can be derived from the 8 twiddles of the - // next line domain. See `compute_first_twiddles()`. - let twiddle_dbl: [i32; 8] = - std::array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 8 + i)); - let (t0, _) = compute_first_twiddles(twiddle_dbl); - let val0 = src.values.packed_at(vec_index * 2).to_packed_m31s(); - let val1 = src.values.packed_at(vec_index * 2 + 1).to_packed_m31s(); - let pairs: [_; 4] = std::array::from_fn(|i| { - let (a, b) = val0[i].deinterleave_with(val1[i]); - avx_ibutterfly(a, b, t0) - }); - let val0 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].0)); - let val1 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].1)); - val0 + PackedSecureField::broadcast(alpha) * val1 - }; - unsafe { - dst.values.set_packed( - vec_index, - dst.values.packed_at(vec_index) * PackedSecureField::broadcast(alpha_sq) - + value, - ) - }; - } - } - - fn decompose(eval: &SecureEvaluation) -> (SecureEvaluation, SecureField) { - let lambda = Self::decomposition_coefficient(eval); - let broadcasted_lambda = PackedSecureField::broadcast(lambda); - let mut g_values = SecureColumn::::zeros(eval.len()); - - let range = eval.len().div_ceil(K_BLOCK_SIZE); - let half_range = range / 2; - for i in 0..half_range { - let val = eval.packed_at(i) - broadcasted_lambda; - unsafe { g_values.set_packed(i, val) } - } - for i in half_range..range { - let val = eval.packed_at(i) + broadcasted_lambda; - unsafe { g_values.set_packed(i, val) } - } - - let g = SecureEvaluation { - domain: eval.domain, - values: g_values, - }; - (g, lambda) - } -} - -impl AVX512Backend { - /// See [`decomposition_coefficient`]. - /// - /// [`CPUBackend::decomposition_coefficient`]: - /// crate::core::backend::cpu::CPUBackend::decomposition_coefficient - fn decomposition_coefficient(eval: &SecureEvaluation) -> SecureField { - let cols = &eval.values.columns; - let [mut x_sum, mut y_sum, mut z_sum, mut w_sum] = [PackedBaseField::zero(); 4]; - - let range = cols[0].len() / K_BLOCK_SIZE; - let (half_a, half_b) = (range / 2, range); - - for i in 0..half_a { - x_sum += cols[0].data[i]; - y_sum += cols[1].data[i]; - z_sum += cols[2].data[i]; - w_sum += cols[3].data[i]; - } - for i in half_a..half_b { - x_sum -= cols[0].data[i]; - y_sum -= cols[1].data[i]; - z_sum -= cols[2].data[i]; - w_sum -= cols[3].data[i]; - } - - let x = x_sum.pointwise_sum(); - let y = y_sum.pointwise_sum(); - let z = z_sum.pointwise_sum(); - let w = w_sum.pointwise_sum(); - - SecureField::from_m31(x, y, z, w) - / BaseField::from_u32_unchecked(1 << eval.domain.log_size()) - } -} - -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -#[cfg(test)] -mod tests { - use crate::core::backend::avx512::{AVX512Backend, BaseFieldVec}; - use crate::core::backend::{CPUBackend, Column}; - use crate::core::fields::qm31::SecureField; - use crate::core::fields::secure_column::SecureColumn; - use crate::core::fri::FriOps; - use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps, SecureEvaluation}; - use crate::core::poly::line::{LineDomain, LineEvaluation}; - use crate::{m31, qm31}; - - #[test] - fn test_fold_line() { - const LOG_SIZE: u32 = 7; - let values: Vec = (0..(1 << LOG_SIZE)) - .map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3)) - .collect(); - let alpha = qm31!(1, 3, 5, 7); - let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset()); - let cpu_fold = CPUBackend::fold_line( - &LineEvaluation::new(domain, values.iter().copied().collect()), - alpha, - &CPUBackend::precompute_twiddles(domain.coset()), - ); - - let avx_fold = AVX512Backend::fold_line( - &LineEvaluation::new(domain, values.iter().copied().collect()), - alpha, - &AVX512Backend::precompute_twiddles(domain.coset()), - ); - - assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec()); - } - - #[test] - fn test_fold_circle_into_line() { - const LOG_SIZE: u32 = 7; - let values: Vec = (0..(1 << LOG_SIZE)) - .map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3)) - .collect(); - let alpha = qm31!(1, 3, 5, 7); - let circle_domain = CanonicCoset::new(LOG_SIZE).circle_domain(); - let line_domain = LineDomain::new(circle_domain.half_coset); - - let mut cpu_fold = - LineEvaluation::new(line_domain, SecureColumn::zeros(1 << (LOG_SIZE - 1))); - CPUBackend::fold_circle_into_line( - &mut cpu_fold, - &SecureEvaluation { - domain: circle_domain, - values: values.iter().copied().collect(), - }, - alpha, - &CPUBackend::precompute_twiddles(line_domain.coset()), - ); - - let mut avx_fold = - LineEvaluation::new(line_domain, SecureColumn::zeros(1 << (LOG_SIZE - 1))); - AVX512Backend::fold_circle_into_line( - &mut avx_fold, - &SecureEvaluation { - domain: circle_domain, - values: values.iter().copied().collect(), - }, - alpha, - &AVX512Backend::precompute_twiddles(line_domain.coset()), - ); - - assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec()); - } - - #[test] - fn decomposition_test() { - const DOMAIN_LOG_SIZE: u32 = 5; - const DOMAIN_LOG_HALF_SIZE: u32 = DOMAIN_LOG_SIZE - 1; - let s = CanonicCoset::new(DOMAIN_LOG_SIZE); - let domain = s.circle_domain(); - - let mut coeffs = BaseFieldVec::zeros(1 << DOMAIN_LOG_SIZE); - - // Polynomial is out of FFT space. - coeffs.as_mut_slice()[1 << DOMAIN_LOG_HALF_SIZE] = m31!(1); - let poly = CirclePoly::::new(coeffs); - let values = poly.evaluate(domain); - - let avx_column = SecureColumn:: { - columns: [ - values.values.clone(), - values.values.clone(), - values.values.clone(), - values.values.clone(), - ], - }; - let avx_eval = SecureEvaluation { - domain, - values: avx_column.clone(), - }; - let cpu_eval = SecureEvaluation:: { - domain, - values: avx_eval.to_cpu(), - }; - let (cpu_g, cpu_lambda) = CPUBackend::decompose(&cpu_eval); - - let (avx_g, avx_lambda) = AVX512Backend::decompose(&avx_eval); - - assert_eq!(avx_lambda, cpu_lambda); - for i in 0..(1 << DOMAIN_LOG_SIZE) { - assert_eq!(avx_g.values.at(i), cpu_g.values.at(i)); - } - } -} diff --git a/crates/prover/src/core/backend/avx512/m31.rs b/crates/prover/src/core/backend/avx512/m31.rs deleted file mode 100644 index 0e6ee27dd..000000000 --- a/crates/prover/src/core/backend/avx512/m31.rs +++ /dev/null @@ -1,326 +0,0 @@ -use core::arch::x86_64::{ - __m512i, _mm512_add_epi32, _mm512_min_epu32, _mm512_mul_epu32, _mm512_srli_epi64, - _mm512_sub_epi32, -}; -use std::arch::x86_64::{ - _mm512_load_epi32, _mm512_permutex2var_epi32, _mm512_set1_epi32, _mm512_setzero_si512, - _mm512_store_epi32, -}; -use std::fmt::Display; -use std::iter::Sum; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; - -use num_traits::{One, Zero}; - -use super::tranpose_utils::{ - EVENS_CONCAT_EVENS, HHALF_INTERLEAVE_HHALF, LHALF_INTERLEAVE_LHALF, ODDS_CONCAT_ODDS, -}; -use crate::core::fields::m31::{pow2147483645, M31, P}; -use crate::core::fields::FieldExpOps; - -pub const K_BLOCK_SIZE: usize = 16; -pub const M512P: __m512i = unsafe { core::mem::transmute([P; K_BLOCK_SIZE]) }; - -/// AVX512 implementation of M31. -/// Stores 16 M31 elements in a single 512-bit register. -/// Each M31 element is unreduced in the range [0, P]. -#[derive(Copy, Clone, Debug)] -pub struct PackedBaseField(pub __m512i); - -impl PackedBaseField { - pub fn broadcast(value: M31) -> Self { - Self(unsafe { _mm512_set1_epi32(value.0 as i32) }) - } - - pub fn from_array(v: [M31; K_BLOCK_SIZE]) -> PackedBaseField { - unsafe { Self(std::mem::transmute(v)) } - } - - pub fn from_m512_unchecked(x: __m512i) -> Self { - Self(x) - } - - pub fn to_array(self) -> [M31; K_BLOCK_SIZE] { - unsafe { std::mem::transmute(self.reduce()) } - } - - /// Reduces each word in the 512-bit register to the range `[0, P)`, excluding P. - pub fn reduce(self) -> PackedBaseField { - Self(unsafe { _mm512_min_epu32(self.0, _mm512_sub_epi32(self.0, M512P)) }) - } - - /// Interleaves self with other. - /// Returns the result as two packed M31 elements. - pub fn interleave_with(self, other: Self) -> (Self, Self) { - ( - Self(unsafe { _mm512_permutex2var_epi32(self.0, LHALF_INTERLEAVE_LHALF, other.0) }), - Self(unsafe { _mm512_permutex2var_epi32(self.0, HHALF_INTERLEAVE_HHALF, other.0) }), - ) - } - - /// Deinterleaves self with other. - /// Done by concatenating the even words of self with the even words of other, and the odd words - /// The inverse of [Self::interleave_with]. - /// Returns the result as two packed M31 elements. - pub fn deinterleave_with(self, other: Self) -> (Self, Self) { - ( - Self(unsafe { _mm512_permutex2var_epi32(self.0, EVENS_CONCAT_EVENS, other.0) }), - Self(unsafe { _mm512_permutex2var_epi32(self.0, ODDS_CONCAT_ODDS, other.0) }), - ) - } - - /// # Safety - /// - /// This function is unsafe because it performs a load from a raw pointer. The pointer must be - /// valid and aligned to 64 bytes. - pub unsafe fn load(ptr: *const i32) -> Self { - Self(_mm512_load_epi32(ptr)) - } - - /// # Safety - /// - /// This function is unsafe because it performs a load from a raw pointer. The pointer must be - /// valid and aligned to 64 bytes. - pub unsafe fn store(self, ptr: *mut i32) { - _mm512_store_epi32(ptr, self.0); - } - - /// Sums all the elements in the packed M31 element. - pub fn pointwise_sum(self) -> M31 { - self.to_array().into_iter().sum() - } -} - -impl Display for PackedBaseField { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let v = self.to_array(); - for elem in v.iter() { - write!(f, "{} ", elem)?; - } - Ok(()) - } -} - -impl Add for PackedBaseField { - type Output = Self; - - /// Adds two packed M31 elements, and reduces the result to the range `[0,P]`. - /// Each value is assumed to be in unreduced form, [0, P] including P. - #[inline(always)] - fn add(self, rhs: Self) -> Self::Output { - Self(unsafe { - // Add word by word. Each word is in the range [0, 2P]. - let c = _mm512_add_epi32(self.0, rhs.0); - // Apply min(c, c-P) to each word. - // When c in [P,2P], then c-P in [0,P] which is always less than [P,2P]. - // When c in [0,P-1], then c-P in [2^32-P,2^32-1] which is always greater than [0,P-1]. - _mm512_min_epu32(c, _mm512_sub_epi32(c, M512P)) - }) - } -} - -impl AddAssign for PackedBaseField { - #[inline(always)] - fn add_assign(&mut self, rhs: Self) { - *self = *self + rhs; - } -} - -impl Mul for PackedBaseField { - type Output = Self; - - /// Computes the product of two packed M31 elements - /// Each value is assumed to be in unreduced form, [0, P] including P. - /// Returned values are in unreduced form, [0, P] including P. - #[inline(always)] - fn mul(self, rhs: Self) -> Self::Output { - /// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a - /// with the even words of b. - const EVENS_INTERLEAVE_EVENS: __m512i = unsafe { - core::mem::transmute([ - 0b00000, 0b10000, 0b00010, 0b10010, 0b00100, 0b10100, 0b00110, 0b10110, 0b01000, - 0b11000, 0b01010, 0b11010, 0b01100, 0b11100, 0b01110, 0b11110, - ]) - }; - /// An input to _mm512_permutex2var_epi32, and is used to interleave the odd words of a - /// with the odd words of b. - const ODDS_INTERLEAVE_ODDS: __m512i = unsafe { - core::mem::transmute([ - 0b00001, 0b10001, 0b00011, 0b10011, 0b00101, 0b10101, 0b00111, 0b10111, 0b01001, - 0b11001, 0b01011, 0b11011, 0b01101, 0b11101, 0b01111, 0b11111, - ]) - }; - - unsafe { - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of - // the first operand. - let val0_e = self.0; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of - // the first operand. - let val0_o = _mm512_srli_epi64(self.0, 32); - - // Double the second operand. - let val1 = _mm512_add_epi32(rhs.0, rhs.0); - let val1_e = val1; - let val1_o = _mm512_srli_epi64(val1, 32); - - // To compute prod = val0 * val1 start by multiplying - // val0_e/o by val1_e/o. - let prod_e_dbl = _mm512_mul_epu32(val0_e, val1_e); - let prod_o_dbl = _mm512_mul_epu32(val0_o, val1_o); - - // The result of a multiplication holds val1*twiddle_dbl in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_e_dbl - |0|prod_e_h|prod_e_l|0| - // prod_o_dbl - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl); - // prod_ls - |prod_o_l|0|prod_e_l|0| - - // Divide by 2: - let prod_ls = Self(_mm512_srli_epi64(prod_ls, 1)); - // prod_ls - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_hs = Self(_mm512_permutex2var_epi32( - prod_e_dbl, - ODDS_INTERLEAVE_ODDS, - prod_o_dbl, - )); - // prod_hs - |0|prod_o_h|0|prod_e_h| - - Self::add(prod_ls, prod_hs) - } - } -} - -impl MulAssign for PackedBaseField { - #[inline(always)] - fn mul_assign(&mut self, rhs: Self) { - *self = *self * rhs; - } -} - -impl Neg for PackedBaseField { - type Output = Self; - - #[inline(always)] - fn neg(self) -> Self::Output { - Self(unsafe { _mm512_sub_epi32(M512P, self.0) }) - } -} - -/// Subtracts two packed M31 elements, and reduces the result to the range `[0,P]`. -/// Each value is assumed to be in unreduced form, [0, P] including P. -impl Sub for PackedBaseField { - type Output = Self; - - #[inline(always)] - fn sub(self, rhs: Self) -> Self::Output { - Self(unsafe { - // Subtract word by word. Each word is in the range [-P, P]. - let c = _mm512_sub_epi32(self.0, rhs.0); - // Apply min(c, c+P) to each word. - // When c in [0,P], then c+P in [P,2P] which is always greater than [0,P]. - // When c in [2^32-P,2^32-1], then c+P in [0,P-1] which is always less than - // [2^32-P,2^32-1]. - _mm512_min_epu32(_mm512_add_epi32(c, M512P), c) - }) - } -} - -impl SubAssign for PackedBaseField { - #[inline(always)] - fn sub_assign(&mut self, rhs: Self) { - *self = *self - rhs; - } -} - -impl Zero for PackedBaseField { - fn zero() -> Self { - Self(unsafe { _mm512_setzero_si512() }) - } - fn is_zero(&self) -> bool { - self.to_array().iter().all(|x| x.is_zero()) - } -} - -impl One for PackedBaseField { - fn one() -> Self { - Self(unsafe { core::mem::transmute([M31::one(); K_BLOCK_SIZE]) }) - } -} - -impl FieldExpOps for PackedBaseField { - fn inverse(&self) -> Self { - assert!(!self.is_zero(), "0 has no inverse"); - pow2147483645(*self) - } -} - -impl Sum for PackedBaseField { - fn sum>(iter: I) -> Self { - iter.fold(Self::zero(), Add::add) - } -} - -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -#[cfg(test)] -mod tests { - - use itertools::Itertools; - - use super::PackedBaseField; - use crate::core::fields::m31::{M31, P}; - use crate::core::fields::{Field, FieldExpOps}; - - /// Tests field operations where field elements are in reduced form. - #[test] - fn test_avx512_basic_ops() { - if !crate::platform::avx512_detected() { - return; - } - - let values = [ - 0, - 1, - 2, - 10, - (P - 1) / 2, - (P + 1) / 2, - P - 2, - P - 1, - 0, - 1, - 2, - 10, - (P - 1) / 2, - (P + 1) / 2, - P - 2, - P - 1, - ] - .map(M31::from_u32_unchecked); - let avx_values = PackedBaseField::from_array(values); - - assert_eq!( - (avx_values + avx_values) - .to_array() - .into_iter() - .collect_vec(), - values.iter().map(|x| x.double()).collect_vec() - ); - assert_eq!( - (avx_values * avx_values) - .to_array() - .into_iter() - .collect_vec(), - values.iter().map(|x| x.square()).collect_vec() - ); - assert_eq!( - (-avx_values).to_array().into_iter().collect_vec(), - values.iter().map(|x| -*x).collect_vec() - ); - } -} diff --git a/crates/prover/src/core/backend/avx512/mod.rs b/crates/prover/src/core/backend/avx512/mod.rs deleted file mode 100644 index 7afe5f990..000000000 --- a/crates/prover/src/core/backend/avx512/mod.rs +++ /dev/null @@ -1,342 +0,0 @@ -mod accumulation; -pub mod bit_reverse; -mod blake2s; -pub mod blake2s_avx; -pub mod circle; -pub mod cm31; -pub mod fft; -mod fri; -pub mod m31; -pub mod qm31; -pub mod quotients; -pub mod tranpose_utils; - -use bytemuck::{cast_slice, cast_slice_mut, Pod, Zeroable}; -use itertools::{izip, Itertools}; -use num_traits::Zero; - -use self::bit_reverse::bit_reverse_m31; -use self::cm31::PackedCM31; -pub use self::m31::{PackedBaseField, K_BLOCK_SIZE}; -use self::qm31::PackedSecureField; -use super::{Backend, CPUBackend, Column, ColumnOps}; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SecureColumn; -use crate::core::fields::{FieldExpOps, FieldOps}; -use crate::core::utils; - -pub const VECS_LOG_SIZE: usize = 4; - -#[derive(Copy, Clone, Debug)] -pub struct AVX512Backend; - -impl Backend for AVX512Backend {} - -// BaseField. -// TODO(spapini): Unite with the M31AVX512 type. - -unsafe impl Pod for PackedBaseField {} -unsafe impl Zeroable for PackedBaseField { - fn zeroed() -> Self { - unsafe { core::mem::zeroed() } - } -} - -#[derive(Clone, Debug)] -pub struct BaseFieldVec { - pub data: Vec, - length: usize, -} - -impl BaseFieldVec { - pub fn as_slice(&self) -> &[BaseField] { - let data: &[BaseField] = cast_slice(&self.data[..]); - &data[..self.length] - } - pub fn as_mut_slice(&mut self) -> &mut [BaseField] { - let data: &mut [BaseField] = cast_slice_mut(&mut self.data[..]); - &mut data[..self.length] - } -} - -impl ColumnOps for AVX512Backend { - type Column = BaseFieldVec; - - fn bit_reverse_column(column: &mut Self::Column) { - // Fallback to cpu bit_reverse. - if column.data.len().ilog2() < bit_reverse::MIN_LOG_SIZE { - utils::bit_reverse(column.as_mut_slice()); - return; - } - bit_reverse_m31(&mut column.data); - } -} - -impl FieldOps for AVX512Backend { - fn batch_inverse(column: &Self::Column, dst: &mut Self::Column) { - PackedBaseField::batch_inverse(&column.data, &mut dst.data); - } -} - -impl Column for BaseFieldVec { - fn zeros(len: usize) -> Self { - Self { - data: vec![PackedBaseField::zeroed(); len.div_ceil(K_BLOCK_SIZE)], - length: len, - } - } - fn to_cpu(&self) -> Vec { - self.data - .iter() - .flat_map(|x| x.to_array()) - .take(self.length) - .collect() - } - fn len(&self) -> usize { - self.length - } - fn at(&self, index: usize) -> BaseField { - self.data[index / K_BLOCK_SIZE].to_array()[index % K_BLOCK_SIZE] - } -} - -fn as_cpu_vec(values: BaseFieldVec) -> Vec { - let capacity = values.data.capacity() * K_BLOCK_SIZE; - unsafe { - let res = Vec::from_raw_parts( - values.data.as_ptr() as *mut BaseField, - values.length, - capacity, - ); - std::mem::forget(values); - res - } -} - -impl FromIterator for BaseFieldVec { - fn from_iter>(iter: I) -> Self { - let mut chunks = iter.into_iter().array_chunks(); - let mut res: Vec<_> = (&mut chunks).map(PackedBaseField::from_array).collect(); - let mut length = res.len() * K_BLOCK_SIZE; - - if let Some(remainder) = chunks.into_remainder() { - if !remainder.is_empty() { - length += remainder.len(); - let pad_len = 16 - remainder.len(); - let last = PackedBaseField::from_array( - remainder - .chain(std::iter::repeat(BaseField::zero()).take(pad_len)) - .collect::>() - .try_into() - .unwrap(), - ); - res.push(last); - } - } - - Self { data: res, length } - } -} - -#[derive(Clone, Debug, Default)] -pub struct SecureFieldVec { - pub data: Vec, - length: usize, -} - -impl ColumnOps for AVX512Backend { - type Column = SecureFieldVec; - - fn bit_reverse_column(column: &mut Self::Column) { - // Fallback to cpu bit_reverse. - // TODO(AlonH): Implement AVX512 bit_reverse for SecureField. - utils::bit_reverse(column.to_cpu().as_mut_slice()); - } -} - -impl FieldOps for AVX512Backend { - fn batch_inverse(column: &Self::Column, dst: &mut Self::Column) { - PackedSecureField::batch_inverse(&column.data, &mut dst.data); - } -} - -impl Column for SecureFieldVec { - fn zeros(len: usize) -> Self { - Self { - data: vec![PackedSecureField::zeroed(); len.div_ceil(K_BLOCK_SIZE)], - length: len, - } - } - fn to_cpu(&self) -> Vec { - self.data - .iter() - .flat_map(|x| x.to_array()) - .take(self.length) - .collect() - } - fn len(&self) -> usize { - self.length - } - fn at(&self, index: usize) -> SecureField { - self.data[index / K_BLOCK_SIZE].to_array()[index % K_BLOCK_SIZE] - } -} - -impl Extend for SecureFieldVec { - fn extend>(&mut self, iter: T) { - self.data.extend(iter); - self.length = self.data.len() * K_BLOCK_SIZE; - } -} - -impl FromIterator for SecureFieldVec { - fn from_iter>(iter: I) -> Self { - let mut chunks = iter.into_iter().array_chunks(); - let mut res: Vec<_> = (&mut chunks).map(PackedSecureField::from_array).collect(); - let mut length = res.len() * K_BLOCK_SIZE; - - if let Some(remainder) = chunks.into_remainder() { - if !remainder.is_empty() { - length += remainder.len(); - let pad_len = 16 - remainder.len(); - let last = PackedSecureField::from_array( - remainder - .chain(std::iter::repeat(SecureField::zero()).take(pad_len)) - .collect::>() - .try_into() - .unwrap(), - ); - res.push(last); - } - } - - Self { data: res, length } - } -} - -impl FromIterator for SecureFieldVec { - fn from_iter>(iter: I) -> Self { - let data = (&mut iter.into_iter()).collect_vec(); - let length = data.len() * K_BLOCK_SIZE; - - Self { data, length } - } -} - -impl SecureColumn { - pub fn n_packs(&self) -> usize { - self.columns[0].data.len() - } - - /// # Safety - /// - /// `vec_index` must be a valid index. - pub fn packed_at(&self, vec_index: usize) -> PackedSecureField { - unsafe { - PackedSecureField([ - PackedCM31([ - *self.columns[0].data.get_unchecked(vec_index), - *self.columns[1].data.get_unchecked(vec_index), - ]), - PackedCM31([ - *self.columns[2].data.get_unchecked(vec_index), - *self.columns[3].data.get_unchecked(vec_index), - ]), - ]) - } - } - - /// # Safety - /// - /// `vec_index` must be a valid index. - pub unsafe fn set_packed(&mut self, vec_index: usize, value: PackedSecureField) { - *self.columns[0].data.get_unchecked_mut(vec_index) = value.a().a(); - *self.columns[1].data.get_unchecked_mut(vec_index) = value.a().b(); - *self.columns[2].data.get_unchecked_mut(vec_index) = value.b().a(); - *self.columns[3].data.get_unchecked_mut(vec_index) = value.b().b(); - } - - pub fn to_vec(&self) -> Vec { - izip!( - self.columns[0].to_cpu(), - self.columns[1].to_cpu(), - self.columns[2].to_cpu(), - self.columns[3].to_cpu(), - ) - .map(|(a, b, c, d)| SecureField::from_m31_array([a, b, c, d])) - .collect() - } -} - -impl FromIterator for SecureColumn { - fn from_iter>(iter: I) -> Self { - let cpu_col = SecureColumn::::from_iter(iter); - SecureColumn { - columns: cpu_col.columns.map(|col| col.into_iter().collect()), - } - } -} - -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -#[cfg(test)] -mod tests { - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use super::*; - use crate::core::backend::{Col, Column}; - - type B = AVX512Backend; - - #[test] - fn test_column() { - for i in 0..100 { - let col = Col::::from_iter((0..i).map(BaseField::from)); - assert_eq!( - col.to_cpu(), - (0..i).map(BaseField::from).collect::>() - ); - for j in 0..i { - assert_eq!(col.at(j), BaseField::from(j)); - } - } - } - - #[test] - fn test_bit_reverse() { - for i in 1..16 { - let len = 1 << i; - let mut col = Col::::from_iter((0..len).map(BaseField::from)); - >::bit_reverse_column(&mut col); - assert_eq!( - col.to_cpu(), - (0..len) - .map(|x| BaseField::from(utils::bit_reverse_index(x, i as u32))) - .collect::>() - ); - } - } - - #[test] - fn test_as_cpu_vec() { - let original_vec = (1000..1100).map(BaseField::from).collect::>(); - let col = Col::::from_iter(original_vec.clone()); - let vec = as_cpu_vec(col); - assert_eq!(vec, original_vec); - } - - #[test] - fn test_packed_basefield_batch_inverse() { - let mut rng = SmallRng::seed_from_u64(0); - let column = (0..64).map(|_| rng.gen()).collect::(); - let expected = column.data.iter().map(|e| e.inverse()).collect::>(); - let mut dst = (0..64).map(|_| BaseField::zero()).collect::(); - - >::batch_inverse(&column, &mut dst); - - dst.data.iter().zip(expected.iter()).for_each(|(a, b)| { - assert_eq!(a.to_array(), b.to_array()); - }); - } -} diff --git a/crates/prover/src/core/backend/avx512/qm31.rs b/crates/prover/src/core/backend/avx512/qm31.rs deleted file mode 100644 index 0103c8ecc..000000000 --- a/crates/prover/src/core/backend/avx512/qm31.rs +++ /dev/null @@ -1,204 +0,0 @@ -use std::iter::Sum; -use std::ops::{Add, AddAssign, Mul, MulAssign, Sub}; - -use bytemuck::{Pod, Zeroable}; -use num_traits::{One, Zero}; - -use super::cm31::PackedCM31; -use super::m31::K_BLOCK_SIZE; -use super::PackedBaseField; -use crate::core::fields::qm31::QM31; -use crate::core::fields::FieldExpOps; - -/// AVX implementation for an extension of CM31. -/// See [crate::core::fields::qm31::QM31] for more information. -#[derive(Copy, Clone, Debug)] -pub struct PackedSecureField(pub [PackedCM31; 2]); -impl PackedSecureField { - pub fn zero() -> Self { - Self([ - PackedCM31([PackedBaseField::zero(); 2]), - PackedCM31([PackedBaseField::zero(); 2]), - ]) - } - pub fn broadcast(value: QM31) -> Self { - Self([ - PackedCM31::broadcast(value.0), - PackedCM31::broadcast(value.1), - ]) - } - pub fn a(&self) -> PackedCM31 { - self.0[0] - } - pub fn b(&self) -> PackedCM31 { - self.0[1] - } - pub fn to_array(&self) -> [QM31; K_BLOCK_SIZE] { - std::array::from_fn(|i| QM31(self.a().to_array()[i], self.b().to_array()[i])) - } - - pub fn from_array(array: [QM31; K_BLOCK_SIZE]) -> Self { - let a = PackedBaseField::from_array(std::array::from_fn(|i| array[i].0 .0)); - let b = PackedBaseField::from_array(std::array::from_fn(|i| array[i].0 .1)); - let c = PackedBaseField::from_array(std::array::from_fn(|i| array[i].1 .0)); - let d = PackedBaseField::from_array(std::array::from_fn(|i| array[i].1 .1)); - Self([PackedCM31([a, b]), PackedCM31([c, d])]) - } - - // Multiply packed QM31 by packed M31. - pub fn mul_packed_m31(&self, rhs: PackedBaseField) -> PackedSecureField { - Self::from_packed_m31s(self.to_packed_m31s().map(|m31| m31 * rhs)) - } - - /// Sums all the elements in the packed M31 element. - pub fn pointwise_sum(self) -> QM31 { - self.to_array().into_iter().sum() - } - - pub fn to_packed_m31s(&self) -> [PackedBaseField; 4] { - [self.a().a(), self.a().b(), self.b().a(), self.b().b()] - } - - pub fn from_packed_m31s(array: [PackedBaseField; 4]) -> Self { - Self([ - PackedCM31([array[0], array[1]]), - PackedCM31([array[2], array[3]]), - ]) - } -} -impl Add for PackedSecureField { - type Output = Self; - fn add(self, rhs: Self) -> Self::Output { - Self([self.a() + rhs.a(), self.b() + rhs.b()]) - } -} -impl Sub for PackedSecureField { - type Output = Self; - fn sub(self, rhs: Self) -> Self::Output { - Self([self.a() - rhs.a(), self.b() - rhs.b()]) - } -} -impl Mul for PackedSecureField { - type Output = Self; - fn mul(self, rhs: Self) -> Self::Output { - // Compute using Karatsuba. - // (a + ub) * (c + ud) = - // (ac + (2+i)bd) + (ad + bc)u = - // ac + 2bd + ibd + (ad + bc)u. - let ac = self.a() * rhs.a(); - let bd = self.b() * rhs.b(); - let bd_times_1_plus_i = PackedCM31([bd.a() - bd.b(), bd.a() + bd.b()]); - // Computes ac + bd. - let ac_p_bd = ac + bd; - // Computes ad + bc. - let ad_p_bc = (self.a() + self.b()) * (rhs.a() + rhs.b()) - ac_p_bd; - // ac + 2bd + ibd = - // ac + bd + bd + ibd - let l = PackedCM31([ - ac_p_bd.a() + bd_times_1_plus_i.a(), - ac_p_bd.b() + bd_times_1_plus_i.b(), - ]); - Self([l, ad_p_bc]) - } -} -impl Zero for PackedSecureField { - fn zero() -> Self { - Self([PackedCM31::zero(), PackedCM31::zero()]) - } - fn is_zero(&self) -> bool { - self.a().is_zero() && self.b().is_zero() - } -} -impl One for PackedSecureField { - fn one() -> Self { - Self([PackedCM31::one(), PackedCM31::zero()]) - } -} -impl AddAssign for PackedSecureField { - fn add_assign(&mut self, rhs: Self) { - *self = *self + rhs; - } -} -impl MulAssign for PackedSecureField { - fn mul_assign(&mut self, rhs: Self) { - *self = *self * rhs; - } -} -impl FieldExpOps for PackedSecureField { - fn inverse(&self) -> Self { - assert!(!self.is_zero(), "0 has no inverse"); - // (a + bu)^-1 = (a - bu) / (a^2 - (2+i)b^2). - let b2 = self.b().square(); - let ib2 = PackedCM31([-b2.b(), b2.a()]); - let denom = self.a().square() - (b2 + b2 + ib2); - let denom_inverse = denom.inverse(); - Self([self.a() * denom_inverse, -self.b() * denom_inverse]) - } -} - -impl Add for PackedSecureField { - type Output = Self; - fn add(self, rhs: PackedBaseField) -> Self::Output { - Self([self.a() + rhs, self.b()]) - } -} -impl Sub for PackedSecureField { - type Output = Self; - fn sub(self, rhs: PackedBaseField) -> Self::Output { - Self([self.a() - rhs, self.b()]) - } -} -impl Mul for PackedSecureField { - type Output = Self; - fn mul(self, rhs: PackedBaseField) -> Self::Output { - Self([self.a() * rhs, self.b() * rhs]) - } -} - -impl Sum for PackedSecureField { - fn sum>(iter: I) -> Self { - iter.fold(Self::zero(), Add::add) - } -} - -unsafe impl Pod for PackedSecureField {} -unsafe impl Zeroable for PackedSecureField { - fn zeroed() -> Self { - unsafe { core::mem::zeroed() } - } -} - -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -#[cfg(test)] -mod tests { - use rand::rngs::SmallRng; - use rand::{Rng, SeedableRng}; - - use super::*; - - #[test] - fn test_qm31avx512_basic_ops() { - let mut rng = SmallRng::seed_from_u64(0); - let x = PackedSecureField::from_array(rng.gen()); - let y = PackedSecureField::from_array(rng.gen()); - let sum = x + y; - let diff = x - y; - let prod = x * y; - for i in 0..16 { - assert_eq!(sum.to_array()[i], x.to_array()[i] + y.to_array()[i]); - assert_eq!(diff.to_array()[i], x.to_array()[i] - y.to_array()[i]); - assert_eq!(prod.to_array()[i], x.to_array()[i] * y.to_array()[i]); - } - } - - #[test] - fn test_from_array() { - let mut rng = SmallRng::seed_from_u64(0); - let x_arr = std::array::from_fn(|_| rng.gen()); - - let packed = PackedSecureField::from_array(x_arr); - let to_arr = packed.to_array(); - - assert_eq!(to_arr, x_arr); - } -} diff --git a/crates/prover/src/core/backend/avx512/quotients.rs b/crates/prover/src/core/backend/avx512/quotients.rs deleted file mode 100644 index 97759fc8d..000000000 --- a/crates/prover/src/core/backend/avx512/quotients.rs +++ /dev/null @@ -1,220 +0,0 @@ -use itertools::{izip, Itertools}; -use num_traits::One; - -use super::qm31::PackedSecureField; -use super::{AVX512Backend, SecureFieldVec, K_BLOCK_SIZE, VECS_LOG_SIZE}; -use crate::core::backend::avx512::PackedBaseField; -use crate::core::backend::cpu::quotients::{ - batch_random_coeffs, column_line_coeffs, QuotientConstants, -}; -use crate::core::backend::{Col, Column}; -use crate::core::circle::CirclePoint; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; -use crate::core::fields::secure_column::SecureColumn; -use crate::core::fields::FieldOps; -use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; -use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation}; -use crate::core::poly::BitReversedOrder; -use crate::core::utils::bit_reverse_index; - -impl QuotientOps for AVX512Backend { - fn accumulate_quotients( - domain: CircleDomain, - columns: &[&CircleEvaluation], - random_coeff: SecureField, - sample_batches: &[ColumnSampleBatch], - ) -> SecureEvaluation { - assert!(domain.log_size() >= VECS_LOG_SIZE as u32); - let mut values = SecureColumn::::zeros(domain.size()); - let quotient_constants = quotient_constants(sample_batches, random_coeff, domain); - - // TODO(spapini): bit reverse iterator. - for vec_row in 0..(1 << (domain.log_size() - VECS_LOG_SIZE as u32)) { - // TODO(spapini): Optimize this, for the small number of columns case. - let points = std::array::from_fn(|i| { - domain.at(bit_reverse_index( - (vec_row << VECS_LOG_SIZE) + i, - domain.log_size(), - )) - }); - let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x)); - let domain_points_y = PackedBaseField::from_array(points.map(|p| p.y)); - let row_accumulator = accumulate_row_quotients( - sample_batches, - columns, - "ient_constants, - vec_row, - (domain_points_x, domain_points_y), - ); - unsafe { values.set_packed(vec_row, row_accumulator) }; - } - SecureEvaluation { domain, values } - } -} - -// TODO(Ohad): no longer using pair_vanishing, remove domain_point_vec and line_coeffs, or write a -// function that deals with quotients over pair_vanishing polynomials. -pub fn accumulate_row_quotients( - sample_batches: &[ColumnSampleBatch], - columns: &[&CircleEvaluation], - quotient_constants: &QuotientConstants, - vec_row: usize, - _domain_point_vec: (PackedBaseField, PackedBaseField), -) -> PackedSecureField { - let mut row_accumulator = PackedSecureField::zero(); - for (sample_batch, _, batch_coeff, denominator_inverses) in izip!( - sample_batches, - "ient_constants.line_coeffs, - "ient_constants.batch_random_coeffs, - "ient_constants.denominator_inverses - ) { - let mut numerator = PackedSecureField::zero(); - for (column_index, sampled_value) in sample_batch.columns_and_values.iter() { - let column = &columns[*column_index]; - let value = column.data[vec_row]; - numerator += PackedSecureField::broadcast(-*sampled_value) + value; - } - - row_accumulator = row_accumulator * PackedSecureField::broadcast(*batch_coeff) - + numerator * denominator_inverses.data[vec_row]; - } - row_accumulator -} - -/// Point vanishing for the packed representation of the points. skips the division. -/// See [crate::core::constraints::point_vanishing_fraction] for more details. -fn packed_point_vanishing_fraction( - excluded: CirclePoint, - p: (PackedBaseField, PackedBaseField), -) -> (PackedSecureField, PackedSecureField) { - let e_conjugate = excluded.conjugate(); - let h_x = PackedSecureField::broadcast(e_conjugate.x) * p.0 - - PackedSecureField::broadcast(e_conjugate.y) * p.1; - let h_y = PackedSecureField::broadcast(e_conjugate.y) * p.0 - + PackedSecureField::broadcast(e_conjugate.x) * p.1; - (h_y, (PackedSecureField::one() + h_x)) -} - -fn denominator_inverses( - sample_batches: &[ColumnSampleBatch], - domain: CircleDomain, -) -> Vec> { - let (denominators, numerators): (SecureFieldVec, SecureFieldVec) = sample_batches - .iter() - .flat_map(|sample_batch| { - (0..(1 << (domain.log_size() - VECS_LOG_SIZE as u32))) - .map(|vec_row| { - // TODO(spapini): Optimize this, for the small number of columns case. - let points = std::array::from_fn(|i| { - domain.at(bit_reverse_index( - (vec_row << VECS_LOG_SIZE) + i, - domain.log_size(), - )) - }); - let domain_points_x = PackedBaseField::from_array(points.map(|p| p.x)); - let domain_points_y = PackedBaseField::from_array(points.map(|p| p.y)); - let domain_point_vec = (domain_points_x, domain_points_y); - - packed_point_vanishing_fraction(sample_batch.point, domain_point_vec) - }) - .collect_vec() - }) - .unzip(); - - let mut flat_denominator_inverses = SecureFieldVec::zeros(denominators.len()); - >::batch_inverse( - &denominators, - &mut flat_denominator_inverses, - ); - - flat_denominator_inverses - .data - .iter_mut() - .zip(&numerators.data) - .for_each(|(inv, denom_denom)| *inv *= *denom_denom); - - flat_denominator_inverses - .data - .chunks(domain.size() / K_BLOCK_SIZE) - .map(|denominator_inverses| denominator_inverses.iter().copied().collect()) - .collect() -} - -fn quotient_constants( - sample_batches: &[ColumnSampleBatch], - random_coeff: SecureField, - domain: CircleDomain, -) -> QuotientConstants { - let line_coeffs = column_line_coeffs(sample_batches, random_coeff); - let batch_random_coeffs = batch_random_coeffs(sample_batches, random_coeff); - let denominator_inverses = denominator_inverses(sample_batches, domain); - QuotientConstants { - line_coeffs, - batch_random_coeffs, - denominator_inverses, - } -} - -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] -#[cfg(test)] -mod tests { - use itertools::Itertools; - - use crate::core::backend::avx512::{AVX512Backend, BaseFieldVec}; - use crate::core::backend::{CPUBackend, Column}; - use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; - use crate::core::fields::m31::BaseField; - use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; - use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; - use crate::core::poly::BitReversedOrder; - use crate::qm31; - - #[test] - fn test_avx_accumulate_quotients() { - const LOG_SIZE: u32 = 8; - let domain = CanonicCoset::new(LOG_SIZE).circle_domain(); - let e0: BaseFieldVec = (0..domain.size()).map(BaseField::from).collect(); - let e1: BaseFieldVec = (0..domain.size()).map(|i| BaseField::from(2 * i)).collect(); - let columns = vec![ - CircleEvaluation::::new(domain, e0), - CircleEvaluation::::new(domain, e1), - ]; - let random_coeff = qm31!(1, 2, 3, 4); - let a = qm31!(3, 6, 9, 12); - let b = qm31!(4, 8, 12, 16); - let samples = vec![ColumnSampleBatch { - point: SECURE_FIELD_CIRCLE_GEN, - columns_and_values: vec![(0, a), (1, b)], - }]; - let avx_result = AVX512Backend::accumulate_quotients( - domain, - &columns.iter().collect_vec(), - random_coeff, - &samples, - ) - .values - .to_vec(); - - let cpu_columns = columns - .iter() - .map(|c| { - CircleEvaluation::::new( - c.domain, - c.values.to_cpu(), - ) - }) - .collect::>(); - - let cpu_result = CPUBackend::accumulate_quotients( - domain, - &cpu_columns.iter().collect_vec(), - random_coeff, - &samples, - ) - .values - .to_vec(); - - assert_eq!(avx_result, cpu_result); - } -} diff --git a/crates/prover/src/core/backend/avx512/tranpose_utils.rs b/crates/prover/src/core/backend/avx512/tranpose_utils.rs deleted file mode 100644 index 4e4f5d3a2..000000000 --- a/crates/prover/src/core/backend/avx512/tranpose_utils.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::arch::x86_64::__m512i; - -/// An input to _mm512_permutex2var_epi32, and is used to interleave the low half of a -/// with the low half of b. -pub const LHALF_INTERLEAVE_LHALF: __m512i = unsafe { - core::mem::transmute([ - 0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100, 0b10100, - 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111, - ]) -}; -/// An input to _mm512_permutex2var_epi32, and is used to interleave the high half of a -/// with the high half of b. -pub const HHALF_INTERLEAVE_HHALF: __m512i = unsafe { - core::mem::transmute([ - 0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100, 0b11100, - 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111, - ]) -}; - -/// An input to _mm512_permutex2var_epi32, and is used to concat the even words of a -/// with the even words of b. -pub const EVENS_CONCAT_EVENS: __m512i = unsafe { - core::mem::transmute([ - 0b00000, 0b00010, 0b00100, 0b00110, 0b01000, 0b01010, 0b01100, 0b01110, 0b10000, 0b10010, - 0b10100, 0b10110, 0b11000, 0b11010, 0b11100, 0b11110, - ]) -}; -/// An input to _mm512_permutex2var_epi32, and is used to concat the odd words of a -/// with the odd words of b. -pub const ODDS_CONCAT_ODDS: __m512i = unsafe { - core::mem::transmute([ - 0b00001, 0b00011, 0b00101, 0b00111, 0b01001, 0b01011, 0b01101, 0b01111, 0b10001, 0b10011, - 0b10101, 0b10111, 0b11001, 0b11011, 0b11101, 0b11111, - ]) -}; diff --git a/crates/prover/src/core/backend/mod.rs b/crates/prover/src/core/backend/mod.rs index b8bd2f534..fba8c4d4d 100644 --- a/crates/prover/src/core/backend/mod.rs +++ b/crates/prover/src/core/backend/mod.rs @@ -10,8 +10,6 @@ use super::fri::FriOps; use super::pcs::quotients::QuotientOps; use super::poly::circle::PolyOps; -#[cfg(target_arch = "x86_64")] -pub mod avx512; pub mod cpu; pub mod simd; diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index 9d9a5f926..6d70f257e 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -16,4 +16,3 @@ pub mod core; pub mod examples; pub mod hash_functions; pub mod math; -pub mod platform; diff --git a/crates/prover/src/platform.rs b/crates/prover/src/platform.rs deleted file mode 100644 index 578949b91..000000000 --- a/crates/prover/src/platform.rs +++ /dev/null @@ -1,12 +0,0 @@ -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -#[inline(always)] -#[allow(unreachable_code)] -pub fn avx512_detected() -> bool { - // Static check, e.g. for building with target-cpu=native. - if cfg!(feature = "avx512") { - return true; - } - - // Dynamic check, if std is enabled. - is_x86_feature_detected!("avx512f") -}