Skip to content

Commit

Permalink
Add edge case handling for batch_add (#169)
Browse files Browse the repository at this point in the history
* feat: add edge case handling for batch_add

* feat: handle edge cases in msm + rename functions

* chore: generate test points in parallel

* chore: remove redundant cfg

* refactor: rename msm functions

* chore: remove batch_add w/o edge case handling

* fix: clippy

* fix: leftover comment
  • Loading branch information
davidnevadoc authored Jul 25, 2024
1 parent 3e7e7b8 commit 44e142f
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 31 deletions.
6 changes: 3 additions & 3 deletions benches/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use criterion::{BenchmarkId, Criterion};
use ff::{Field, PrimeField};
use group::prime::PrimeCurveAffine;
use halo2curves::bn256::{Fr as Scalar, G1Affine as Point};
use halo2curves::msm::{best_multiexp, multiexp_serial};
use halo2curves::msm::{msm_best, msm_serial};
use rand_core::{RngCore, SeedableRng};
use rand_xorshift::XorShiftRng;
use rayon::current_thread_index;
Expand Down Expand Up @@ -136,7 +136,7 @@ fn msm(c: &mut Criterion) {
assert!(k < 64);
let n: usize = 1 << k;
let mut acc = Point::identity().into();
b.iter(|| multiexp_serial(&coeffs[b_index][..n], &bases[..n], &mut acc));
b.iter(|| msm_serial(&coeffs[b_index][..n], &bases[..n], &mut acc));
})
.sample_size(10);
}
Expand All @@ -147,7 +147,7 @@ fn msm(c: &mut Criterion) {
assert!(k < 64);
let n: usize = 1 << k;
b.iter(|| {
best_multiexp(&coeffs[b_index][..n], &bases[..n]);
msm_best(&coeffs[b_index][..n], &bases[..n]);
})
})
.sample_size(SAMPLE_SIZE);
Expand Down
102 changes: 74 additions & 28 deletions src/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
// Booth encoding:
// * step by `window` size
// * slice by size of `window + 1``
// * each window overlap by 1 bit
// * append a zero bit to the least significant end
// * each window overlap by 1 bit * append a zero bit to the least significant end
// Indexing rule for example window size 3 where we slice by 4 bits:
// `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
// So we can reduce the bucket size without preprocessing scalars
Expand Down Expand Up @@ -54,14 +53,15 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
}
}

/// Batch addition.
fn batch_add<C: CurveAffine>(
size: usize,
buckets: &mut [BucketAffine<C>],
points: &[SchedulePoint],
bases: &[Affine<C>],
) {
let mut t = vec![C::Base::ZERO; size];
let mut z = vec![C::Base::ZERO; size];
let mut t = vec![C::Base::ZERO; size]; // Stores x2 - x1
let mut z = vec![C::Base::ZERO; size]; // Stores y2 - y1
let mut acc = C::Base::ONE;

for (
Expand All @@ -76,16 +76,42 @@ fn batch_add<C: CurveAffine>(
z,
) in points.iter().zip(t.iter_mut()).zip(z.iter_mut())
{
*z = buckets[*buck_idx].x() - bases[*base_idx].x;
if buckets[*buck_idx].is_inf() {
// We assume bases[*base_idx] != infinity always.
continue;
}

if buckets[*buck_idx].x() == bases[*base_idx].x {
// y-coordinate matches:
// 1. y1 == y2 and sign = false or
// 2. y1 != y2 and sign = true
// => ( y1 == y2) xor !sign
// (This uses the fact that x1 == x2 and both points satisfy the curve eq.)
if (buckets[*buck_idx].y() == bases[*base_idx].y) ^ !*sign {
// Doubling
let x_squared = bases[*base_idx].x.square();
*z = buckets[*buck_idx].y() + buckets[*buck_idx].y(); // 2y
*t = acc * (x_squared + x_squared + x_squared); // acc * 3x^2
acc *= *z;
continue;
}
// P + (-P)
buckets[*buck_idx].set_inf();
continue;
}
// Addition
*z = buckets[*buck_idx].x() - bases[*base_idx].x; // x2 - x1
if *sign {
*t = acc * (buckets[*buck_idx].y() - bases[*base_idx].y);
} else {
*t = acc * (buckets[*buck_idx].y() + bases[*base_idx].y);
}
} // y2 - y1
acc *= *z;
}

acc = acc.invert().unwrap();
acc = acc
.invert()
.expect("Some edge case has not been handled properly");

for (
(
Expand All @@ -99,15 +125,18 @@ fn batch_add<C: CurveAffine>(
z,
) in points.iter().zip(t.iter()).zip(z.iter()).rev()
{
if buckets[*buck_idx].is_inf() {
// We assume bases[*base_idx] != infinity always.
continue;
}
let lambda = acc * t;
acc *= z;

let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x);
acc *= z; // update acc
let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); // x_result
if *sign {
buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) - bases[*base_idx].y));
} else {
buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) + bases[*base_idx].y));
}
} // y_result = lambda * (x1 - x_result) - y1
buckets[*buck_idx].set_x(&x);
}
}
Expand Down Expand Up @@ -207,6 +236,13 @@ impl<C: CurveAffine> BucketAffine<C> {
}
}

