Skip to content

Commit

Permalink
refactor: Refactor DlogGroup trait and optimize batch operations
Browse files Browse the repository at this point in the history
- the DlogGroup trait is now group-crate aware, and requires traits in those terms,
- the requirements will be further streamlined when zkcrypto/group#48 merges
- simplified declarations boilerplate in halo2curves & pasta macros
- removed boilerplate macro duplication for grumpkin_msm.
  • Loading branch information
huitseeker committed Dec 23, 2023
1 parent 1585c84 commit ae4c556
Show file tree
Hide file tree
Showing 12 changed files with 134 additions and 411 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ once_cell = "1.18.0"
itertools = "0.12.0"
rand = "0.8.5"
ref-cast = "1.0.20"
derive_more = "0.99.17"

[target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies]
pasta-msm = { git = "https://github.com/lurk-lab/pasta-msm", branch = "dev", version = "0.1.4" }
Expand Down
202 changes: 15 additions & 187 deletions src/provider/bn256_grumpkin.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
//! This module implements the Nova traits for `bn256::Point`, `bn256::Scalar`, `grumpkin::Point`, `grumpkin::Scalar`.
use crate::{
provider::{
traits::{CompressedGroup, DlogGroup},
util::msm::cpu_best_msm,
},
impl_traits,
provider::{traits::DlogGroup, util::msm::cpu_best_msm},
traits::{Group, PrimeFieldExt, TranscriptReprTrait},
};
use digest::{ExtendableOutput, Update};
use ff::{FromUniformBytes, PrimeField};
use group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup, GroupEncoding};
use group::{cofactor::CofactorCurveAffine, Curve, Group as AnotherGroup};
use grumpkin_msm::{bn256 as bn256_msm, grumpkin as grumpkin_msm};
use num_bigint::BigInt;
use num_traits::Num;
// Remove this when https://github.com/zcash/pasta_curves/issues/41 resolves
Expand All @@ -17,203 +16,32 @@ use rayon::prelude::*;
use sha3::Shake256;
use std::io::Read;

use halo2curves::bn256::{
G1Affine as Bn256Affine, G1Compressed as Bn256Compressed, G1 as Bn256Point,
};
use halo2curves::grumpkin::{
G1Affine as GrumpkinAffine, G1Compressed as GrumpkinCompressed, G1 as GrumpkinPoint,
};

/// Re-exports that give access to the standard aliases used in the code base, for bn256
pub mod bn256 {
pub use halo2curves::bn256::{Fq as Base, Fr as Scalar, G1Affine as Affine, G1 as Point};
pub use halo2curves::bn256::{
Fq as Base, Fr as Scalar, G1Affine as Affine, G1Compressed as Compressed, G1 as Point,
};
}

/// Re-exports that give access to the standard aliases used in the code base, for grumpkin
pub mod grumpkin {
pub use halo2curves::grumpkin::{Fq as Base, Fr as Scalar, G1Affine as Affine, G1 as Point};
}

