Skip to content

Commit

Permalink
Remove AVX backend (#616)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson authored May 20, 2024
1 parent 23d2946 commit d4372f9
Show file tree
Hide file tree
Showing 25 changed files with 20 additions and 4,499 deletions.
3 changes: 0 additions & 3 deletions crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ nonstandard-style = "deny"
rust-2018-idioms = "deny"
unused = "deny"

[features]
avx512 = []

[[bench]]
name = "bit_rev"
harness = false
Expand Down
24 changes: 0 additions & 24 deletions crates/prover/benches/bit_rev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,6 @@ pub fn simd_bit_rev(c: &mut Criterion) {
});
}

#[cfg(target_arch = "x86_64")]
pub fn avx512_bit_rev(c: &mut Criterion) {
use stwo_prover::core::backend::avx512::bit_reverse::bit_reverse_m31;
use stwo_prover::core::backend::avx512::BaseFieldVec;
const SIZE: usize = 1 << 26;
if !stwo_prover::platform::avx512_detected() {
return;
}
let data = (0..SIZE).map(BaseField::from).collect::<BaseFieldVec>();
c.bench_function("avx bit_rev 26bit", |b| {
b.iter_batched(
|| data.data.clone(),
|mut data| bit_reverse_m31(&mut data),
BatchSize::LargeInput,
);
});
}

#[cfg(target_arch = "x86_64")]
criterion_group!(
name = bit_rev;
config = Criterion::default().sample_size(10);
targets = avx512_bit_rev, simd_bit_rev, cpu_bit_rev);
#[cfg(not(target_arch = "x86_64"))]
criterion_group!(
name = bit_rev;
config = Criterion::default().sample_size(10);
Expand Down
5 changes: 0 additions & 5 deletions crates/prover/benches/eval_at_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@ fn bench_eval_at_secure_point<B: PolyOps>(c: &mut Criterion, id: &str) {
}

fn eval_at_secure_point_benches(c: &mut Criterion) {
#[cfg(target_arch = "x86_64")]
if stwo_prover::platform::avx512_detected() {
use stwo_prover::core::backend::avx512::AVX512Backend;
bench_eval_at_secure_point::<AVX512Backend>(c, "avx");
}
bench_eval_at_secure_point::<SimdBackend>(c, "simd");
bench_eval_at_secure_point::<CPUBackend>(c, "cpu");
}
Expand Down
179 changes: 10 additions & 169 deletions crates/prover/benches/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
use std::hint::black_box;
use std::mem::{size_of_val, transmute};

use criterion::{BatchSize, BenchmarkId, Criterion, Throughput};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput};
use itertools::Itertools;
use stwo_prover::core::backend::simd::column::BaseFieldVec;
use stwo_prover::core::backend::simd::fft::ifft::{
get_itwiddle_dbls, ifft, ifft3_loop, ifft_vecwise_loop,
};
use stwo_prover::core::backend::simd::fft::rfft::{fft, get_twiddle_dbls};
use stwo_prover::core::backend::simd::fft::transpose_vecs;
use stwo_prover::core::backend::simd::m31::PackedBaseField;
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::poly::circle::CanonicCoset;

