Skip to content

Commit

Permalink
Fix compilation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed May 16, 2024
1 parent 91257ed commit efb39ee
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 45 deletions.
189 changes: 146 additions & 43 deletions crates/prover/src/core/backend/simd/fri.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
use super::AVX512Backend;
use crate::core::backend::avx512::fft::compute_first_twiddles;
use crate::core::backend::avx512::fft::ifft::avx_ibutterfly;
use crate::core::backend::avx512::qm31::PackedSecureField;
use crate::core::backend::avx512::VECS_LOG_SIZE;
use std::array;
use std::simd::u32x8;

use num_traits::Zero;

use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use super::SimdBackend;
use crate::core::backend::simd::fft::compute_first_twiddles;
use crate::core::backend::simd::fft::ifft::simd_ibutterfly;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fri::{self, FriOps};
Expand All @@ -11,14 +18,14 @@ use crate::core::poly::line::LineEvaluation;
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::domain_line_twiddles_from_tree;

impl FriOps for AVX512Backend {
impl FriOps for SimdBackend {
fn fold_line(
eval: &LineEvaluation<Self>,
alpha: SecureField,
twiddles: &TwiddleTree<Self>,
) -> LineEvaluation<Self> {
let log_size = eval.len().ilog2();
if log_size <= VECS_LOG_SIZE as u32 {
if log_size <= LOG_N_LANES {
let eval = fri::fold_line(&eval.to_cpu(), alpha);
return LineEvaluation::new(eval.domain(), eval.values.into_iter().collect());
}
Expand All @@ -28,18 +35,18 @@ impl FriOps for AVX512Backend {

let mut folded_values = SecureColumn::<Self>::zeros(1 << (log_size - 1));

for vec_index in 0..(1 << (log_size - 1 - VECS_LOG_SIZE as u32)) {
for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) {
let value = unsafe {
let twiddle_dbl: [i32; 16] =
std::array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i));
let val0 = eval.values.packed_at(vec_index * 2).to_packed_m31s();
let val1 = eval.values.packed_at(vec_index * 2 + 1).to_packed_m31s();
let pairs: [_; 4] = std::array::from_fn(|i| {
let (a, b) = val0[i].deinterleave_with(val1[i]);
avx_ibutterfly(a, b, std::mem::transmute(twiddle_dbl))
let twiddle_dbl: [u32; 16] =
array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i));
let val0 = eval.values.packed_at(vec_index * 2).into_packed_m31s();
let val1 = eval.values.packed_at(vec_index * 2 + 1).into_packed_m31s();
let pairs: [_; 4] = array::from_fn(|i| {
let (a, b) = val0[i].deinterleave(val1[i]);
simd_ibutterfly(a, b, std::mem::transmute(twiddle_dbl))
});
let val0 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].1));
let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1));
val0 + PackedSecureField::broadcast(alpha) * val1
};
unsafe { folded_values.set_packed(vec_index, value) };
Expand All @@ -55,27 +62,28 @@ impl FriOps for AVX512Backend {
twiddles: &TwiddleTree<Self>,
) {
let log_size = src.len().ilog2();
assert!(log_size > VECS_LOG_SIZE as u32, "Evaluation too small");
assert!(log_size > LOG_N_LANES, "Evaluation too small");

let domain = src.domain;
let alpha_sq = alpha * alpha;
let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0];