macro_rules! impl_traits {
(
$name:ident,
$name_compressed:ident,
$name_curve:ident,
$name_curve_affine:ident,
$order_str:literal,
$base_str:literal
) => {
impl Group for $name::Point {
type Base = $name::Base;
type Scalar = $name::Scalar;

fn group_params() -> (Self::Base, Self::Base, BigInt, BigInt) {
let A = $name::Point::a();
let B = $name::Point::b();
let order = BigInt::from_str_radix($order_str, 16).unwrap();
let base = BigInt::from_str_radix($base_str, 16).unwrap();

(A, B, order, base)
}
}

impl DlogGroup for $name::Point {
type CompressedGroupElement = $name_compressed;
type PreprocessedGroupElement = $name::Affine;

#[tracing::instrument(
skip_all,
level = "trace",
name = "<_ as Group>::vartime_multiscalar_mul"
)]
fn vartime_multiscalar_mul(
scalars: &[Self::Scalar],
bases: &[Self::PreprocessedGroupElement],
) -> Self {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
if scalars.len() >= 128 {
grumpkin_msm::$name(bases, scalars)
} else {
cpu_best_msm(scalars, bases)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
cpu_best_msm(scalars, bases)
}
fn preprocessed(&self) -> Self::PreprocessedGroupElement {
self.to_affine()
}

fn compress(&self) -> Self::CompressedGroupElement {
self.to_bytes()
}

fn from_label(label: &'static [u8], n: usize) -> Vec<Self::PreprocessedGroupElement> {
let mut shake = Shake256::default();
shake.update(label);
let mut reader = shake.finalize_xof();
let mut uniform_bytes_vec = Vec::new();
for _ in 0..n {
let mut uniform_bytes = [0u8; 32];
reader.read_exact(&mut uniform_bytes).unwrap();
uniform_bytes_vec.push(uniform_bytes);
}
let gens_proj: Vec<$name_curve> = (0..n)
.into_par_iter()
.map(|i| {
let hash = $name_curve::hash_to_curve("from_uniform_bytes");
hash(&uniform_bytes_vec[i])
})
.collect();

let num_threads = rayon::current_num_threads();
if gens_proj.len() > num_threads {
let chunk = (gens_proj.len() as f64 / num_threads as f64).ceil() as usize;
(0..num_threads)
.into_par_iter()
.flat_map(|i| {
let start = i * chunk;
let end = if i == num_threads - 1 {
gens_proj.len()
} else {
core::cmp::min((i + 1) * chunk, gens_proj.len())
};
if end > start {
let mut gens = vec![$name_curve_affine::identity(); end - start];
<Self as Curve>::batch_normalize(&gens_proj[start..end], &mut gens);
gens
} else {
vec![]
}
})
.collect()
} else {
let mut gens = vec![$name_curve_affine::identity(); n];
<Self as Curve>::batch_normalize(&gens_proj, &mut gens);
gens
}
}

fn zero() -> Self {
$name::Point::identity()
}

fn to_coordinates(&self) -> (Self::Base, Self::Base, bool) {
let coordinates = self.to_affine().coordinates();
if coordinates.is_some().unwrap_u8() == 1
&& ($name_curve_affine::identity() != self.to_affine())
{
(*coordinates.unwrap().x(), *coordinates.unwrap().y(), false)
} else {
(Self::Base::zero(), Self::Base::zero(), true)
}
}
}

impl PrimeFieldExt for $name::Scalar {
fn from_uniform(bytes: &[u8]) -> Self {
let bytes_arr: [u8; 64] = bytes.try_into().unwrap();
$name::Scalar::from_uniform_bytes(&bytes_arr)
}
}

impl<G: DlogGroup> TranscriptReprTrait<G> for $name_compressed {
fn to_transcript_bytes(&self) -> Vec<u8> {
self.as_ref().to_vec()
}
}

impl CompressedGroup for $name_compressed {
type GroupElement = $name::Point;

fn decompress(&self) -> Option<$name::Point> {
Some($name_curve::from_bytes(&self).unwrap())
}
}

impl<G: Group> TranscriptReprTrait<G> for $name::Scalar {
fn to_transcript_bytes(&self) -> Vec<u8> {
self.to_repr().to_vec()
}
}

impl<G: DlogGroup> TranscriptReprTrait<G> for $name::Affine {
fn to_transcript_bytes(&self) -> Vec<u8> {
let (x, y, is_infinity_byte) = {
let coordinates = self.coordinates();
if coordinates.is_some().unwrap_u8() == 1 && ($name_curve_affine::identity() != *self) {
let c = coordinates.unwrap();
(*c.x(), *c.y(), u8::from(false))
} else {
($name::Base::zero(), $name::Base::zero(), u8::from(false))
}
};

x.to_repr()
.into_iter()
.chain(y.to_repr().into_iter())
.chain(std::iter::once(is_infinity_byte))
.collect()
}
}
pub use halo2curves::grumpkin::{
Fq as Base, Fr as Scalar, G1Affine as Affine, G1Compressed as Compressed, G1 as Point,
};
}

impl_traits!(
bn256,
Bn256Compressed,
Bn256Point,
Bn256Affine,
"30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
"30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47"
"30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47",
bn256_msm
);

impl_traits!(
grumpkin,
GrumpkinCompressed,
GrumpkinPoint,
GrumpkinAffine,
"30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47",
"30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"
"30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
grumpkin_msm
);

