Skip to content

Commit

Permalink
SecureColumnSlices
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Jul 1, 2024
1 parent a759ce0 commit 3db6f2f
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 56 deletions.
41 changes: 6 additions & 35 deletions crates/prover/src/core/backend/simd/bit_reverse.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,11 @@
use std::array;

use super::column::{BaseFieldVec, SecureFieldVec};
use super::m31::PackedBaseField;
use super::SimdBackend;
use crate::core::backend::ColumnOps;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::utils::{bit_reverse as cpu_bit_reverse, bit_reverse_index};

const VEC_BITS: u32 = 4;
use super::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::utils::bit_reverse_index;

const W_BITS: u32 = 3;

pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS;

impl ColumnOps<BaseField> for SimdBackend {
type Column = BaseFieldVec;

fn bit_reverse_column(column: &mut Self::Column) {
// Fallback to cpu bit_reverse.
if column.data.len().ilog2() < MIN_LOG_SIZE {
cpu_bit_reverse(column.as_mut_slice());
return;
}

bit_reverse_m31(&mut column.data);
}
}

impl ColumnOps<SecureField> for SimdBackend {
type Column = SecureFieldVec;

fn bit_reverse_column(_column: &mut SecureFieldVec) {
todo!()
}
}
pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + LOG_N_LANES;