pub fn simd_ifft(c: &mut Criterion) {
use stwo_prover::core::backend::simd::column::BaseFieldVec;
use stwo_prover::core::backend::simd::fft::ifft::{get_itwiddle_dbls, ifft};

let mut group = c.benchmark_group("iffts");

for log_size in 16..=28 {
Expand All @@ -37,12 +41,6 @@ pub fn simd_ifft(c: &mut Criterion) {
}

pub fn simd_ifft_parts(c: &mut Criterion) {
use stwo_prover::core::backend::simd::column::BaseFieldVec;
use stwo_prover::core::backend::simd::fft::ifft::{
get_itwiddle_dbls, ifft3_loop, ifft_vecwise_loop,
};
use stwo_prover::core::backend::simd::fft::transpose_vecs;

const LOG_SIZE: u32 = 14;

let domain = CanonicCoset::new(LOG_SIZE).circle_domain();
Expand Down Expand Up @@ -104,10 +102,6 @@ pub fn simd_ifft_parts(c: &mut Criterion) {
}

pub fn simd_rfft(c: &mut Criterion) {
use stwo_prover::core::backend::simd::column::BaseFieldVec;
use stwo_prover::core::backend::simd::fft::rfft::{fft, get_twiddle_dbls};
use stwo_prover::core::backend::simd::m31::PackedBaseField;

const LOG_SIZE: u32 = 20;

let domain = CanonicCoset::new(LOG_SIZE).circle_domain();
Expand All @@ -131,161 +125,8 @@ pub fn simd_rfft(c: &mut Criterion) {
});
}

#[cfg(target_arch = "x86_64")]
pub fn avx512_ifft(c: &mut criterion::Criterion) {
use stwo_prover::core::backend::avx512::fft::ifft;
use stwo_prover::platform;
if !platform::avx512_detected() {
return;
}

let mut group = c.benchmark_group("iffts");
for log_size in 16..=28 {
let (values, twiddle_dbls) = prepare_values(log_size);

group.throughput(Throughput::Bytes(
(std::mem::size_of::<BaseField>() as u64) << log_size,
));
group.bench_function(BenchmarkId::new("avx ifft", log_size), |b| {
b.iter_batched(
|| values.clone().data,
|mut values| unsafe {
ifft::ifft(
std::mem::transmute(values.as_mut_ptr()),
&twiddle_dbls
.iter()
.map(|x| x.as_slice())
.collect::<Vec<_>>(),
log_size as usize,
)
},
BatchSize::LargeInput,
);
});
}
}

#[cfg(target_arch = "x86_64")]
pub fn avx512_ifft_parts(c: &mut criterion::Criterion) {
use stwo_prover::core::backend::avx512::fft::{ifft, transpose_vecs};
use stwo_prover::platform;
if !platform::avx512_detected() {
return;
}

let (values, twiddle_dbls) = prepare_values(14);
let mut group = c.benchmark_group("ifft parts");

// Note: These benchmarks run only on 2^14 elements ebcause of their parameters.
// Increasing the figure above won't change the runtime of these benchmarks.
group.throughput(Throughput::Bytes(4 << 14));
group.bench_function("avx ifft_vecwise_loop 2^14", |b| {
b.iter_batched(
|| values.clone().data,
|mut values| unsafe {
ifft::ifft_vecwise_loop(
std::mem::transmute(values.as_mut_ptr()),
&twiddle_dbls
.iter()
.map(|x| x.as_slice())
.collect::<Vec<_>>(),
9,
0,
)
},
BatchSize::LargeInput,
);
});

group.bench_function("avx ifft3_loop 2^14", |b| {
b.iter_batched(
|| values.clone().data,
|mut values| unsafe {
ifft::ifft3_loop(
std::mem::transmute(values.as_mut_ptr()),
&twiddle_dbls
.iter()
.skip(3)
.map(|x| x.as_slice())
.collect::<Vec<_>>(),
7,
4,
0,
)
},
BatchSize::LargeInput,
);
});

let (values, _twiddle_dbls) = prepare_values(20);
group.throughput(Throughput::Bytes(4 << 20));
group.bench_function("avx transpose_vecs 2^20", |b| {
b.iter_batched(
|| values.clone().data,
|mut values| unsafe {
transpose_vecs(std::mem::transmute(values.as_mut_ptr()), (20 - 4) as usize);
},
BatchSize::LargeInput,
);
});
}

#[cfg(target_arch = "x86_64")]
pub fn avx512_rfft(c: &mut criterion::Criterion) {
use stwo_prover::core::backend::avx512::fft::rfft;
use stwo_prover::core::backend::avx512::PackedBaseField;
use stwo_prover::platform;
if !platform::avx512_detected() {
return;
}

const LOG_SIZE: u32 = 20;
let (values, twiddle_dbls) = prepare_values(LOG_SIZE);

c.bench_function("avx rfft 20bit", |b| {
b.iter_with_large_drop(|| unsafe {
let mut target = Vec::<PackedBaseField>::with_capacity(values.data.len());
#[allow(clippy::uninit_vec)]
target.set_len(values.data.len());

rfft::fft(
black_box(std::mem::transmute(values.data.as_ptr())),
std::mem::transmute(target.as_mut_ptr()),
&twiddle_dbls
.iter()
.map(|x| x.as_slice())
.collect::<Vec<_>>(),
LOG_SIZE as usize,
);
})
});
}

#[cfg(target_arch = "x86_64")]
fn prepare_values(
log_size: u32,
) -> (
stwo_prover::core::backend::avx512::BaseFieldVec,
Vec<Vec<i32>>,
) {
use stwo_prover::core::backend::avx512::fft::ifft::get_itwiddle_dbls;
let domain = CanonicCoset::new(log_size).circle_domain();
let values = (0..domain.size())
.map(|i| BaseField::from_u32_unchecked(i as u32))
.collect::<Vec<_>>();
let values = stwo_prover::core::backend::avx512::BaseFieldVec::from_iter(values);
let twiddle_dbls = get_itwiddle_dbls(domain.half_coset);
(values, twiddle_dbls)
}

#[cfg(target_arch = "x86_64")]
criterion::criterion_group!(
name = benches;
config = Criterion::default().sample_size(10);
targets = avx512_ifft, avx512_ifft_parts, avx512_rfft, simd_ifft, simd_ifft_parts, simd_rfft);
#[cfg(not(target_arch = "x86_64"))]
criterion::criterion_group!(
criterion_group!(
name = benches;
config = Criterion::default().sample_size(10);
targets = simd_ifft, simd_ifft_parts, simd_rfft);
criterion::criterion_main!(benches);
criterion_main!(benches);
80 changes: 8 additions & 72 deletions crates/prover/benches/field.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use criterion::Criterion;
use criterion::{criterion_group, criterion_main, Criterion};
use num_traits::One;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use stwo_prover::core::backend::simd::m31::N_LANES;
use stwo_prover::core::backend::simd::m31::{PackedBaseField, N_LANES};
use stwo_prover::core::fields::cm31::CM31;
use stwo_prover::core::fields::m31::{BaseField, M31};
use stwo_prover::core::fields::qm31::SecureField;

pub const N_ELEMENTS: usize = 1 << 16;
pub const N_STATE_ELEMENTS: usize = 8;

pub fn m31_operations_bench(c: &mut criterion::Criterion) {
pub fn m31_operations_bench(c: &mut Criterion) {
let mut rng = SmallRng::seed_from_u64(0);
let elements: Vec<M31> = (0..N_ELEMENTS).map(|_| rng.gen()).collect();
let mut state: [M31; N_STATE_ELEMENTS] = rng.gen();
Expand Down Expand Up @@ -40,7 +40,7 @@ pub fn m31_operations_bench(c: &mut criterion::Criterion) {
});
}

pub fn cm31_operations_bench(c: &mut criterion::Criterion) {
pub fn cm31_operations_bench(c: &mut Criterion) {
let mut rng = SmallRng::seed_from_u64(0);
let elements: Vec<CM31> = (0..N_ELEMENTS).map(|_| rng.gen()).collect();
let mut state: [CM31; N_STATE_ELEMENTS] = rng.gen();
Expand Down Expand Up @@ -70,7 +70,7 @@ pub fn cm31_operations_bench(c: &mut criterion::Criterion) {
});
}

pub fn qm31_operations_bench(c: &mut criterion::Criterion) {
pub fn qm31_operations_bench(c: &mut Criterion) {
let mut rng = SmallRng::seed_from_u64(0);
let elements: Vec<SecureField> = (0..N_ELEMENTS).map(|_| rng.gen()).collect();
let mut state: [SecureField; N_STATE_ELEMENTS] = rng.gen();
Expand Down Expand Up @@ -100,64 +100,7 @@ pub fn qm31_operations_bench(c: &mut criterion::Criterion) {
});
}

#[cfg(target_arch = "x86_64")]
pub fn avx512_m31_operations_bench(c: &mut criterion::Criterion) {
use stwo_prover::core::backend::avx512::m31::{PackedBaseField, K_BLOCK_SIZE};
use stwo_prover::platform;

if !platform::avx512_detected() {
return;
}

let mut rng = SmallRng::seed_from_u64(0);
let mut elements: Vec<PackedBaseField> = Vec::new();
let mut states: Vec<PackedBaseField> =
vec![PackedBaseField::from_array([1.into(); K_BLOCK_SIZE]); N_STATE_ELEMENTS];

for _ in 0..(N_ELEMENTS / K_BLOCK_SIZE) {
elements.push(PackedBaseField::from_array(rng.gen()));
}

c.bench_function("mul_avx512", |b| {
b.iter(|| {
for elem in elements.iter() {
for _ in 0..128 {
for state in states.iter_mut() {
*state *= *elem;
}
}
}
})
});

c.bench_function("add_avx512", |b| {
b.iter(|| {
for elem in elements.iter() {
for _ in 0..128 {
for state in states.iter_mut() {
*state += *elem;
}
}
}
})
});

c.bench_function("sub_avx512", |b| {
b.iter(|| {
for elem in elements.iter() {
for _ in 0..128 {
for state in states.iter_mut() {
*state -= *elem;
}
}
}
})
});
}

pub fn simd_m31_operations_bench(c: &mut criterion::Criterion) {
use stwo_prover::core::backend::simd::m31::PackedBaseField;

pub fn simd_m31_operations_bench(c: &mut Criterion) {
let mut rng = SmallRng::seed_from_u64(0);
let elements: Vec<PackedBaseField> = (0..N_ELEMENTS / N_LANES).map(|_| rng.gen()).collect();
let mut states = vec![PackedBaseField::broadcast(BaseField::one()); N_STATE_ELEMENTS];
Expand Down Expand Up @@ -199,16 +142,9 @@ pub fn simd_m31_operations_bench(c: &mut criterion::Criterion) {
});
}

#[cfg(target_arch = "x86_64")]
criterion::criterion_group!(
name = benches;
config = Criterion::default().sample_size(10);
targets = m31_operations_bench, cm31_operations_bench, qm31_operations_bench,
avx512_m31_operations_bench, simd_m31_operations_bench);
#[cfg(not(target_arch = "x86_64"))]
criterion::criterion_group!(
criterion_group!(
name = benches;
config = Criterion::default().sample_size(10);
targets = m31_operations_bench, cm31_operations_bench, qm31_operations_bench,
simd_m31_operations_bench);
criterion::criterion_main!(benches);
criterion_main!(benches);
5 changes: 0 additions & 5 deletions crates/prover/benches/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@ fn bench_blake2s_merkle<B: MerkleOps<Blake2sMerkleHasher>>(c: &mut Criterion, id
}

fn blake2s_merkle_benches(c: &mut Criterion) {
#[cfg(target_arch = "x86_64")]
if stwo_prover::platform::avx512_detected() {
use stwo_prover::core::backend::avx512::AVX512Backend;
bench_blake2s_merkle::<AVX512Backend>(c, "avx");
}
bench_blake2s_merkle::<SimdBackend>(c, "simd");
bench_blake2s_merkle::<CPUBackend>(c, "cpu");
}
Expand Down
Loading

0 comments on commit d4372f9

Please sign in to comment.