Skip to content

Commit

Permalink
Copy FriOps from AVX backend
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed May 16, 2024
1 parent 37f3ef0 commit 91257ed
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 5 deletions.
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/avx512/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl FriOps for AVX512Backend {
let domain = eval.domain();
let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0];

let mut folded_values = SecureColumn::zeros(1 << (log_size - 1));
let mut folded_values = SecureColumn::<Self>::zeros(1 << (log_size - 1));

for vec_index in 0..(1 << (log_size - 1 - VECS_LOG_SIZE as u32)) {
let value = unsafe {
Expand Down
56 changes: 53 additions & 3 deletions crates/prover/src/core/backend/simd/column.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use std::mem;

use bytemuck::{cast_slice, cast_slice_mut, Zeroable};
use itertools::Itertools;
use itertools::{izip, Itertools};
use num_traits::Zero;

use super::cm31::PackedCM31;
use super::m31::{PackedBaseField, N_LANES};
use super::qm31::PackedSecureField;
use super::qm31::{PackedQM31, PackedSecureField};
use super::SimdBackend;
use crate::core::backend::Column;
use crate::core::backend::{CPUBackend, Column};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fields::{FieldExpOps, FieldOps};

impl FieldOps<BaseField> for SimdBackend {
Expand Down Expand Up @@ -154,6 +156,54 @@ impl FromIterator<PackedSecureField> for SecureFieldVec {
}
}

impl SecureColumn<SimdBackend> {
/// # Safety
///
/// `vec_index` must be a valid index.
pub unsafe fn packed_at(&self, vec_index: usize) -> PackedSecureField {
PackedQM31([
PackedCM31([
*self.columns[0].data.get_unchecked(vec_index),
*self.columns[1].data.get_unchecked(vec_index),
]),
PackedCM31([
*self.columns[2].data.get_unchecked(vec_index),
*self.columns[3].data.get_unchecked(vec_index),
]),
])
}

/// # Safety
///
/// `vec_index` must be a valid index.
pub unsafe fn set_packed(&mut self, vec_index: usize, value: PackedSecureField) {
let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value;
*self.columns[0].data.get_unchecked_mut(vec_index) = a;
*self.columns[1].data.get_unchecked_mut(vec_index) = b;
*self.columns[2].data.get_unchecked_mut(vec_index) = c;
*self.columns[3].data.get_unchecked_mut(vec_index) = d;
}

pub fn to_vec(&self) -> Vec<SecureField> {
izip!(
self.columns[0].to_cpu(),
self.columns[1].to_cpu(),
self.columns[2].to_cpu(),
self.columns[3].to_cpu(),
)
.map(|(a, b, c, d)| SecureField::from_m31_array([a, b, c, d]))
.collect()
}
}

impl FromIterator<SecureField> for SecureColumn<SimdBackend> {
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
let cpu_col = SecureColumn::<CPUBackend>::from_iter(iter);
let columns = cpu_col.columns.map(|col| col.into_iter().collect());
SecureColumn { columns }
}
}

#[cfg(test)]
mod tests {
use std::array;
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/simd/fft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) {
/// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers.
///
/// Returns the twiddles for the first layer and the twiddles for the second layer.
fn compute_first_twiddles(twiddle1_dbl: u32x8) -> (u32x16, u32x16) {
pub fn compute_first_twiddles(twiddle1_dbl: u32x8) -> (u32x16, u32x16) {
// Start by loading the twiddles for the second layer (layer 1):
let t1 = simd_swizzle!(
twiddle1_dbl,
Expand Down
163 changes: 163 additions & 0 deletions crates/prover/src/core/backend/simd/fri.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
use super::AVX512Backend;
use crate::core::backend::avx512::fft::compute_first_twiddles;
use crate::core::backend::avx512::fft::ifft::avx_ibutterfly;
use crate::core::backend::avx512::qm31::PackedSecureField;
use crate::core::backend::avx512::VECS_LOG_SIZE;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fri::{self, FriOps};
use crate::core::poly::circle::SecureEvaluation;
use crate::core::poly::line::LineEvaluation;
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::domain_line_twiddles_from_tree;

impl FriOps for AVX512Backend {
fn fold_line(
eval: &LineEvaluation<Self>,
alpha: SecureField,
twiddles: &TwiddleTree<Self>,
) -> LineEvaluation<Self> {
let log_size = eval.len().ilog2();
if log_size <= VECS_LOG_SIZE as u32 {
let eval = fri::fold_line(&eval.to_cpu(), alpha);
return LineEvaluation::new(eval.domain(), eval.values.into_iter().collect());
}

let domain = eval.domain();
let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0];

let mut folded_values = SecureColumn::<Self>::zeros(1 << (log_size - 1));

for vec_index in 0..(1 << (log_size - 1 - VECS_LOG_SIZE as u32)) {
let value = unsafe {
let twiddle_dbl: [i32; 16] =
std::array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i));
let val0 = eval.values.packed_at(vec_index * 2).to_packed_m31s();
let val1 = eval.values.packed_at(vec_index * 2 + 1).to_packed_m31s();
let pairs: [_; 4] = std::array::from_fn(|i| {
let (a, b) = val0[i].deinterleave_with(val1[i]);
avx_ibutterfly(a, b, std::mem::transmute(twiddle_dbl))
});
let val0 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].1));
val0 + PackedSecureField::broadcast(alpha) * val1
};
unsafe { folded_values.set_packed(vec_index, value) };
}

LineEvaluation::new(domain.double(), folded_values)
}

