Skip to content

Commit

Permalink
Implement FieldOps for SIMD backend
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed May 2, 2024
1 parent 7ddb0a1 commit 6d71f67
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 50 deletions.
159 changes: 159 additions & 0 deletions crates/prover/src/core/backend/simd/column.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
use bytemuck::{cast_slice, cast_slice_mut, Zeroable};
use itertools::Itertools;
use num_traits::Zero;

use super::m31::{PackedBaseField, N_LANES};
use super::qm31::PackedSecureField;
use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;

#[derive(Clone, Debug)]
pub struct BaseFieldVec {
pub data: Vec<PackedBaseField>,
pub length: usize,
}

impl AsRef<[BaseField]> for BaseFieldVec {
fn as_ref(&self) -> &[BaseField] {
&cast_slice(&self.data)[..self.length]
}
}

impl AsMut<[BaseField]> for BaseFieldVec {
fn as_mut(&mut self) -> &mut [BaseField] {
&mut cast_slice_mut(&mut self.data)[..self.length]
}
}

impl Column<BaseField> for BaseFieldVec {
fn zeros(length: usize) -> Self {
let data = vec![PackedBaseField::zeroed(); length.div_ceil(N_LANES)];
Self { data, length }
}

fn to_cpu(&self) -> Vec<BaseField> {
self.as_ref().to_vec()
}

fn len(&self) -> usize {
self.length
}

fn at(&self, index: usize) -> BaseField {
self.data[index / N_LANES].to_array()[index % N_LANES]
}
}

impl FromIterator<BaseField> for BaseFieldVec {
fn from_iter<I: IntoIterator<Item = BaseField>>(iter: I) -> Self {
let mut chunks = iter.into_iter().array_chunks();
let mut data = (&mut chunks).map(PackedBaseField::from_array).collect_vec();
let mut length = data.len() * N_LANES;

if let Some(remainder) = chunks.into_remainder() {
if !remainder.is_empty() {
length += remainder.len();
let mut last = [BaseField::zero(); N_LANES];
last[..remainder.len()].copy_from_slice(remainder.as_slice());
data.push(PackedBaseField::from_array(last));
}
}

Self { data, length }
}
}

#[derive(Clone, Debug)]
pub struct SecureFieldVec {
pub data: Vec<PackedSecureField>,
pub length: usize,
}

impl Column<SecureField> for SecureFieldVec {
fn zeros(length: usize) -> Self {
Self {
data: vec![PackedSecureField::zeroed(); length.div_ceil(N_LANES)],
length,
}
}

fn to_cpu(&self) -> Vec<SecureField> {
self.data
.iter()
.flat_map(|x| x.to_array())
.take(self.length)
.collect()
}

fn len(&self) -> usize {
self.length
}

fn at(&self, index: usize) -> SecureField {
self.data[index / N_LANES].to_array()[index % N_LANES]
}
}

impl FromIterator<SecureField> for SecureFieldVec {
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
let mut chunks = iter.into_iter().array_chunks();
let mut data = (&mut chunks)
.map(PackedSecureField::from_array)
.collect_vec();
let mut length = data.len() * N_LANES;

if let Some(remainder) = chunks.into_remainder() {
if !remainder.is_empty() {
length += remainder.len();
let mut last = [SecureField::zero(); N_LANES];
last[..remainder.len()].copy_from_slice(remainder.as_slice());
data.push(PackedSecureField::from_array(last));
}
}

Self { data, length }
}
}

impl FromIterator<PackedSecureField> for SecureFieldVec {
fn from_iter<I: IntoIterator<Item = PackedSecureField>>(iter: I) -> Self {
let data = (&mut iter.into_iter()).collect_vec();
let length = data.len() * N_LANES;

Self { data, length }
}
}

#[cfg(test)]
mod tests {
use std::array;

use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use super::BaseFieldVec;
use crate::core::backend::simd::column::SecureFieldVec;
use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;

#[test]
fn base_field_vec_from_iter_works() {
let values: [BaseField; 30] = array::from_fn(BaseField::from);

let res = values.into_iter().collect::<BaseFieldVec>();

assert_eq!(res.to_cpu(), values);
}

#[test]
fn secure_field_vec_from_iter_works() {
let mut rng = SmallRng::seed_from_u64(0);
let values: [SecureField; 30] = rng.gen();

let res = values.into_iter().collect::<SecureFieldVec>();

assert_eq!(res.to_cpu(), values);
}
}
8 changes: 4 additions & 4 deletions crates/prover/src/core/backend/simd/m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use num_traits::{One, Zero};
use rand::distributions::{Distribution, Standard};

use crate::core::backend::simd::utils::{LoEvensInterleaveHiEvens, LoOddsInterleaveHiOdds};
use crate::core::fields::m31::{BaseField, M31, P};
use crate::core::fields::m31::{pow2147483645, BaseField, M31, P};
use crate::core::fields::FieldExpOps;

pub const LOG_N_LANES: u32 = 4;
Expand Down Expand Up @@ -61,7 +61,7 @@ impl PackedBaseField {
self.to_array().into_iter().sum()
}