#[cfg(test)]
Expand All @@ -237,7 +65,7 @@ mod tests {
.map(|_| bn256::Scalar::random(&mut rng))
.collect::<Vec<_>>();

let cpu_msm = cpu_best_msm(&scalars, &points);
let cpu_msm = cpu_best_msm(&points, &scalars);
let gpu_msm = bn256::Point::vartime_multiscalar_mul(&scalars, &points);

assert_eq!(cpu_msm, gpu_msm);
Expand All @@ -253,7 +81,7 @@ mod tests {
.map(|_| grumpkin::Scalar::random(&mut rng))
.collect::<Vec<_>>();

let cpu_msm = cpu_best_msm(&scalars, &points);
let cpu_msm = cpu_best_msm(&points, &scalars);
let gpu_msm = grumpkin::Point::vartime_multiscalar_mul(&scalars, &points);

assert_eq!(cpu_msm, gpu_msm);
Expand Down
2 changes: 1 addition & 1 deletion src/provider/kzg_commitment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub struct KZGCommitmentEngine<E: Engine> {
impl<E: Engine, NE: NovaEngine<GE = E::G1, Scalar = E::Fr>> CommitmentEngineTrait<NE>
for KZGCommitmentEngine<E>
where
E::G1: DlogGroup<PreprocessedGroupElement = E::G1Affine, Scalar = E::Fr>,
E::G1: DlogGroup<ScalarExt = E::Fr, AffineExt = E::G1Affine>,
E::G1Affine: Serialize + for<'de> Deserialize<'de>,
E::G2Affine: Serialize + for<'de> Deserialize<'de>,
E::Fr: PrimeFieldBits, // TODO due to use of gen_srs_for_testing, make optional
Expand Down
16 changes: 8 additions & 8 deletions src/provider/mlkzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl<E, NE> EvaluationEngine<E, NE>
where
E: Engine,
NE: NovaEngine<GE = E::G1, Scalar = E::Fr>,
E::G1: DlogGroup<PreprocessedGroupElement = E::G1Affine, Scalar = E::Fr>,
E::G1: DlogGroup,
E::Fr: TranscriptReprTrait<E::G1>,
E::G1Affine: TranscriptReprTrait<E::G1>, // TODO: this bound on DlogGroup is really unusable!
{
Expand Down Expand Up @@ -104,7 +104,7 @@ where
E::Fr: Serialize + DeserializeOwned,
E::G1Affine: Serialize + DeserializeOwned,
E::G2Affine: Serialize + DeserializeOwned,
E::G1: DlogGroup<PreprocessedGroupElement = E::G1Affine, Scalar = E::Fr>,
E::G1: DlogGroup<ScalarExt = E::Fr, AffineExt = E::G1Affine>,
<E::G1 as Group>::Base: TranscriptReprTrait<E::G1>, // Note: due to the move of the bound TranscriptReprTrait<G> on G::Base from Group to Engine
E::Fr: PrimeFieldBits, // TODO due to use of gen_srs_for_testing, make optional
E::Fr: TranscriptReprTrait<E::G1>,
Expand Down Expand Up @@ -161,7 +161,7 @@ where

<NE::CE as CommitmentEngineTrait<NE>>::commit(ck, &h)
.comm
.preprocessed()
.to_affine()
};

let kzg_open_batch = |C: &[E::G1Affine],
Expand Down Expand Up @@ -253,7 +253,7 @@ where
.map(|i| {
<NE::CE as CommitmentEngineTrait<NE>>::commit(ck, &polys[i])
.comm
.preprocessed()
.to_affine()
})
.collect();

Expand All @@ -265,7 +265,7 @@ where

// Phase 3 -- create response
let mut com_all = comms.clone();
com_all.insert(0, C.comm.preprocessed());
com_all.insert(0, C.comm.to_affine());
let (w, evals) = kzg_open_batch(&com_all, &polys, &u, transcript);

Ok(EvaluationArgument { comms, w, evals })
Expand Down Expand Up @@ -300,7 +300,7 @@ where

// Compute the commitment to the batched polynomial B(X)
let c_0: E::G1 = C[0].into();
let C_B = (c_0 + NE::GE::vartime_multiscalar_mul(&q_powers[1..k], &C[1..k])).preprocessed();
let C_B = (c_0 + NE::GE::vartime_multiscalar_mul(&q_powers[1..k], &C[1..k])).to_affine();

// Compute the batched openings
// compute B(u_i) = v[i][0] + q*v[i][1] + ... + q^(t-1) * v[i][t-1]
Expand Down Expand Up @@ -356,10 +356,10 @@ where
// obtained from the transcript
let r = Self::compute_challenge(&com, transcript);

if r == E::Fr::ZERO || C.comm == E::G1::zero() {
if r == E::Fr::ZERO || C.comm == E::G1::identity() {
return Err(NovaError::ProofVerifyError);
}
com.insert(0, C.comm.preprocessed()); // set com_0 = C, shifts other commitments to the right
com.insert(0, C.comm.to_affine()); // set com_0 = C, shifts other commitments to the right

let u = vec![r, -r, r * r];

Expand Down
2 changes: 1 addition & 1 deletion src/provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ mod tests {
acc + *base * coeff
});

assert_eq!(naive, cpu_best_msm(&coeffs, &bases))
assert_eq!(naive, cpu_best_msm(&bases, &coeffs))
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions src/provider/non_hiding_kzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ pub struct UVKZGPCS<E> {

impl<E: MultiMillerLoop> UVKZGPCS<E>
where
E::G1: DlogGroup<PreprocessedGroupElement = E::G1Affine, Scalar = E::Fr>,
E::G1: DlogGroup<ScalarExt = E::Fr, AffineExt = E::G1Affine>,
{
/// Generate a commitment for a polynomial
/// Note that the scheme is not hidding
Expand Down Expand Up @@ -327,7 +327,7 @@ mod tests {
fn end_to_end_test_template<E>() -> Result<(), NovaError>
where
E: MultiMillerLoop,
E::G1: DlogGroup<PreprocessedGroupElement = E::G1Affine, Scalar = E::Fr>,
E::G1: DlogGroup<ScalarExt = E::Fr, AffineExt = E::G1Affine>,
E::Fr: PrimeFieldBits,
{
for _ in 0..100 {
Expand Down
Loading

0 comments on commit ae4c556

Please sign in to comment.