/// Bit reverses M31 values.
///
Expand All @@ -44,21 +15,21 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {
assert!(data.len().ilog2() >= MIN_LOG_SIZE);

// Indices in the array are of the form v_h w_h a w_l v_l, with
// |v_h| = |v_l| = VEC_BITS, |w_h| = |w_l| = W_BITS, |a| = n - 2*W_BITS - VEC_BITS.
// |v_h| = |v_l| = LOG_N_LANES, |w_h| = |w_l| = W_BITS, |a| = n - 2*W_BITS - LOG_N_LANES.
// The loops go over a, w_l, w_h, and then swaps the 16 by 16 values at:
// * w_h a w_l * <-> * rev(w_h a w_l) *.
// These are 1 or 2 chunks of 2^W_BITS contiguous `u32x16` vectors.

let log_size = data.len().ilog2();
let a_bits = log_size - 2 * W_BITS - VEC_BITS;
let a_bits = log_size - 2 * W_BITS - LOG_N_LANES;

// TODO(spapini): when doing multithreading, do it over a.
for a in 0u32..1 << a_bits {
for w_l in 0u32..1 << W_BITS {
let w_l_rev = w_l.reverse_bits() >> (u32::BITS - W_BITS);
for w_h in 0..w_l_rev + 1 {
let idx = ((((w_h << a_bits) | a) << W_BITS) | w_l) as usize;
let idx_rev = bit_reverse_index(idx, log_size - VEC_BITS);
let idx_rev = bit_reverse_index(idx, log_size - LOG_N_LANES);

// In order to not swap twice, only swap if idx <= idx_rev.
if idx > idx_rev {
Expand Down
100 changes: 79 additions & 21 deletions crates/prover/src/core/backend/simd/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,39 @@ use bytemuck::{cast_slice, cast_slice_mut, Zeroable};
use itertools::{izip, Itertools};
use num_traits::Zero;

use super::bit_reverse::{bit_reverse_m31, MIN_LOG_SIZE};
use super::cm31::PackedCM31;
use super::m31::{PackedBaseField, N_LANES};
use super::qm31::{PackedQM31, PackedSecureField};
use super::SimdBackend;
use crate::core::backend::{Column, CpuBackend};
use crate::core::backend::{Column, ColumnOps, CpuBackend};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::fields::secure_column::{SecureColumn, SECURE_EXTENSION_DEGREE};
use crate::core::fields::{FieldExpOps, FieldOps};
use crate::core::utils::bit_reverse as cpu_bit_reverse;

impl ColumnOps<BaseField> for SimdBackend {
type Column = BaseFieldVec;

fn bit_reverse_column(column: &mut Self::Column) {
// Fallback to cpu bit_reverse.
if column.data.len().ilog2() < MIN_LOG_SIZE {
cpu_bit_reverse(column.as_mut_slice());
return;
}

bit_reverse_m31(&mut column.data);
}
}

impl ColumnOps<SecureField> for SimdBackend {
type Column = SecureFieldVec;

fn bit_reverse_column(_column: &mut SecureFieldVec) {
todo!()
}
}

impl FieldOps<BaseField> for SimdBackend {
fn batch_inverse(column: &BaseFieldVec, dst: &mut BaseFieldVec) {
Expand Down Expand Up @@ -54,7 +78,6 @@ impl BaseFieldVec {
res
}
}

impl Column<BaseField> for BaseFieldVec {
fn zeros(length: usize) -> Self {
let data = vec![PackedBaseField::zeroed(); length.div_ceil(N_LANES)];
Expand Down Expand Up @@ -156,9 +179,12 @@ impl FromIterator<PackedSecureField> for SecureFieldVec {
}
}

impl SecureColumn<SimdBackend> {
pub struct SecureColumnSlice<'a>(pub [&'a [PackedBaseField]; SECURE_EXTENSION_DEGREE]);
pub struct SecureColumnMutSlice<'a>(pub [&'a mut [PackedBaseField]; SECURE_EXTENSION_DEGREE]);

impl<'a> SecureColumnSlice<'a> {
pub fn packed_len(&self) -> usize {
self.columns[0].data.len()
self.0[0].len()
}

/// # Safety
Expand All @@ -167,36 +193,68 @@ impl SecureColumn<SimdBackend> {
pub unsafe fn packed_at(&self, vec_index: usize) -> PackedSecureField {
PackedQM31([
PackedCM31([
*self.columns[0].data.get_unchecked(vec_index),
*self.columns[1].data.get_unchecked(vec_index),
*self.0[0].get_unchecked(vec_index),
*self.0[1].get_unchecked(vec_index),
]),
PackedCM31([
*self.columns[2].data.get_unchecked(vec_index),
*self.columns[3].data.get_unchecked(vec_index),
*self.0[2].get_unchecked(vec_index),
*self.0[3].get_unchecked(vec_index),
]),
])
}

pub fn to_vec(&self) -> Vec<SecureField> {
izip!(
cast_slice(self.0[0]),
cast_slice(self.0[1]),
cast_slice(self.0[2]),
cast_slice(self.0[3]),
)
.map(|(a, b, c, d)| SecureField::from_m31_array([*a, *b, *c, *d]))
.collect()
}
}

impl<'a> SecureColumnMutSlice<'a> {
/// # Safety
///
/// `vec_index` must be a valid index.
pub unsafe fn set_packed(&mut self, vec_index: usize, value: PackedSecureField) {
let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value;
*self.columns[0].data.get_unchecked_mut(vec_index) = a;
*self.columns[1].data.get_unchecked_mut(vec_index) = b;
*self.columns[2].data.get_unchecked_mut(vec_index) = c;
*self.columns[3].data.get_unchecked_mut(vec_index) = d;
*self.0[0].get_unchecked_mut(vec_index) = a;
*self.0[1].get_unchecked_mut(vec_index) = b;
*self.0[2].get_unchecked_mut(vec_index) = c;
*self.0[3].get_unchecked_mut(vec_index) = d;
}
}

impl SecureColumn<SimdBackend> {
pub fn as_ref(&self) -> SecureColumnSlice<'_> {
assert_eq!(self.columns[0].length, self.columns[0].data.len() * N_LANES);
SecureColumnSlice(std::array::from_fn(|i| &self.columns[i].data[..]))
}
pub fn as_mut(&mut self) -> SecureColumnMutSlice<'_> {
assert_eq!(self.columns[0].length, self.columns[0].data.len() * N_LANES);
let cols = self.columns.get_many_mut([0, 1, 2, 3]).unwrap();
SecureColumnMutSlice(cols.map(|c| &mut c.data[..]))
}
pub fn packed_len(&self) -> usize {
self.as_ref().packed_len()
}
/// # Safety
///
/// `vec_index` must be a valid index.
pub unsafe fn packed_at(&self, vec_index: usize) -> PackedSecureField {
self.as_ref().packed_at(vec_index)
}
pub fn to_vec(&self) -> Vec<SecureField> {
izip!(
self.columns[0].to_cpu(),
self.columns[1].to_cpu(),
self.columns[2].to_cpu(),
self.columns[3].to_cpu(),
)
.map(|(a, b, c, d)| SecureField::from_m31_array([a, b, c, d]))
.collect()
self.as_ref().to_vec()
}
/// # Safety
///
/// `vec_index` must be a valid index.
pub unsafe fn set_packed(&mut self, vec_index: usize, value: PackedSecureField) {
self.as_mut().set_packed(vec_index, value)
}
}

Expand Down

0 comments on commit 3db6f2f

Please sign in to comment.