Skip to content

Commit

Permalink
avx512 eval at secure point (#513)
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware authored Mar 20, 2024
1 parent 3bc4c79 commit d3e0c90
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 4 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,7 @@ harness = false
[[bench]]
name = "fri"
harness = false

[[bench]]
name = "eval_at_point"
harness = false
94 changes: 94 additions & 0 deletions benches/eval_at_point.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use criterion::{black_box, Criterion};

#[cfg(target_arch = "x86_64")]
pub fn cpu_eval_at_secure_point(c: &mut criterion::Criterion) {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use stwo::core::backend::CPUBackend;
use stwo::core::circle::CirclePoint;
use stwo::core::fields::m31::BaseField;
use stwo::core::fields::qm31::QM31;
use stwo::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use stwo::core::poly::NaturalOrder;
let log_size = 23;
let rng = &mut StdRng::seed_from_u64(0);

let domain = CanonicCoset::new(log_size as u32).circle_domain();
let evaluation = CircleEvaluation::<CPUBackend, _, NaturalOrder>::new(
domain,
(0..(1 << log_size))
.map(BaseField::from_u32_unchecked)
.collect(),
);
let poly = evaluation.bit_reverse().interpolate();
let x = QM31::from_u32_unchecked(
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
);
let y = QM31::from_u32_unchecked(
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
);

let point = CirclePoint { x, y };
c.bench_function("cpu eval_at_secure_field_point", |b| {
b.iter(|| {
black_box(<CPUBackend as PolyOps>::eval_at_point(&poly, point));
})
});
}

#[cfg(target_arch = "x86_64")]
pub fn avx512_eval_at_secure_point(c: &mut criterion::Criterion) {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use stwo::core::backend::avx512::AVX512Backend;
use stwo::core::circle::CirclePoint;
use stwo::core::fields::m31::BaseField;
use stwo::core::fields::qm31::QM31;
use stwo::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use stwo::core::poly::NaturalOrder;
let log_size = 23;
let rng = &mut StdRng::seed_from_u64(0);

let domain = CanonicCoset::new(log_size as u32).circle_domain();
let evaluation = CircleEvaluation::<AVX512Backend, _, NaturalOrder>::new(
domain,
(0..(1 << log_size))
.map(BaseField::from_u32_unchecked)
.collect(),
);
let poly = evaluation.bit_reverse().interpolate();
let x = QM31::from_u32_unchecked(
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
);
let y = QM31::from_u32_unchecked(
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
);

let point = CirclePoint { x, y };
c.bench_function("avx eval_at_secure_field_point", |b| {
b.iter(|| {
black_box(<AVX512Backend as PolyOps>::eval_at_securefield_point(
&poly, point,
));
})
});
}

#[cfg(target_arch = "x86_64")]
criterion::criterion_group!(
name=secure_eval;
config = Criterion::default().sample_size(10);
targets=avx512_eval_at_secure_point, cpu_eval_at_secure_point);
criterion::criterion_main!(secure_eval);
207 changes: 203 additions & 4 deletions src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,119 @@
use bytemuck::cast_slice;
use bytemuck::{cast_slice, Zeroable};
use num_traits::One;

use super::fft::{ifft, CACHED_FFT_LOG_SIZE};
use super::m31::PackedBaseField;
use super::{as_cpu_vec, AVX512Backend, VECS_LOG_SIZE};
use super::qm31::PackedQM31;
use super::{as_cpu_vec, AVX512Backend, K_BLOCK_SIZE, VECS_LOG_SIZE};
use crate::core::backend::avx512::fft::rfft;
use crate::core::backend::avx512::BaseFieldVec;
use crate::core::backend::{CPUBackend, Col};
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::fields::{ExtensionOf, FieldExpOps};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, Field, FieldExpOps};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold};
use crate::core::poly::BitReversedOrder;

