Skip to content

Commit

Permalink
Fix compilation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed May 16, 2024
1 parent 9897c48 commit 82c5535
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 43 deletions.
87 changes: 46 additions & 41 deletions crates/prover/src/core/backend/simd/fri.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use super::AVX512Backend;
use crate::core::backend::avx512::fft::compute_first_twiddles;
use crate::core::backend::avx512::fft::ifft::avx_ibutterfly;
use crate::core::backend::avx512::qm31::PackedSecureField;
use crate::core::backend::avx512::VECS_LOG_SIZE;
use std::array;
use std::simd::u32x8;

use super::m31::LOG_N_LANES;
use super::SimdBackend;
use crate::core::backend::simd::fft::compute_first_twiddles;
use crate::core::backend::simd::fft::ifft::ibutterfly;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fri::{self, FriOps};
Expand All @@ -11,14 +14,14 @@ use crate::core::poly::line::LineEvaluation;
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::domain_line_twiddles_from_tree;

impl FriOps for AVX512Backend {
impl FriOps for SimdBackend {
fn fold_line(
eval: &LineEvaluation<Self>,
alpha: SecureField,
twiddles: &TwiddleTree<Self>,
) -> LineEvaluation<Self> {
let log_size = eval.len().ilog2();
if log_size <= VECS_LOG_SIZE as u32 {
if log_size <= LOG_N_LANES {
let eval = fri::fold_line(&eval.to_cpu(), alpha);
return LineEvaluation::new(eval.domain(), eval.values.into_iter().collect());
}
Expand All @@ -28,18 +31,18 @@ impl FriOps for AVX512Backend {

let mut folded_values = SecureColumn::<Self>::zeros(1 << (log_size - 1));

for vec_index in 0..(1 << (log_size - 1 - VECS_LOG_SIZE as u32)) {
for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) {
let value = unsafe {
let twiddle_dbl: [i32; 16] =
std::array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i));
let val0 = eval.values.packed_at(vec_index * 2).to_packed_m31s();
let val1 = eval.values.packed_at(vec_index * 2 + 1).to_packed_m31s();
let pairs: [_; 4] = std::array::from_fn(|i| {
let (a, b) = val0[i].deinterleave_with(val1[i]);
avx_ibutterfly(a, b, std::mem::transmute(twiddle_dbl))
let twiddle_dbl: [u32; 16] =
array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i));
let val0 = eval.values.packed_at(vec_index * 2).into_packed_m31s();
let val1 = eval.values.packed_at(vec_index * 2 + 1).into_packed_m31s();
let pairs: [_; 4] = array::from_fn(|i| {
let (a, b) = val0[i].deinterleave(val1[i]);
ibutterfly(a, b, std::mem::transmute(twiddle_dbl))
});
let val0 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].1));
let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1));
val0 + PackedSecureField::broadcast(alpha) * val1
};
unsafe { folded_values.set_packed(vec_index, value) };
Expand All @@ -55,27 +58,28 @@ impl FriOps for AVX512Backend {
twiddles: &TwiddleTree<Self>,
) {
let log_size = src.len().ilog2();
assert!(log_size > VECS_LOG_SIZE as u32, "Evaluation too small");
assert!(log_size > LOG_N_LANES, "Evaluation too small");

let domain = src.domain;
let alpha_sq = alpha * alpha;
let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0];