/// Doubles each element.
/// Doubles each element in the vector.
pub fn double(self) -> Self {
// TODO: Make more optimal.
self + self
Expand Down Expand Up @@ -122,6 +122,7 @@ impl Mul for PackedBaseField {
#[inline(always)]
fn mul(self, rhs: Self) -> Self {
// TODO: Come up with a better approach than `cfg`ing on target_feature.
// TODO: Ensure all these branches get tested in the CI.
cfg_if::cfg_if! {
if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] {
_mul_neon(self, rhs)
Expand Down Expand Up @@ -194,8 +195,7 @@ impl One for PackedBaseField {

impl FieldExpOps for PackedBaseField {
fn inverse(&self) -> Self {
assert!(!self.is_zero(), "0 has no inverse");
self.pow((P - 2) as u128)
pow2147483645(*self)
}
}

Expand Down
37 changes: 37 additions & 0 deletions crates/prover/src/core/backend/simd/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,44 @@
use self::column::{BaseFieldVec, SecureFieldVec};
use self::m31::PackedBaseField;
use self::qm31::PackedSecureField;
use super::ColumnOps;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{FieldExpOps, FieldOps};

pub mod cm31;
pub mod column;
pub mod m31;
pub mod qm31;
mod utils;

#[derive(Copy, Clone, Debug)]
pub struct SimdBackend;

impl ColumnOps<BaseField> for SimdBackend {
type Column = BaseFieldVec;

fn bit_reverse_column(_column: &mut Self::Column) {
todo!()
}
}

impl FieldOps<BaseField> for SimdBackend {
fn batch_inverse(column: &Self::Column, dst: &mut Self::Column) {
PackedBaseField::batch_inverse(&column.data, &mut dst.data);
}
}

impl ColumnOps<SecureField> for SimdBackend {
type Column = SecureFieldVec;

fn bit_reverse_column(_column: &mut Self::Column) {
todo!()
}
}

impl FieldOps<SecureField> for SimdBackend {
fn batch_inverse(column: &Self::Column, dst: &mut Self::Column) {
PackedSecureField::batch_inverse(&column.data, &mut dst.data);
}
}
47 changes: 1 addition & 46 deletions crates/prover/src/core/backend/simd/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,6 @@ impl<const N: usize> Swizzle<N> for LoHiInterleaveHiHi {
const INDEX: [usize; N] = segment_interleave(true);
}

/// Used with [`Swizzle::concat_swizzle`] to concat the even values of vectors `lo` and `hi`.
pub struct LoEvensConcatHiEvens;

impl<const N: usize> Swizzle<N> for LoEvensConcatHiEvens {
const INDEX: [usize; N] = parity_concat(false);
}

/// Used with [`Swizzle::concat_swizzle`] to concat the odd values of vectors `lo` and `hi`.
pub struct LoOddsConcatHiOdds;

impl<const N: usize> Swizzle<N> for LoOddsConcatHiOdds {
const INDEX: [usize; N] = parity_concat(true);
}

/// Used with [`Swizzle::concat_swizzle`] to interleave the even values of vectors `lo` and `hi`.
pub struct LoEvensInterleaveHiEvens;

Expand All @@ -52,16 +38,6 @@ const fn segment_interleave<const N: usize>(hi: bool) -> [usize; N] {
res
}

const fn parity_concat<const N: usize>(odd: bool) -> [usize; N] {
let mut res = [0; N];
let mut i = 0;
while i < N {
res[i] = i * 2 + if odd { 1 } else { 0 };
i += 1;
}
res
}

const fn parity_interleave<const N: usize>(odd: bool) -> [usize; N] {
let mut res = [0; N];
let mut i = 0;
Expand All @@ -78,8 +54,7 @@ mod tests {

use super::LoLoInterleaveHiLo;
use crate::core::backend::simd::utils::{
LoEvensConcatHiEvens, LoEvensInterleaveHiEvens, LoHiInterleaveHiHi, LoOddsConcatHiOdds,
LoOddsInterleaveHiOdds,
LoEvensInterleaveHiEvens, LoHiInterleaveHiHi, LoOddsInterleaveHiOdds,
};

#[test]
Expand All @@ -102,26 +77,6 @@ mod tests {
assert_eq!(res, u32x4::from_array([2, 6, 3, 7]));
}

#[test]
fn lo_evens_concat_hi_evens() {
let lo = u32x4::from_array([0, 1, 2, 3]);
let hi = u32x4::from_array([4, 5, 6, 7]);

let res = LoEvensConcatHiEvens::concat_swizzle(lo, hi);

assert_eq!(res, u32x4::from_array([0, 2, 4, 6]));
}

#[test]
fn lo_odds_concat_hi_odds() {
let lo = u32x4::from_array([0, 1, 2, 3]);
let hi = u32x4::from_array([4, 5, 6, 7]);

let res = LoOddsConcatHiOdds::concat_swizzle(lo, hi);

assert_eq!(res, u32x4::from_array([1, 3, 5, 7]));
}

#[test]
fn lo_evens_interleave_hi_evens() {
let lo = u32x4::from_array([0, 1, 2, 3]);
Expand Down

0 comments on commit 6d71f67

Please sign in to comment.