impl AVX512Backend {
// TODO(Ohad): optimize.
fn twiddle_at<F: Field>(mappings: &[F], mut index: usize) -> F {
debug_assert!(
(1 << mappings.len()) as usize >= index,
"Index out of bounds. mappings log len = {}, index = {index}",
mappings.len().ilog2()
);

let mut product = F::one();
for &num in mappings.iter() {
if index & 1 == 1 {
product *= num;
}
index >>= 1;
if index == 0 {
break;
}
}

product
}

// TODO(Ohad): consider moving this to to a more general place.
// Note: CACHED_FFT_LOG_SIZE is specific to the backend.
fn generate_evaluation_mappings<F: Field>(point: CirclePoint<F>, log_size: u32) -> Vec<F> {
// Mappings are the factors used to compute the evaluation twiddle.
// Every twiddle (i) is of the form (m[0])^b_0 * (m[1])^b_1 * ... * (m[log_size -
// 1])^b_log_size.
// Where (m)_j are the mappings, and b_i is the j'th bit of i.
let mut mappings = vec![point.y, point.x];
let mut x = point.x;
for _ in 2..log_size {
x = CirclePoint::double_x(x);
mappings.push(x);
}

// The caller function expects the mapping in natural order. i.e. (y,x,h(x),h(h(x)),...).
// If the polynomial is large, the fft does a transpose in the middle in a granularity of 16
// (avx512). The coefficients would then be in tranposed order of 16-sized chunks.
// i.e. (a_(n-15), a_(n-14), ..., a_(n-1), a_(n-31), ..., a_(n-16), a_(n-32), ...).
// To compute the twiddles in the correct order, we need to transpose the coprresponding
// 'transposed bits' in the mappings. The result order of the mappings would then be
// (y, x, h(x), h^2(x), h^(log_n-1)(x), h^(log_n-2)(x) ...). To avoid code
// complexity for now, we just reverse the mappings, transpose, then reverse back.
// TODO(Ohad): optimize. consider changing the caller to expect the mappings in
// reversed-tranposed order.
if log_size as usize > CACHED_FFT_LOG_SIZE {
mappings.reverse();
let n = mappings.len();
let n0 = (n - VECS_LOG_SIZE) / 2;
let n1 = (n - VECS_LOG_SIZE + 1) / 2;
let (ab, c) = mappings.split_at_mut(n1);
let (a, _b) = ab.split_at_mut(n0);
// Swap content of a,c.
a.swap_with_slice(&mut c[0..n0]);
mappings.reverse();
}

mappings
}

// Generates twiddle steps for efficiently computing the twiddles.
// steps[i] = t_i/(t_0*t_1*...*t_i-1).
fn twiddle_steps<F: Field>(mappings: &[F]) -> Vec<F> {
let mut denominators: Vec<F> = vec![mappings[0]];

for i in 1..mappings.len() {
denominators.push(denominators[i - 1] * mappings[i]);
}

// TODO(Ohad): batch inverse.
let denom_inverses = denominators.iter().map(|d| d.inverse()).collect::<Vec<F>>();

let mut steps = vec![mappings[0]];

mappings
.iter()
.skip(1)
.zip(denom_inverses.iter())
.for_each(|(&m, &d)| {
steps.push(m * d);
});
steps.push(F::one());
steps
}

// Advances the twiddle by multiplying it by the next step. e.g:
// If idx(t) = 0b100..1010 , then f(t) = t * step[0]
// If idx(t) = 0b100..0111 , then f(t) = t * step[3]
fn advance_twiddle<F: Field>(twiddle: F, steps: &[F], curr_idx: usize) -> F {
twiddle * steps[curr_idx.trailing_ones() as usize]
}
}

// TODO(spapini): Everything is returned in redundant representation, where values can also be P.
// Decide if and when it's ok and what to do if it's not.
impl PolyOps for AVX512Backend {
Expand Down Expand Up @@ -171,17 +269,75 @@ impl PolyOps for AVX512Backend {
itwiddles,
}
}