fn is_inf(&self) -> bool {
match self {
Self::None => true,
Self::Point(_) => false,
}
}

fn set_x(&mut self, x: &C::Base) {
match self {
Self::None => panic!("::set_x None"),
Expand All @@ -220,6 +256,13 @@ impl<C: CurveAffine> BucketAffine<C> {
Self::Point(ref mut a) => a.y = *y,
}
}

fn set_inf(&mut self) {
match self {
Self::None => {}
Self::Point(_) => *self = Self::None,
}
}
}

struct Schedule<C: CurveAffine> {
Expand Down Expand Up @@ -286,7 +329,10 @@ impl<C: CurveAffine> Schedule<C> {
}
}

pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
/// Performs a multi-scalar multiplication operation.
///
/// This function will panic if coeffs and bases have a different length.
pub fn msm_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();

let c = if bases.len() < 4 {
Expand All @@ -303,7 +349,7 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &
let mut acc_or = vec![0; field_byte_size];
for coeff in &coeffs {
for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
*acc_limb = *acc_limb | *limb;
*acc_limb |= *limb;
}
}
let max_byte_size = field_byte_size
Expand All @@ -315,7 +361,7 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &
if max_byte_size == 0 {
return;
}
let number_of_windows = max_byte_size * 8 as usize / c + 1;
let number_of_windows = max_byte_size * 8_usize / c + 1;

for current_window in (0..number_of_windows).rev() {
for _ in 0..c {
Expand Down Expand Up @@ -377,12 +423,12 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &
}
}

/// Performs a multi-exponentiation operation.
/// Performs a multi-scalar multiplication operation.
///
/// This function will panic if coeffs and bases have a different length.
///
/// This will use multithreading if beneficial.
pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
pub fn msm_parallel<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());

let num_threads = rayon::current_num_threads();
Expand All @@ -399,25 +445,22 @@ pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu
.zip(results.iter_mut())
{
scope.spawn(move |_| {
multiexp_serial(coeffs, bases, acc);
msm_serial(coeffs, bases, acc);
});
}
});
results.iter().fold(C::Curve::identity(), |a, b| a + b)
} else {
let mut acc = C::Curve::identity();
multiexp_serial(coeffs, bases, &mut acc);
msm_serial(coeffs, bases, &mut acc);
acc
}
}
///

/// This function will panic if coeffs and bases have a different length.
///
/// This will use multithreading if beneficial.
pub fn best_multiexp_independent_points<C: CurveAffine>(
coeffs: &[C::Scalar],
bases: &[C],
) -> C::Curve {
pub fn msm_best<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());

// TODO: consider adjusting it with emprical data?
Expand All @@ -430,7 +473,7 @@ pub fn best_multiexp_independent_points<C: CurveAffine>(
};

if c < 10 {
return best_multiexp(coeffs, bases);
return msm_parallel(coeffs, bases);
}

// coeffs to byte representation
Expand Down Expand Up @@ -491,7 +534,6 @@ pub fn best_multiexp_independent_points<C: CurveAffine>(

#[cfg(test)]
mod test {

use std::ops::Neg;

use crate::bn256::{Fr, G1Affine, G1};
Expand Down Expand Up @@ -548,27 +590,31 @@ mod test {
}

fn run_msm_cross<C: CurveAffine>(min_k: usize, max_k: usize) {
use rayon::iter::{IntoParallelIterator, ParallelIterator};

let points = (0..1 << max_k)
.into_par_iter()
.map(|_| C::Curve::random(OsRng))
.collect::<Vec<_>>();
let mut affine_points = vec![C::identity(); 1 << max_k];
C::Curve::batch_normalize(&points[..], &mut affine_points[..]);
let points = affine_points;

let scalars = (0..1 << max_k)
.into_par_iter()
.map(|_| C::Scalar::random(OsRng))
.collect::<Vec<_>>();

for k in min_k..=max_k {
let points = &points[..1 << k];
let scalars = &scalars[..1 << k];

let t0 = start_timer!(|| format!("cyclone k={}", k));
let e0 = super::best_multiexp_independent_points(scalars, points);
let t0 = start_timer!(|| format!("cyclone indep k={}", k));
let e0 = super::msm_best(scalars, points);
end_timer!(t0);

let t1 = start_timer!(|| format!("older k={}", k));
let e1 = super::best_multiexp(scalars, points);
let e1 = super::msm_parallel(scalars, points);
end_timer!(t1);
assert_eq!(e0, e1);
}
Expand Down

0 comments on commit 44e142f

Please sign in to comment.