for vec_index in 0..(1 << (log_size - 1 - VECS_LOG_SIZE as u32)) {
for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) {
let value = unsafe {
// The 16 twiddles of the circle domain can be derived from the 8 twiddles of the
// next line domain. See `compute_first_twiddles()`.
let twiddle_dbl: [i32; 8] =
std::array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 8 + i));
let twiddle_dbl = u32x8::from_array(array::from_fn(|i| {
*itwiddles.get_unchecked(vec_index * 8 + i)
}));
let (t0, _) = compute_first_twiddles(twiddle_dbl);
let val0 = src.values.packed_at(vec_index * 2).to_packed_m31s();
let val1 = src.values.packed_at(vec_index * 2 + 1).to_packed_m31s();
let pairs: [_; 4] = std::array::from_fn(|i| {
let (a, b) = val0[i].deinterleave_with(val1[i]);
avx_ibutterfly(a, b, t0)
let val0 = src.values.packed_at(vec_index * 2).into_packed_m31s();
let val1 = src.values.packed_at(vec_index * 2 + 1).into_packed_m31s();
let pairs: [_; 4] = array::from_fn(|i| {
let (a, b) = val0[i].deinterleave(val1[i]);
simd_ibutterfly(a, b, t0)
});
let val0 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].1));
let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1));
val0 + PackedSecureField::broadcast(alpha) * val1
};
unsafe {
Expand All @@ -87,26 +95,85 @@ impl FriOps for AVX512Backend {
};
}
}

fn decompose(eval: &SecureEvaluation<Self>) -> (SecureEvaluation<Self>, SecureField) {
let lambda = decomposition_coefficient(eval);
let broadcasted_lambda = PackedSecureField::broadcast(lambda);
let mut g_values = SecureColumn::zeros(eval.len());

let range = eval.len().div_ceil(N_LANES);
let half_range = range / 2;
for i in 0..half_range {
let val = unsafe { eval.packed_at(i) } - broadcasted_lambda;
unsafe { g_values.set_packed(i, val) }
}
for i in half_range..range {
let val = unsafe { eval.packed_at(i) } + broadcasted_lambda;
unsafe { g_values.set_packed(i, val) }
}

let g = SecureEvaluation {
domain: eval.domain,
values: g_values,
};
(g, lambda)
}
}

/// See [`decomposition_coefficient`].
///
/// [`decomposition_coefficient`]: crate::core::backend::cpu::CpuBackend::decomposition_coefficient
fn decomposition_coefficient(eval: &SecureEvaluation<SimdBackend>) -> SecureField {
let cols = &eval.values.columns;
let [mut x_sum, mut y_sum, mut z_sum, mut w_sum] = [PackedBaseField::zero(); 4];

let range = cols[0].len() / N_LANES;
let (half_a, half_b) = (range / 2, range);

for i in 0..half_a {
x_sum += cols[0].data[i];
y_sum += cols[1].data[i];
z_sum += cols[2].data[i];
w_sum += cols[3].data[i];
}
for i in half_a..half_b {
x_sum -= cols[0].data[i];
y_sum -= cols[1].data[i];
z_sum -= cols[2].data[i];
w_sum -= cols[3].data[i];
}

let x = x_sum.pointwise_sum();
let y = y_sum.pointwise_sum();
let z = z_sum.pointwise_sum();
let w = w_sum.pointwise_sum();

SecureField::from_m31(x, y, z, w) / BaseField::from_u32_unchecked(1 << eval.domain.log_size())
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use crate::core::backend::avx512::AVX512Backend;
use crate::core::backend::CPUBackend;
use itertools::Itertools;
use num_traits::One;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use crate::core::backend::simd::column::BaseFieldVec;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{CPUBackend, Column};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fri::FriOps;
use crate::core::poly::circle::{CanonicCoset, PolyOps, SecureEvaluation};
use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps, SecureEvaluation};
use crate::core::poly::line::{LineDomain, LineEvaluation};
use crate::qm31;

