Skip to content

Commit

Permalink
feat: implement svdw hash-to-curve for bn254 and grumpkin
Browse files Browse the repository at this point in the history
  • Loading branch information
han0110 committed Jun 11, 2023
1 parent 21def8d commit 3ba1790
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 5 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ num-traits = "0.2"
paste = "1.0.11"
serde = { version = "1.0", default-features = false, optional = true }
serde_arrays = { version = "0.1.0", optional = true }
blake2b_simd = "1"

[features]
default = ["reexport", "bits"]
Expand Down
8 changes: 8 additions & 0 deletions src/bn256/curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::ff::WithSmallOrderMulGroup;
use crate::ff::{Field, PrimeField};
use crate::group::Curve;
use crate::group::{cofactor::CofactorGroup, prime::PrimeCurveAffine, Group, GroupEncoding};
use crate::hash_to_curve::svdw_map_to_curve;
use crate::{
batch_add, impl_add_binop_specify_output, impl_binops_additive,
impl_binops_additive_specify_output, impl_binops_multiplicative,
Expand Down Expand Up @@ -37,6 +38,7 @@ new_curve_impl!(
(G1_GENERATOR_X,G1_GENERATOR_Y),
G1_B,
"bn256_g1",
|curve_id, domain_prefix| svdw_map_to_curve(curve_id, domain_prefix, Fq::ONE),
);

new_curve_impl!(
Expand All @@ -49,6 +51,7 @@ new_curve_impl!(
(G2_GENERATOR_X, G2_GENERATOR_Y),
G2_B,
"bn256_g2",
|_, _| unimplemented!(),
);

impl CurveAffineExt for G1Affine {
Expand Down Expand Up @@ -215,6 +218,11 @@ mod tests {
use ff::WithSmallOrderMulGroup;
use rand_core::OsRng;

#[test]
fn test_hash_to_curve() {
crate::tests::curve::hash_to_curve_test::<G1>();
}

#[test]
fn test_curve() {
crate::tests::curve::curve_tests::<G1>();
Expand Down
5 changes: 3 additions & 2 deletions src/derive/curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ macro_rules! new_curve_impl {
$generator:expr,
$constant_b:expr,
$curve_id:literal,
$hash_to_curve:expr,
) => {

macro_rules! impl_compressed {
Expand Down Expand Up @@ -606,8 +607,8 @@ macro_rules! new_curve_impl {
}


fn hash_to_curve<'a>(_: &'a str) -> Box<dyn Fn(&[u8]) -> Self + 'a> {
unimplemented!();
fn hash_to_curve<'a>(domain_prefix: &'a str) -> Box<dyn Fn(&[u8]) -> Self + 'a> {
$hash_to_curve($curve_id, domain_prefix)
}

fn is_on_curve(&self) -> Choice {
Expand Down
7 changes: 7 additions & 0 deletions src/grumpkin/curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::group::Curve;
use crate::group::{prime::PrimeCurveAffine, Group, GroupEncoding};
use crate::grumpkin::Fq;
use crate::grumpkin::Fr;
use crate::hash_to_curve::svdw_map_to_curve;
use crate::{
batch_add, impl_add_binop_specify_output, impl_binops_additive,
impl_binops_additive_specify_output, impl_binops_multiplicative,
Expand All @@ -30,6 +31,7 @@ new_curve_impl!(
(G1_GENERATOR_X, G1_GENERATOR_Y),
G1_B,
"grumpkin_g1",
|curve_id, domain_prefix| svdw_map_to_curve(curve_id, domain_prefix, Fq::ONE),
);

impl CurveAffineExt for G1Affine {
Expand Down Expand Up @@ -78,6 +80,11 @@ mod tests {
use crate::CurveExt;
use ff::WithSmallOrderMulGroup;

#[test]
fn test_hash_to_curve() {
crate::tests::curve::hash_to_curve_test::<G1>();
}

#[test]
fn test_curve() {
crate::tests::curve::curve_tests::<G1>();
Expand Down
198 changes: 198 additions & 0 deletions src/hash_to_curve.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
use ff::{Field, FromUniformBytes, PrimeField};
use pasta_curves::arithmetic::CurveExt;
use static_assertions::const_assert;
use subtle::{ConditionallySelectable, ConstantTimeEq};

/// Hashes over a message and writes the output to all of `buf`.
/// Modified from https://github.com/zcash/pasta_curves/blob/7e3fc6a4919f6462a32b79dd226cb2587b7961eb/src/hashtocurve.rs#L11.
fn hash_to_field<F: FromUniformBytes<64>>(
method: &str,
curve_id: &str,
domain_prefix: &str,
message: &[u8],
buf: &mut [F; 2],
) {
assert!(domain_prefix.len() < 256);
assert!((18 + method.len() + curve_id.len() + domain_prefix.len()) < 256);

// Assume that the field size is 32 bytes and k is 256, where k is defined in
// <https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-10.html#name-security-considerations-3>.
const CHUNKLEN: usize = 64;
const_assert!(CHUNKLEN * 2 < 256);

// Input block size of BLAKE2b.
const R_IN_BYTES: usize = 128;

let personal = [0u8; 16];
let empty_hasher = blake2b_simd::Params::new()
.hash_length(CHUNKLEN)
.personal(&personal)
.to_state();

let b_0 = empty_hasher
.clone()
.update(&[0; R_IN_BYTES])
.update(message)
.update(&[0, (CHUNKLEN * 2) as u8, 0])
.update(domain_prefix.as_bytes())
.update(b"-")
.update(curve_id.as_bytes())
.update(b"_XMD:BLAKE2b_")
.update(method.as_bytes())
.update(b"_RO_")
.update(&[(18 + method.len() + curve_id.len() + domain_prefix.len()) as u8])
.finalize();

let b_1 = empty_hasher
.clone()
.update(b_0.as_array())
.update(&[1])
.update(domain_prefix.as_bytes())
.update(b"-")
.update(curve_id.as_bytes())
.update(b"_XMD:BLAKE2b_")
.update(method.as_bytes())
.update(b"_RO_")
.update(&[(18 + method.len() + curve_id.len() + domain_prefix.len()) as u8])
.finalize();

let b_2 = {
let mut empty_hasher = empty_hasher;
for (l, r) in b_0.as_array().iter().zip(b_1.as_array().iter()) {
empty_hasher.update(&[*l ^ *r]);
}
empty_hasher
.update(&[2])
.update(domain_prefix.as_bytes())
.update(b"-")
.update(curve_id.as_bytes())
.update(b"_XMD:BLAKE2b_")
.update(method.as_bytes())
.update(b"_RO_")
.update(&[(18 + method.len() + curve_id.len() + domain_prefix.len()) as u8])
.finalize()
};

for (big, buf) in [b_1, b_2].iter().zip(buf.iter_mut()) {
let mut little = [0u8; CHUNKLEN];
little.copy_from_slice(big.as_array());
little.reverse();
*buf = F::from_uniform_bytes(&little);
}
}

/// Implementation of https://www.ietf.org/id/draft-irtf-cfrg-hash-to-curve-16.html#name-shallue-van-de-woestijne-met
#[allow(clippy::type_complexity)]
pub(crate) fn svdw_map_to_curve<'a, C>(
curve_id: &'static str,
domain_prefix: &'a str,
z: C::Base,
) -> Box<dyn Fn(&[u8]) -> C + 'a>
where
C: CurveExt,
C::Base: FromUniformBytes<64>,
{
let one = C::Base::ONE;
let three = one + one + one;
let four = three + one;
let a = C::a();
let b = C::b();
let tmp = three * z.square() + four * a;

// Precomputed constants:
// 1. c1 = g(Z)
let c1 = (z.square() + a) * z + b;
// 2. c2 = -Z / 2
let c2 = -z * C::Base::TWO_INV;
// 3. c3 = sqrt(-g(Z) * (3 * Z^2 + 4 * A)) # sgn0(c3) MUST equal 0
let c3 = {
let c3 = (-c1 * tmp).sqrt().unwrap();
C::Base::conditional_select(&c3, &-c3, c3.is_odd())
};
// 4. c4 = -4 * g(Z) / (3 * Z^2 + 4 * A)
let c4 = -four * c1 * tmp.invert().unwrap();

Box::new(move |message| {
let mut us = [C::Base::ZERO; 2];
hash_to_field("SVDW", curve_id, domain_prefix, message, &mut us);

let [q0, q1] = us.map(|u| {
// 1. tv1 = u^2
let tv1 = u.square();
// 2. tv1 = tv1 * c1
let tv1 = tv1 * c1;
// 3. tv2 = 1 + tv1
let tv2 = one + tv1;
// 4. tv1 = 1 - tv1
let tv1 = one - tv1;
// 5. tv3 = tv1 * tv2
let tv3 = tv1 * tv2;
// 6. tv3 = inv0(tv3)
let tv3 = tv3.invert().unwrap_or(C::Base::ZERO);
// 7. tv4 = u * tv1
let tv4 = u * tv1;
// 8. tv4 = tv4 * tv3
let tv4 = tv4 * tv3;
// 9. tv4 = tv4 * c3
let tv4 = tv4 * c3;
// 10. x1 = c2 - tv4
let x1 = c2 - tv4;
// 11. gx1 = x1^2
let gx1 = x1.square();
// 12. gx1 = gx1 + A
let gx1 = gx1 + a;
// 13. gx1 = gx1 * x1
let gx1 = gx1 * x1;
// 14. gx1 = gx1 + B
let gx1 = gx1 + b;
// 15. e1 = is_square(gx1)
let e1 = gx1.sqrt().is_some();
// 16. x2 = c2 + tv4
let x2 = c2 + tv4;
// 17. gx2 = x2^2
let gx2 = x2.square();
// 18. gx2 = gx2 + A
let gx2 = gx2 + a;
// 19. gx2 = gx2 * x2
let gx2 = gx2 * x2;
// 20. gx2 = gx2 + B
let gx2 = gx2 + b;
// 21. e2 = is_square(gx2) AND NOT e1 # Avoid short-circuit logic ops
let e2 = gx2.sqrt().is_some() & (!e1);
// 22. x3 = tv2^2
let x3 = tv2.square();
// 23. x3 = x3 * tv3
let x3 = x3 * tv3;
// 24. x3 = x3^2
let x3 = x3.square();
// 25. x3 = x3 * c4
let x3 = x3 * c4;
// 26. x3 = x3 + Z
let x3 = x3 + z;
// 27. x = CMOV(x3, x1, e1) # x = x1 if gx1 is square, else x = x3
let x = C::Base::conditional_select(&x3, &x1, e1);
// 28. x = CMOV(x, x2, e2) # x = x2 if gx2 is square and gx1 is not
let x = C::Base::conditional_select(&x, &x2, e2);
// 29. gx = x^2
let gx = x.square();
// 30. gx = gx + A
let gx = gx + a;
// 31. gx = gx * x
let gx = gx * x;
// 32. gx = gx + B
let gx = gx + b;
// 33. y = sqrt(gx)
let y = gx.sqrt().unwrap();
// 34. e3 = sgn0(u) == sgn0(y)
let e3 = u.is_odd().ct_eq(&y.is_odd());
// 35. y = CMOV(-y, y, e3) # Select correct sign of y
let y = C::Base::conditional_select(&-y, &y, e3);
// 36. return (x, y)
C::new_jacobian(x, y, one).unwrap()
});

let r = q0 + &q1;
debug_assert!(bool::from(r.is_on_curve()));
r
})
}
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#![cfg_attr(feature = "asm", feature(asm_const))]
#![allow(clippy::op_ref)]

mod arithmetic;
pub mod hash_to_curve;
pub mod pairing;
pub mod serde;

pub mod bn256;
pub mod grumpkin;
pub mod pairing;
pub mod pasta;
pub mod secp256k1;
pub mod serde;

#[macro_use]
mod derive;
Expand Down
1 change: 1 addition & 0 deletions src/secp256k1/curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ new_curve_impl!(
(SECP_GENERATOR_X,SECP_GENERATOR_Y),
SECP_B,
"secp256k1",
|_, _| unimplemented!(),
);

impl CurveAffineExt for Secp256k1Affine {
Expand Down
15 changes: 14 additions & 1 deletion src/tests/curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use crate::ff::Field;
use crate::group::prime::PrimeCurveAffine;
use crate::{group::GroupEncoding, serde::SerdeObject};
use crate::{CurveAffine, CurveExt};
use rand_core::OsRng;
use rand_core::{OsRng, RngCore};
use std::iter;

#[cfg(feature = "derive_serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -314,3 +315,15 @@ fn multiplication<G: CurveExt>() {
assert_eq!(t0, t1);
}
}

pub fn hash_to_curve_test<G: CurveExt>() {
let hasher = G::hash_to_curve("test");
let mut rng = OsRng;
for _ in 0..1000 {
let message = iter::repeat_with(|| rng.next_u32().to_be_bytes())
.take(32)
.flatten()
.collect::<Vec<_>>();
assert!(bool::from(hasher(&message).is_on_curve()));
}
}

0 comments on commit 3ba1790

Please sign in to comment.