fn fold_circle_into_line(
dst: &mut LineEvaluation<Self>,
src: &SecureEvaluation<Self>,
alpha: SecureField,
twiddles: &TwiddleTree<Self>,
) {
let log_size = src.len().ilog2();
assert!(log_size > VECS_LOG_SIZE as u32, "Evaluation too small");

let domain = src.domain;
let alpha_sq = alpha * alpha;
let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0];

for vec_index in 0..(1 << (log_size - 1 - VECS_LOG_SIZE as u32)) {
let value = unsafe {
// The 16 twiddles of the circle domain can be derived from the 8 twiddles of the
// next line domain. See `compute_first_twiddles()`.
let twiddle_dbl: [i32; 8] =
std::array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 8 + i));
let (t0, _) = compute_first_twiddles(twiddle_dbl);
let val0 = src.values.packed_at(vec_index * 2).to_packed_m31s();
let val1 = src.values.packed_at(vec_index * 2 + 1).to_packed_m31s();
let pairs: [_; 4] = std::array::from_fn(|i| {
let (a, b) = val0[i].deinterleave_with(val1[i]);
avx_ibutterfly(a, b, t0)
});
let val0 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].0));
let val1 = PackedSecureField::from_packed_m31s(std::array::from_fn(|i| pairs[i].1));
val0 + PackedSecureField::broadcast(alpha) * val1
};
unsafe {
dst.values.set_packed(
vec_index,
dst.values.packed_at(vec_index) * PackedSecureField::broadcast(alpha_sq)
+ value,
)
};
}
}
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use crate::core::backend::avx512::AVX512Backend;
use crate::core::backend::CPUBackend;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fri::FriOps;
use crate::core::poly::circle::{CanonicCoset, PolyOps, SecureEvaluation};
use crate::core::poly::line::{LineDomain, LineEvaluation};
use crate::qm31;

#[test]
fn test_fold_line() {
const LOG_SIZE: u32 = 7;
let values: Vec<SecureField> = (0..(1 << LOG_SIZE))
.map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3))
.collect();
let alpha = qm31!(1, 3, 5, 7);
let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset());
let cpu_fold = CPUBackend::fold_line(
&LineEvaluation::new(domain, values.iter().copied().collect()),
alpha,
&CPUBackend::precompute_twiddles(domain.coset()),
);

let avx_fold = AVX512Backend::fold_line(
&LineEvaluation::new(domain, values.iter().copied().collect()),
alpha,
&AVX512Backend::precompute_twiddles(domain.coset()),
);

assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec());
}

#[test]
fn test_fold_circle_into_line() {
const LOG_SIZE: u32 = 7;
let values: Vec<SecureField> = (0..(1 << LOG_SIZE))
.map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3))
.collect();
let alpha = qm31!(1, 3, 5, 7);
let circle_domain = CanonicCoset::new(LOG_SIZE).circle_domain();
let line_domain = LineDomain::new(circle_domain.half_coset);

let mut cpu_fold =
LineEvaluation::new(line_domain, SecureColumn::zeros(1 << (LOG_SIZE - 1)));
CPUBackend::fold_circle_into_line(
&mut cpu_fold,
&SecureEvaluation {
domain: circle_domain,
values: values.iter().copied().collect(),
},
alpha,
&CPUBackend::precompute_twiddles(line_domain.coset()),
);

let mut avx_fold =
LineEvaluation::new(line_domain, SecureColumn::zeros(1 << (LOG_SIZE - 1)));
AVX512Backend::fold_circle_into_line(
&mut avx_fold,
&SecureEvaluation {
domain: circle_domain,
values: values.iter().copied().collect(),
},
alpha,
&AVX512Backend::precompute_twiddles(line_domain.coset()),
);

assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec());
}
}
1 change: 1 addition & 0 deletions crates/prover/src/core/backend/simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod circle;
pub mod cm31;
pub mod column;
pub mod fft;
pub mod fri;
pub mod m31;
pub mod qm31;
mod utils;
Expand Down
13 changes: 13 additions & 0 deletions crates/prover/src/core/backend/simd/qm31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ impl PackedQM31 {
let Self([a, b]) = self;
Self([a.double(), b.double()])
}

/// Returns vectors `a, b, c, d` such that element `i` is represented as
/// `QM31(a_i, b_i, c_i, d_i)`.
pub fn into_packed_m31s(self) -> [PackedBaseField; 4] {
let Self([PackedCM31([a, b]), PackedCM31([c, d])]) = self;
[a, b, c, d]
}

/// Creates an instance from vectors `a, b, c, d` such that element `i`
/// is represented as `QM31(a_i, b_i, c_i, d_i)`.
pub fn from_packed_m31s([a, b, c, d]: [PackedBaseField; 4]) -> Self {
Self([PackedCM31([a, b]), PackedCM31([c, d])])
}
}

impl Add for PackedQM31 {
Expand Down

0 comments on commit 91257ed

Please sign in to comment.