diff --git a/crates/prover/src/core/backend/simd/fri.rs b/crates/prover/src/core/backend/simd/fri.rs index edceb18f8..088d07dab 100644 --- a/crates/prover/src/core/backend/simd/fri.rs +++ b/crates/prover/src/core/backend/simd/fri.rs @@ -1,8 +1,11 @@ -use super::AVX512Backend; -use crate::core::backend::avx512::fft::compute_first_twiddles; -use crate::core::backend::avx512::fft::ifft::avx_ibutterfly; -use crate::core::backend::avx512::qm31::PackedSecureField; -use crate::core::backend::avx512::VECS_LOG_SIZE; +use std::array; +use std::simd::u32x8; + +use super::m31::LOG_N_LANES; +use super::SimdBackend; +use crate::core::backend::simd::fft::compute_first_twiddles; +use crate::core::backend::simd::fft::ifft::ibutterfly; +use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumn; use crate::core::fri::{self, FriOps}; @@ -11,14 +14,14 @@ use crate::core::poly::line::LineEvaluation; use crate::core::poly::twiddles::TwiddleTree; use crate::core::poly::utils::domain_line_twiddles_from_tree; -impl FriOps for AVX512Backend { +impl FriOps for SimdBackend { fn fold_line( eval: &LineEvaluation, alpha: SecureField, twiddles: &TwiddleTree, ) -> LineEvaluation { let log_size = eval.len().ilog2(); - if log_size <= VECS_LOG_SIZE as u32 { + if log_size <= LOG_N_LANES { let eval = fri::fold_line(&eval.to_cpu(), alpha); return LineEvaluation::new(eval.domain(), eval.values.into_iter().collect()); } @@ -28,18 +31,18 @@ impl FriOps for AVX512Backend { let mut folded_values = SecureColumn::::zeros(1 << (log_size - 1)); - for vec_index in 0..(1 << (log_size - 1 - VECS_LOG_SIZE as u32)) { + for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) { let value = unsafe { - let twiddle_dbl: [i32; 16] = - std::array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i)); - let val0 = eval.values.packed_at(vec_index * 2).to_packed_m31s(); - let val1 = eval.values.packed_at(vec_index * 2 + 1).to_packed_m31s(); - let pairs: [_; 4] = std::array::from_fn(|i| { - let (a, b) = val0[i].deinterleave_with(val1[i]); - avx_ibutterfly(a, b, std::mem::transmute(twiddle_dbl)) + let twiddle_dbl: [u32; 16] = + array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i)); + let val0 = eval.values.packed_at(vec_index * 2).into_packed_m31s(); + let val1 = eval.values.packed_at(vec_index * 2 + 1).into_packed_m31s(); + let pairs: [_; 4] = array::from_fn(|i| { + let (a, b) = val0[i].deinterleave(val1[i]); + ibutterfly(a, b, std::mem::transmute(twiddle_dbl)) }); - let val0 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].0)); - let val1 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].1)); + let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0)); + let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1)); val0 + PackedSecureField::broadcast(alpha) * val1 }; unsafe { folded_values.set_packed(vec_index, value) }; @@ -55,27 +58,28 @@ impl FriOps for AVX512Backend { twiddles: &TwiddleTree, ) { let log_size = src.len().ilog2(); - assert!(log_size > VECS_LOG_SIZE as u32, "Evaluation too small"); + assert!(log_size > LOG_N_LANES, "Evaluation too small"); let domain = src.domain; let alpha_sq = alpha * alpha; let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0]; - for vec_index in 0..(1 << (log_size - 1 - VECS_LOG_SIZE as u32)) { + for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) { let value = unsafe { // The 16 twiddles of the circle domain can be derived from the 8 twiddles of the // next line domain. See `compute_first_twiddles()`. - let twiddle_dbl: [i32; 8] = - std::array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 8 + i)); + let twiddle_dbl = u32x8::from_array(array::from_fn(|i| { + *itwiddles.get_unchecked(vec_index * 8 + i) + })); let (t0, _) = compute_first_twiddles(twiddle_dbl); - let val0 = src.values.packed_at(vec_index * 2).to_packed_m31s(); - let val1 = src.values.packed_at(vec_index * 2 + 1).to_packed_m31s(); - let pairs: [_; 4] = std::array::from_fn(|i| { - let (a, b) = val0[i].deinterleave_with(val1[i]); - avx_ibutterfly(a, b, t0) + let val0 = src.values.packed_at(vec_index * 2).into_packed_m31s(); + let val1 = src.values.packed_at(vec_index * 2 + 1).into_packed_m31s(); + let pairs: [_; 4] = array::from_fn(|i| { + let (a, b) = val0[i].deinterleave(val1[i]); + ibutterfly(a, b, t0) }); - let val0 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].0)); - let val1 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].1)); + let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0)); + let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1)); val0 + PackedSecureField::broadcast(alpha) * val1 }; unsafe { @@ -89,10 +93,13 @@ impl FriOps for AVX512Backend { } } -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] #[cfg(test)] mod tests { - use crate::core::backend::avx512::AVX512Backend; + use itertools::Itertools; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use crate::core::backend::simd::SimdBackend; use crate::core::backend::CPUBackend; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumn; @@ -104,9 +111,8 @@ mod tests { #[test] fn test_fold_line() { const LOG_SIZE: u32 = 7; - let values: Vec = (0..(1 << LOG_SIZE)) - .map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3)) - .collect(); + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..1 << LOG_SIZE).map(|_| rng.gen()).collect_vec(); let alpha = qm31!(1, 3, 5, 7); let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset()); let cpu_fold = CPUBackend::fold_line( @@ -115,10 +121,10 @@ mod tests { &CPUBackend::precompute_twiddles(domain.coset()), ); - let avx_fold = AVX512Backend::fold_line( + let avx_fold = SimdBackend::fold_line( &LineEvaluation::new(domain, values.iter().copied().collect()), alpha, - &AVX512Backend::precompute_twiddles(domain.coset()), + &SimdBackend::precompute_twiddles(domain.coset()), ); assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec()); @@ -133,7 +139,6 @@ mod tests { let alpha = qm31!(1, 3, 5, 7); let circle_domain = CanonicCoset::new(LOG_SIZE).circle_domain(); let line_domain = LineDomain::new(circle_domain.half_coset); - let mut cpu_fold = LineEvaluation::new(line_domain, SecureColumn::zeros(1 << (LOG_SIZE - 1))); CPUBackend::fold_circle_into_line( @@ -146,18 +151,18 @@ mod tests { &CPUBackend::precompute_twiddles(line_domain.coset()), ); - let mut avx_fold = + let mut simd_fold = LineEvaluation::new(line_domain, SecureColumn::zeros(1 << (LOG_SIZE - 1))); - AVX512Backend::fold_circle_into_line( - &mut avx_fold, + SimdBackend::fold_circle_into_line( + &mut simd_fold, &SecureEvaluation { domain: circle_domain, values: values.iter().copied().collect(), }, alpha, - &AVX512Backend::precompute_twiddles(line_domain.coset()), + &SimdBackend::precompute_twiddles(line_domain.coset()), ); - assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec()); + assert_eq!(cpu_fold.values.to_vec(), simd_fold.values.to_vec()); } } diff --git a/crates/prover/src/core/backend/simd/qm31.rs b/crates/prover/src/core/backend/simd/qm31.rs index 4ba30b0c5..9b0a01ea1 100644 --- a/crates/prover/src/core/backend/simd/qm31.rs +++ b/crates/prover/src/core/backend/simd/qm31.rs @@ -79,14 +79,14 @@ impl PackedQM31 { /// Returns vectors `a, b, c, d` such that element `i` is represented as /// `QM31(a_i, b_i, c_i, d_i)`. - pub fn into_packed_m31s(self) -> [PackedBaseField; 4] { + pub fn into_packed_m31s(self) -> [PackedM31; 4] { let Self([PackedCM31([a, b]), PackedCM31([c, d])]) = self; [a, b, c, d] } /// Creates an instance from vectors `a, b, c, d` such that element `i` /// is represented as `QM31(a_i, b_i, c_i, d_i)`. - pub fn from_packed_m31s([a, b, c, d]: [PackedBaseField; 4]) -> Self { + pub fn from_packed_m31s([a, b, c, d]: [PackedM31; 4]) -> Self { Self([PackedCM31([a, b]), PackedCM31([c, d])]) } }