Skip to content

Commit

Permalink
Insert MSM and FFT code and their benchmarks.
Browse files Browse the repository at this point in the history
  • Loading branch information
einar-taiko committed Sep 8, 2023
1 parent 6e2ff38 commit 8797915
Show file tree
Hide file tree
Showing 8 changed files with 436 additions and 1 deletion.
17 changes: 16 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ criterion = { version = "0.3", features = ["html_reports"] }
rand_xorshift = "0.3"
ark-std = { version = "0.3" }
bincode = "1.3.3"
halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", rev="7a21656" }

[dependencies]
subtle = "2.4"
Expand All @@ -31,9 +32,11 @@ paste = "1.0.11"
serde = { version = "1.0", default-features = false, optional = true }
serde_arrays = { version = "0.1.0", optional = true }
blake2b_simd = "1"
maybe-rayon = { version = "0.1.0", default-features = false }

[features]
default = ["reexport", "bits"]
default = ["reexport", "bits", "multicore"]
multicore = ["maybe-rayon/threads"]
asm = []
bits = ["ff/bits"]
bn256-table = []
Expand Down Expand Up @@ -67,3 +70,15 @@ harness = false
[[bench]]
name = "hash_to_curve"
harness = false

[[bench]]
name = "fft"
harness = false

[[bench]]
name = "msm"
harness = false

[[bench]]
name = "msm-alt"
harness = false
24 changes: 24 additions & 0 deletions benches/fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#[macro_use]
extern crate criterion;

use group::ff::Field;
use halo2curves::{fft::best_fft, pasta::Fp};

use criterion::{BenchmarkId, Criterion};
use rand_core::OsRng;

fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("fft");
for k in 3..19 {
group.bench_function(BenchmarkId::new("k", k), |b| {
let mut a = (0..(1 << k)).map(|_| Fp::random(OsRng)).collect::<Vec<_>>();
let omega = Fp::random(OsRng); // would be weird if this mattered
b.iter(|| {
best_fft(&mut a, omega, k as u32);
});
});
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
56 changes: 56 additions & 0 deletions benches/msm-alt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
//! This benchmark allows testing msm without depending on the `halo2_proofs`
//! crate. This code originates in an older version of `halo2_proofs` from
//! before the `hash_to_curve` method was implemented. It currently only uses
//! curve `Secp256k1Affine`

#[macro_use]
extern crate criterion;

use criterion::{black_box, BenchmarkId, Criterion};
use ff::Field;
use halo2_proofs::arithmetic::small_multiexp;
use halo2curves::secp256k1::Fq as Scalar;
use halo2curves::secp256k1::Secp256k1Affine;
use halo2curves::CurveAffine;
use rand_core::OsRng;
use rand_core::SeedableRng;
use rand_xorshift::XorShiftRng;
use std::iter::zip;

fn random_curve_points<C: CurveAffine>(k: u8) -> Vec<Secp256k1Affine> {
debug_assert!(k < 64);
let n: u64 = 1 << k;

let mut rng = XorShiftRng::from_seed([
0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc,
0xe5,
]);

(0..n).map(|_n| Secp256k1Affine::random(&mut rng)).collect()
}

fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("msm-alt");
for k in 8..16 {
group
.bench_function(BenchmarkId::new("k", k), |b| {
let rng = OsRng;

let mut g = random_curve_points::<Secp256k1Affine>(k);
let half_len = g.len() / 2;
let (g_lo, g_hi) = g.split_at_mut(half_len);
let coeff_1 = Scalar::random(rng);
let coeff_2 = Scalar::random(rng);

b.iter(|| {
for (g_lo, g_hi) in zip(g_lo.iter(), g_hi.iter()) {
small_multiexp(&[black_box(coeff_1), black_box(coeff_2)], &[*g_lo, *g_hi]);
}
})
})
.sample_size(30);
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
34 changes: 34 additions & 0 deletions benches/msm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#[macro_use]
extern crate criterion;
use criterion::{black_box, Criterion};
use ff::Field;
use halo2_proofs::poly::{commitment::ParamsProver, ipa::commitment::ParamsIPA};
use halo2curves::msm::small_multiexp;
use pasta_curves::{EqAffine, Fp};
use rand_core::OsRng;

fn criterion_benchmark(c: &mut Criterion) {
let rng = OsRng;

// small multiexp
{
let params: ParamsIPA<EqAffine> = ParamsIPA::new(5);
let g = &mut params.get_g().to_vec();
let len = g.len() / 2;
let (g_lo, g_hi) = g.split_at_mut(len);

let coeff_1 = Fp::random(rng);
let coeff_2 = Fp::random(rng);

c.bench_function("double-and-add", |b| {
b.iter(|| {
for (g_lo, g_hi) in g_lo.iter().zip(g_hi.iter()) {
small_multiexp(&[black_box(coeff_1), black_box(coeff_2)], &[*g_lo, *g_hi]);
}
})
});
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
134 changes: 134 additions & 0 deletions src/fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use crate::multicore;
pub use crate::{CurveAffine, CurveExt};
use ff::Field;
use group::{GroupOpsOwned, ScalarMulOwned};

/// This represents an element of a group with basic operations that can be
/// performed. This allows an FFT implementation (for example) to operate
/// generically over either a field or elliptic curve group.
pub trait FftGroup<Scalar: Field>:
Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
{
}

impl<T, Scalar> FftGroup<Scalar> for T
where
Scalar: Field,
T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
{
}

/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size
/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative
/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when
/// interpreted as the coefficients of a polynomial of degree $n - 1$, is
/// transformed into the evaluations of this polynomial at each of the $n$
/// distinct powers of $\omega$. This transformation is invertible by providing
/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element
/// by $n$.
///
/// This will use multithreading if beneficial.
pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
fn bitreverse(mut n: usize, l: usize) -> usize {
let mut r = 0;
for _ in 0..l {
r = (r << 1) | (n & 1);
n >>= 1;
}
r
}

let threads = multicore::current_num_threads();
let log_threads = threads.ilog2();
let n = a.len();
assert_eq!(n, 1 << log_n);

for k in 0..n {
let rk = bitreverse(k, log_n as usize);
if k < rk {
a.swap(rk, k);
}
}

// precompute twiddle factors
let twiddles: Vec<_> = (0..(n / 2))
.scan(Scalar::ONE, |w, _| {
let tw = *w;
*w *= &omega;
Some(tw)
})
.collect();

if log_n <= log_threads {
let mut chunk = 2_usize;
let mut twiddle_chunk = n / 2;
for _ in 0..log_n {
a.chunks_mut(chunk).for_each(|coeffs| {
let (left, right) = coeffs.split_at_mut(chunk / 2);

// case when twiddle factor is one
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0];
b[0] = a[0];
a[0] += &t;
b[0] -= &t;

left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = *b;
t *= &twiddles[(i + 1) * twiddle_chunk];
*b = *a;
*a += &t;
*b -= &t;
});
});
chunk *= 2;
twiddle_chunk /= 2;
}
} else {
recursive_butterfly_arithmetic(a, n, 1, &twiddles)
}
}

/// This perform recursive butterfly arithmetic
pub fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
a: &mut [G],
n: usize,
twiddle_chunk: usize,
twiddles: &[Scalar],
) {
if n == 2 {
let t = a[1];
a[1] = a[0];
a[0] += &t;
a[1] -= &t;
} else {
let (left, right) = a.split_at_mut(n / 2);
multicore::join(
|| recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
|| recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
);

// case when twiddle factor is one
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0];
b[0] = a[0];
a[0] += &t;
b[0] -= &t;

left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = *b;
t *= &twiddles[(i + 1) * twiddle_chunk];
*b = *a;
*a += &t;
*b -= &t;
});
}
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
mod arithmetic;
pub mod fft;
pub mod hash_to_curve;
pub mod msm;
pub mod multicore;
#[macro_use]
pub mod legendre;
pub mod serde;
Expand Down
Loading

0 comments on commit 8797915

Please sign in to comment.