fn eval_at_securefield_point(
poly: &CirclePoly<Self>,
point: CirclePoint<SecureField>,
) -> SecureField {
// If the polynomial is small, fallback to evaluate directly.
// TODO(Ohad): it's possible to avoid falling back. Consider fixing.
if poly.log_size() <= 8 {
return Self::eval_at_point(poly, point);
}

let mappings = Self::generate_evaluation_mappings(point, poly.log_size());

// 8 lowest mappings produce the first 2^8 twiddles. Separate to optimize each calculation.
let (map_low, map_high) = mappings.split_at(4);
let twiddle_lows =
PackedQM31::from_array(&std::array::from_fn(|i| Self::twiddle_at(map_low, i)));
let (map_mid, map_high) = map_high.split_at(4);
let twiddle_mids =
PackedQM31::from_array(&std::array::from_fn(|i| Self::twiddle_at(map_mid, i)));

// Compute the high twiddle steps.
let twiddle_steps = Self::twiddle_steps(map_high);

// Every twiddle is a product of mappings that correspond to '1's in the bit representation
// of the current index. For every 2^n alligned chunk of 2^n elements, the twiddle
// array is the same, denoted twiddle_low. Use this to compute sums of (coeff *
// twiddle_high) mod 2^n, then multiply by twiddle_low, and sum to get the final result.
let mut sum = PackedQM31::zeroed();
let mut twiddle_high = SecureField::one();
for (i, coeff_chunk) in poly.coeffs.data.array_chunks::<K_BLOCK_SIZE>().enumerate() {
// For every chunk of 2 ^ 4 * 2 ^ 4 = 2 ^ 8 elements, the twiddle high is the same.
// Multiply it by every mid twiddle factor to get the factors for the current chunk.
let high_twiddle_factors =
(PackedQM31::broadcast(twiddle_high) * twiddle_mids).to_array();

// Sum the coefficients multiplied by each corrseponsing twiddle. Result is effectivley
// an array[16] where the value at index 'i' is the sum of all coefficients at indices
// that are i mod 16.
for (&packed_coeffs, &mid_twiddle) in
coeff_chunk.iter().zip(high_twiddle_factors.iter())
{
sum = sum + PackedQM31::broadcast(mid_twiddle).mul_packed_m31(packed_coeffs);
}

// Advance twiddle high.
twiddle_high = Self::advance_twiddle(twiddle_high, &twiddle_steps, i);
}

(sum * twiddle_lows).pointwise_sum()
}
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};

use crate::core::backend::avx512::fft::{CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use crate::core::backend::avx512::AVX512Backend;
use crate::core::backend::Column;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};
use crate::core::poly::{BitReversedOrder, NaturalOrder};
use crate::qm31;

#[test]
fn test_interpolate_and_eval() {
Expand Down Expand Up @@ -259,4 +415,47 @@ mod tests {
assert_eq!(eval0.values.to_vec(), eval1.values.to_vec());
}
}

#[test]
fn test_eval_securefield() {
use crate::core::backend::avx512::fft::MIN_FFT_LOG_SIZE;
let rng = &mut StdRng::seed_from_u64(0);

for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 2) {
let domain = CanonicCoset::new(log_size as u32).circle_domain();
let evaluation = CircleEvaluation::<AVX512Backend, _, NaturalOrder>::new(
domain,
(0..(1 << log_size))
.map(BaseField::from_u32_unchecked)
.collect(),
);
let poly = evaluation.bit_reverse().interpolate();

let x = qm31!(
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>()
);
let y = qm31!(
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>()
);

let p = CirclePoint { x, y };

assert_eq!(
<AVX512Backend as PolyOps>::eval_at_securefield_point(&poly, p),
<AVX512Backend as PolyOps>::eval_at_point(&poly, p),
"log_size = {log_size}"
);

println!(
"log_size = {log_size} passed, eval{}",
<AVX512Backend as PolyOps>::eval_at_securefield_point(&poly, p)
);
}
}
}
9 changes: 9 additions & 0 deletions src/core/backend/avx512/m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ impl PackedBaseField {
pub unsafe fn store(self, ptr: *mut i32) {
_mm512_store_epi32(ptr, self.0);
}

/// Sums all the elements in the packed M31 element.
pub fn pointwise_sum(self) -> M31 {
self.to_array().into_iter().sum()
}

pub fn broadcast(x: M31) -> Self {
Self(unsafe { std::arch::x86_64::_mm512_set1_epi32(x.0 as i32) })
}
}

impl Display for PackedBaseField {
Expand Down
Loading

0 comments on commit d3e0c90

Please sign in to comment.