#[test]
fn test_fold_line() {
const LOG_SIZE: u32 = 7;
let values: Vec<SecureField> = (0..(1 << LOG_SIZE))
.map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3))
.collect();
let mut rng = SmallRng::seed_from_u64(0);
let values = (0..1 << LOG_SIZE).map(|_| rng.gen()).collect_vec();
let alpha = qm31!(1, 3, 5, 7);
let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset());
let cpu_fold = CPUBackend::fold_line(
Expand All @@ -115,10 +182,10 @@ mod tests {
&CPUBackend::precompute_twiddles(domain.coset()),
);

let avx_fold = AVX512Backend::fold_line(
let avx_fold = SimdBackend::fold_line(
&LineEvaluation::new(domain, values.iter().copied().collect()),
alpha,
&AVX512Backend::precompute_twiddles(domain.coset()),
&SimdBackend::precompute_twiddles(domain.coset()),
);

assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec());
Expand All @@ -133,7 +200,6 @@ mod tests {
let alpha = qm31!(1, 3, 5, 7);
let circle_domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let line_domain = LineDomain::new(circle_domain.half_coset);

let mut cpu_fold =
LineEvaluation::new(line_domain, SecureColumn::zeros(1 << (LOG_SIZE - 1)));
CPUBackend::fold_circle_into_line(
Expand All @@ -146,18 +212,55 @@ mod tests {
&CPUBackend::precompute_twiddles(line_domain.coset()),
);

let mut avx_fold =
let mut simd_fold =
LineEvaluation::new(line_domain, SecureColumn::zeros(1 << (LOG_SIZE - 1)));
AVX512Backend::fold_circle_into_line(
&mut avx_fold,
SimdBackend::fold_circle_into_line(
&mut simd_fold,
&SecureEvaluation {
domain: circle_domain,
values: values.iter().copied().collect(),
},
alpha,
&AVX512Backend::precompute_twiddles(line_domain.coset()),
&SimdBackend::precompute_twiddles(line_domain.coset()),
);

assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec());
assert_eq!(cpu_fold.values.to_vec(), simd_fold.values.to_vec());
}

#[test]
fn decomposition_test() {
const DOMAIN_LOG_SIZE: u32 = 5;
const DOMAIN_LOG_HALF_SIZE: u32 = DOMAIN_LOG_SIZE - 1;
let s = CanonicCoset::new(DOMAIN_LOG_SIZE);
let domain = s.circle_domain();
let mut coeffs = BaseFieldVec::zeros(1 << DOMAIN_LOG_SIZE);
// Polynomial is out of FFT space.
coeffs.as_mut_slice()[1 << DOMAIN_LOG_HALF_SIZE] = BaseField::one();
let poly = CirclePoly::<SimdBackend>::new(coeffs);
let values = poly.evaluate(domain);
let avx_column = SecureColumn::<SimdBackend> {
columns: [
values.values.clone(),
values.values.clone(),
values.values.clone(),
values.values.clone(),
],
};
let avx_eval = SecureEvaluation {
domain,
values: avx_column.clone(),
};
let cpu_eval = SecureEvaluation::<CPUBackend> {
domain,
values: avx_eval.to_cpu(),
};
let (cpu_g, cpu_lambda) = CPUBackend::decompose(&cpu_eval);

let (avx_g, avx_lambda) = SimdBackend::decompose(&avx_eval);

assert_eq!(avx_lambda, cpu_lambda);
for i in 0..1 << DOMAIN_LOG_SIZE {
assert_eq!(avx_g.values.at(i), cpu_g.values.at(i));
}
}
}
4 changes: 2 additions & 2 deletions crates/prover/src/core/backend/simd/qm31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ impl PackedQM31 {

/// Returns vectors `a, b, c, d` such that element `i` is represented as
/// `QM31(a_i, b_i, c_i, d_i)`.
pub fn into_packed_m31s(self) -> [PackedBaseField; 4] {
pub fn into_packed_m31s(self) -> [PackedM31; 4] {
let Self([PackedCM31([a, b]), PackedCM31([c, d])]) = self;
[a, b, c, d]
}

/// Creates an instance from vectors `a, b, c, d` such that element `i`
/// is represented as `QM31(a_i, b_i, c_i, d_i)`.
pub fn from_packed_m31s([a, b, c, d]: [PackedBaseField; 4]) -> Self {
pub fn from_packed_m31s([a, b, c, d]: [PackedM31; 4]) -> Self {
Self([PackedCM31([a, b]), PackedCM31([c, d])])
}
}
Expand Down

0 comments on commit efb39ee

Please sign in to comment.