for vec_index in 0..(1 << (log_size - 1 - VECS_LOG_SIZE as u32)) {
for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) {
let value = unsafe {
// The 16 twiddles of the circle domain can be derived from the 8 twiddles of the
// next line domain. See `compute_first_twiddles()`.
let twiddle_dbl: [i32; 8] =
std::array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 8 + i));
let twiddle_dbl = u32x8::from_array(array::from_fn(|i| {
*itwiddles.get_unchecked(vec_index * 8 + i)
}));
let (t0, _) = compute_first_twiddles(twiddle_dbl);
let val0 = src.values.packed_at(vec_index * 2).to_packed_m31s();
let val1 = src.values.packed_at(vec_index * 2 + 1).to_packed_m31s();
let pairs: [_; 4] = std::array::from_fn(|i| {
let (a, b) = val0[i].deinterleave_with(val1[i]);
avx_ibutterfly(a, b, t0)
let val0 = src.values.packed_at(vec_index * 2).into_packed_m31s();
let val1 = src.values.packed_at(vec_index * 2 + 1).into_packed_m31s();
let pairs: [_; 4] = array::from_fn(|i| {
let (a, b) = val0[i].deinterleave(val1[i]);
ibutterfly(a, b, t0)
});
let val0 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].1));
let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1));
val0 + PackedSecureField::broadcast(alpha) * val1
};
unsafe {
Expand All @@ -89,10 +93,13 @@ impl FriOps for AVX512Backend {
}
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use crate::core::backend::avx512::AVX512Backend;
use itertools::Itertools;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use crate::core::backend::simd::SimdBackend;
use crate::core::backend::CPUBackend;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
Expand All @@ -104,9 +111,8 @@ mod tests {
#[test]
fn test_fold_line() {
const LOG_SIZE: u32 = 7;
let values: Vec<SecureField> = (0..(1 << LOG_SIZE))
.map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3))
.collect();
let mut rng = SmallRng::seed_from_u64(0);
let values = (0..1 << LOG_SIZE).map(|_| rng.gen()).collect_vec();
let alpha = qm31!(1, 3, 5, 7);
let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset());
let cpu_fold = CPUBackend::fold_line(
Expand All @@ -115,10 +121,10 @@ mod tests {
&CPUBackend::precompute_twiddles(domain.coset()),
);

let avx_fold = AVX512Backend::fold_line(
let avx_fold = SimdBackend::fold_line(
&LineEvaluation::new(domain, values.iter().copied().collect()),
alpha,
&AVX512Backend::precompute_twiddles(domain.coset()),
&SimdBackend::precompute_twiddles(domain.coset()),
);

assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec());
Expand All @@ -133,7 +139,6 @@ mod tests {
let alpha = qm31!(1, 3, 5, 7);
let circle_domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let line_domain = LineDomain::new(circle_domain.half_coset);

let mut cpu_fold =
LineEvaluation::new(line_domain, SecureColumn::zeros(1 << (LOG_SIZE - 1)));
CPUBackend::fold_circle_into_line(
Expand All @@ -146,18 +151,18 @@ mod tests {
&CPUBackend::precompute_twiddles(line_domain.coset()),
);

let mut avx_fold =
let mut simd_fold =
LineEvaluation::new(line_domain, SecureColumn::zeros(1 << (LOG_SIZE - 1)));
AVX512Backend::fold_circle_into_line(
&mut avx_fold,
SimdBackend::fold_circle_into_line(
&mut simd_fold,
&SecureEvaluation {
domain: circle_domain,
values: values.iter().copied().collect(),
},
alpha,
&AVX512Backend::precompute_twiddles(line_domain.coset()),
&SimdBackend::precompute_twiddles(line_domain.coset()),
);

assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec());
assert_eq!(cpu_fold.values.to_vec(), simd_fold.values.to_vec());
}
}
4 changes: 2 additions & 2 deletions crates/prover/src/core/backend/simd/qm31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ impl PackedQM31 {

/// Returns vectors `a, b, c, d` such that element `i` is represented as
/// `QM31(a_i, b_i, c_i, d_i)`.
pub fn into_packed_m31s(self) -> [PackedBaseField; 4] {
pub fn into_packed_m31s(self) -> [PackedM31; 4] {
let Self([PackedCM31([a, b]), PackedCM31([c, d])]) = self;
[a, b, c, d]
}

/// Creates an instance from vectors `a, b, c, d` such that element `i`
/// is represented as `QM31(a_i, b_i, c_i, d_i)`.
pub fn from_packed_m31s([a, b, c, d]: [PackedBaseField; 4]) -> Self {
pub fn from_packed_m31s([a, b, c, d]: [PackedM31; 4]) -> Self {
Self([PackedCM31([a, b]), PackedCM31([c, d])])
}
}
Expand Down

0 comments on commit 82c5535

Please sign in to comment.