From 67b23230cd1881c9ef0af7b2f9a5074b53c866e3 Mon Sep 17 00:00:00 2001 From: Gregor Date: Mon, 30 Sep 2024 13:20:23 +0200 Subject: [PATCH 01/15] minimize assumptions on field trait for poseidon impl --- poseidon/src/lib.rs | 1 + poseidon/src/minimal_field.rs | 44 +++++++++++++++++++++++++++++++++++ poseidon/src/permutation.rs | 10 ++++---- poseidon/src/poseidon.rs | 14 +++++------ 4 files changed, 57 insertions(+), 12 deletions(-) create mode 100644 poseidon/src/minimal_field.rs diff --git a/poseidon/src/lib.rs b/poseidon/src/lib.rs index 943d54cb87..c3dbfe3ef8 100644 --- a/poseidon/src/lib.rs +++ b/poseidon/src/lib.rs @@ -1,5 +1,6 @@ pub mod constants; pub mod dummy_values; +pub mod minimal_field; pub mod pasta; pub mod permutation; pub mod poseidon; diff --git a/poseidon/src/minimal_field.rs b/poseidon/src/minimal_field.rs new file mode 100644 index 0000000000..367b82b740 --- /dev/null +++ b/poseidon/src/minimal_field.rs @@ -0,0 +1,44 @@ +use ark_ff::{BitIteratorBE, One, Zero}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use std::ops::{Add, AddAssign, Mul, MulAssign}; + +/** + * Minimal Field trait needed to implement Poseidon + */ +pub trait MinimalField: + 'static + + Copy + + Clone + + CanonicalSerialize + + CanonicalDeserialize + + Zero + + One + + for<'a> Add<&'a Self, Output = Self> + + for<'a> Mul<&'a Self, Output = Self> + + for<'a> AddAssign<&'a Self> + + for<'a> MulAssign<&'a Self> +{ + /// Squares `self` in place. + fn square_in_place(&mut self) -> &mut Self; + + /// Returns `self^exp`, where `exp` is an integer represented with `u64` limbs, + /// least significant limb first. + fn pow>(&self, exp: S) -> Self { + let mut res = Self::one(); + + for i in BitIteratorBE::without_leading_zeros(exp) { + res.square_in_place(); + + if i { + res *= self; + } + } + res + } +} + +impl MinimalField for F { + fn square_in_place(&mut self) -> &mut Self { + self.square_in_place() + } +} diff --git a/poseidon/src/permutation.rs b/poseidon/src/permutation.rs index b3d58776a5..28223853f8 100644 --- a/poseidon/src/permutation.rs +++ b/poseidon/src/permutation.rs @@ -2,11 +2,11 @@ use crate::{ constants::SpongeConstants, + minimal_field::MinimalField, poseidon::{sbox, ArithmeticSpongeParams}, }; -use ark_ff::Field; -fn apply_mds_matrix( +fn apply_mds_matrix( params: &ArithmeticSpongeParams, state: &[F], ) -> Vec { @@ -30,7 +30,7 @@ fn apply_mds_matrix( } } -pub fn full_round( +pub fn full_round( params: &ArithmeticSpongeParams, state: &mut Vec, r: usize, @@ -44,7 +44,7 @@ pub fn full_round( } } -pub fn half_rounds( +pub fn half_rounds( params: &ArithmeticSpongeParams, state: &mut [F], ) { @@ -84,7 +84,7 @@ pub fn half_rounds( } } -pub fn poseidon_block_cipher( +pub fn poseidon_block_cipher( params: &ArithmeticSpongeParams, state: &mut Vec, ) { diff --git a/poseidon/src/poseidon.rs b/poseidon/src/poseidon.rs index 2ce1d1f3d3..c189a46780 100644 --- a/poseidon/src/poseidon.rs +++ b/poseidon/src/poseidon.rs @@ -2,16 +2,16 @@ use crate::{ constants::SpongeConstants, + minimal_field::MinimalField, permutation::{full_round, poseidon_block_cipher}, }; -use ark_ff::Field; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use serde::{Deserialize, Serialize}; use serde_with::serde_as; /// Cryptographic sponge interface - for hashing an arbitrary amount of /// data into one or more field elements -pub trait Sponge { +pub trait Sponge { /// Create a new cryptographic sponge using arithmetic sponge `params` fn new(params: &'static ArithmeticSpongeParams) -> Self; @@ -25,7 +25,7 @@ pub trait Sponge { fn reset(&mut self); } -pub fn sbox(x: F) -> F { +pub fn sbox(x: F) -> F { x.pow([SC::PERM_SBOX as u64]) } @@ -37,7 +37,7 @@ pub enum SpongeState { #[serde_as] #[derive(Clone, Serialize, Deserialize, Default, Debug)] -pub struct ArithmeticSpongeParams { +pub struct ArithmeticSpongeParams { #[serde_as(as = "Vec>")] pub round_constants: Vec>, #[serde_as(as = "Vec>")] @@ -45,7 +45,7 @@ pub struct ArithmeticSpongeParams { +pub struct ArithmeticSponge { pub sponge_state: SpongeState, rate: usize, // TODO(mimoo: an array enforcing the width is better no? or at least an assert somewhere) @@ -54,7 +54,7 @@ pub struct ArithmeticSponge { pub constants: std::marker::PhantomData, } -impl ArithmeticSponge { +impl ArithmeticSponge { pub fn full_round(&mut self, r: usize) { full_round::(self.params, &mut self.state, r); } @@ -64,7 +64,7 @@ impl ArithmeticSponge { } } -impl Sponge for ArithmeticSponge { +impl Sponge for ArithmeticSponge { fn new(params: &'static ArithmeticSpongeParams) -> ArithmeticSponge { let capacity = SC::SPONGE_CAPACITY; let rate = SC::SPONGE_RATE; From 25215e1fcca68c9bd9c892b4a0c8903cf7cb002b Mon Sep 17 00:00:00 2001 From: Gregor Date: Mon, 30 Sep 2024 13:38:02 +0200 Subject: [PATCH 02/15] move into src/pasta --- curves/Cargo.toml | 4 ++-- curves/src/pasta/mod.rs | 1 + .../src => curves/src/pasta/wasm_friendly}/minimal_field.rs | 0 curves/src/pasta/wasm_friendly/mod.rs | 2 ++ curves/src/pasta/wasm_friendly/wasm_fp.rs | 0 poseidon/src/lib.rs | 1 - poseidon/src/permutation.rs | 3 ++- poseidon/src/poseidon.rs | 2 +- 8 files changed, 8 insertions(+), 5 deletions(-) rename {poseidon/src => curves/src/pasta/wasm_friendly}/minimal_field.rs (100%) create mode 100644 curves/src/pasta/wasm_friendly/mod.rs create mode 100644 curves/src/pasta/wasm_friendly/wasm_fp.rs diff --git a/curves/Cargo.toml b/curves/Cargo.toml index 2116b66c86..e5374af2d6 100644 --- a/curves/Cargo.toml +++ b/curves/Cargo.toml @@ -12,10 +12,10 @@ license = "Apache-2.0" [dependencies] ark-ec.workspace = true ark-ff.workspace = true +ark-serialize.workspace = true [dev-dependencies] -rand.workspace = true +rand.workspace = true ark-test-curves.workspace = true ark-algebra-test-templates.workspace = true -ark-serialize.workspace = true ark-std.workspace = true diff --git a/curves/src/pasta/mod.rs b/curves/src/pasta/mod.rs index abd4207fa3..4d3eb86cfa 100644 --- a/curves/src/pasta/mod.rs +++ b/curves/src/pasta/mod.rs @@ -1,5 +1,6 @@ pub mod curves; pub mod fields; +pub mod wasm_friendly; pub use curves::{ pallas::{Pallas, PallasParameters, ProjectivePallas}, diff --git a/poseidon/src/minimal_field.rs b/curves/src/pasta/wasm_friendly/minimal_field.rs similarity index 100% rename from poseidon/src/minimal_field.rs rename to curves/src/pasta/wasm_friendly/minimal_field.rs diff --git a/curves/src/pasta/wasm_friendly/mod.rs b/curves/src/pasta/wasm_friendly/mod.rs new file mode 100644 index 0000000000..9cf62fcece --- /dev/null +++ b/curves/src/pasta/wasm_friendly/mod.rs @@ -0,0 +1,2 @@ +pub mod minimal_field; +pub mod wasm_fp; diff --git a/curves/src/pasta/wasm_friendly/wasm_fp.rs b/curves/src/pasta/wasm_friendly/wasm_fp.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/poseidon/src/lib.rs b/poseidon/src/lib.rs index c3dbfe3ef8..943d54cb87 100644 --- a/poseidon/src/lib.rs +++ b/poseidon/src/lib.rs @@ -1,6 +1,5 @@ pub mod constants; pub mod dummy_values; -pub mod minimal_field; pub mod pasta; pub mod permutation; pub mod poseidon; diff --git a/poseidon/src/permutation.rs b/poseidon/src/permutation.rs index 28223853f8..468e2e1f0f 100644 --- a/poseidon/src/permutation.rs +++ b/poseidon/src/permutation.rs @@ -1,8 +1,9 @@ //! The permutation module contains the function implementing the permutation used in Poseidon +use mina_curves::pasta::wasm_friendly::minimal_field::MinimalField; + use crate::{ constants::SpongeConstants, - minimal_field::MinimalField, poseidon::{sbox, ArithmeticSpongeParams}, }; diff --git a/poseidon/src/poseidon.rs b/poseidon/src/poseidon.rs index c189a46780..8f35e735b6 100644 --- a/poseidon/src/poseidon.rs +++ b/poseidon/src/poseidon.rs @@ -2,10 +2,10 @@ use crate::{ constants::SpongeConstants, - minimal_field::MinimalField, permutation::{full_round, poseidon_block_cipher}, }; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use mina_curves::pasta::wasm_friendly::minimal_field::MinimalField; use serde::{Deserialize, Serialize}; use serde_with::serde_as; From e437e5fd922225e5d49ba718581cbc549bddc525 Mon Sep 17 00:00:00 2001 From: Gregor Date: Mon, 30 Sep 2024 16:09:50 +0200 Subject: [PATCH 03/15] reduce minimal field definition to 3 functions and 3 constants --- Cargo.lock | 1 + curves/Cargo.toml | 1 + curves/src/pasta/wasm_friendly/bigint32.rs | 52 ++++++ curves/src/pasta/wasm_friendly/mod.rs | 6 + curves/src/pasta/wasm_friendly/wasm_fp.rs | 190 +++++++++++++++++++++ poseidon/src/poseidon.rs | 2 +- 6 files changed, 251 insertions(+), 1 deletion(-) create mode 100644 curves/src/pasta/wasm_friendly/bigint32.rs diff --git a/Cargo.lock b/Cargo.lock index 1046a81957..45d69b89f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1313,6 +1313,7 @@ dependencies = [ "ark-serialize", "ark-std", "ark-test-curves", + "derivative", "rand", ] diff --git a/curves/Cargo.toml b/curves/Cargo.toml index e5374af2d6..027108f4ba 100644 --- a/curves/Cargo.toml +++ b/curves/Cargo.toml @@ -13,6 +13,7 @@ license = "Apache-2.0" ark-ec.workspace = true ark-ff.workspace = true ark-serialize.workspace = true +derivative = { version = "2.0", features = ["use_core"] } [dev-dependencies] rand.workspace = true diff --git a/curves/src/pasta/wasm_friendly/bigint32.rs b/curves/src/pasta/wasm_friendly/bigint32.rs new file mode 100644 index 0000000000..f7e2357795 --- /dev/null +++ b/curves/src/pasta/wasm_friendly/bigint32.rs @@ -0,0 +1,52 @@ +/** + * BigInt with 32-bit limbs + * + * Contains everything for wasm_fp which is unrelated to being a field + * + * Code is mostly copied from ark-ff::BigInt + */ +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, + Write, +}; + +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +pub struct BigInt(pub [u32; N]); + +impl Default for BigInt { + fn default() -> Self { + Self([0u32; N]) + } +} + +impl CanonicalSerialize for BigInt { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.0.serialized_size(compress) + } +} + +impl Valid for BigInt { + fn check(&self) -> Result<(), SerializationError> { + self.0.check() + } +} + +impl CanonicalDeserialize for BigInt { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + Ok(BigInt::(<[u32; N]>::deserialize_with_mode( + reader, compress, validate, + )?)) + } +} diff --git a/curves/src/pasta/wasm_friendly/mod.rs b/curves/src/pasta/wasm_friendly/mod.rs index 9cf62fcece..3609e31e96 100644 --- a/curves/src/pasta/wasm_friendly/mod.rs +++ b/curves/src/pasta/wasm_friendly/mod.rs @@ -1,2 +1,8 @@ +pub mod bigint32; +pub use bigint32::BigInt; + pub mod minimal_field; +pub use minimal_field::MinimalField; + pub mod wasm_fp; +pub use wasm_fp::Fp; diff --git a/curves/src/pasta/wasm_friendly/wasm_fp.rs b/curves/src/pasta/wasm_friendly/wasm_fp.rs index e69de29bb2..d913a1c3e4 100644 --- a/curves/src/pasta/wasm_friendly/wasm_fp.rs +++ b/curves/src/pasta/wasm_friendly/wasm_fp.rs @@ -0,0 +1,190 @@ +/** + * MinimalField trait implementation `Fp` which only depends on an `FpConfig` trait + * + * Most of this code was copied over from ark_ff::Fp + */ +use crate::pasta::wasm_friendly::bigint32::BigInt; +use ark_ff::{One, Zero}; +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, + Write, +}; +use derivative::Derivative; +use std::{ + marker::PhantomData, + ops::{Add, AddAssign, Mul, MulAssign}, +}; + +use super::minimal_field::MinimalField; + +pub trait FpConfig: Send + Sync + 'static + Sized { + /// The modulus of the field. + const MODULUS: BigInt; + + /// Additive identity of the field, i.e. the element `e` + /// such that, for all elements `f` of the field, `e + f = f`. + const ZERO: Fp; + + /// Multiplicative identity of the field, i.e. the element `e` + /// such that, for all elements `f` of the field, `e * f = f`. + const ONE: Fp; + + /// Set a += b. + fn add_assign(a: &mut Fp, b: &Fp); + + /// Set a *= b. + fn mul_assign(a: &mut Fp, b: &Fp); + + /// Construct a field element from an integer in the range + /// `0..(Self::MODULUS - 1)`. Returns `None` if the integer is outside + /// this range. + fn from_bigint(other: BigInt) -> Option>; +} + +/// Represents an element of the prime field F_p, where `p == P::MODULUS`. +/// This type can represent elements in any field of size at most N * 64 bits. +#[derive(Derivative)] +#[derivative( + Default(bound = ""), + Hash(bound = ""), + Clone(bound = ""), + Copy(bound = ""), + PartialEq(bound = ""), + Eq(bound = "") +)] +pub struct Fp, const N: usize>( + pub BigInt, + #[derivative(Debug = "ignore")] + #[doc(hidden)] + pub PhantomData

, +); + +impl, const N: usize> Fp { + #[inline] + fn from_bigint(r: BigInt) -> Option { + P::from_bigint(r) + } +} + +// field + +impl, const N: usize> MinimalField for Fp { + fn square_in_place(&mut self) -> &mut Self { + // implemented with mul_assign for now + let self_copy = *self; + self.mul_assign(&self_copy); + self + } +} + +// add, zero + +impl, const N: usize> Zero for Fp { + #[inline] + fn zero() -> Self { + P::ZERO + } + + #[inline] + fn is_zero(&self) -> bool { + *self == P::ZERO + } +} + +impl<'a, P: FpConfig, const N: usize> AddAssign<&'a Self> for Fp { + #[inline] + fn add_assign(&mut self, other: &Self) { + P::add_assign(self, other) + } +} +impl, const N: usize> Add for Fp { + type Output = Self; + + #[inline] + fn add(mut self, other: Self) -> Self { + self.add_assign(&other); + self + } +} +impl<'a, P: FpConfig, const N: usize> Add<&'a Fp> for Fp { + type Output = Self; + + #[inline] + fn add(mut self, other: &Self) -> Self { + self.add_assign(other); + self + } +} + +// mul, one + +impl, const N: usize> One for Fp { + #[inline] + fn one() -> Self { + P::ONE + } + + #[inline] + fn is_one(&self) -> bool { + *self == P::ONE + } +} +impl<'a, P: FpConfig, const N: usize> MulAssign<&'a Self> for Fp { + #[inline] + fn mul_assign(&mut self, other: &Self) { + P::mul_assign(self, other) + } +} +impl, const N: usize> Mul for Fp { + type Output = Self; + + #[inline] + fn mul(mut self, other: Self) -> Self { + self.mul_assign(&other); + self + } +} +impl<'a, P: FpConfig, const N: usize> Mul<&'a Fp> for Fp { + type Output = Self; + + #[inline] + fn mul(mut self, other: &Self) -> Self { + self.mul_assign(other); + self + } +} + +// (de)serialization + +impl, const N: usize> CanonicalSerialize for Fp { + #[inline] + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(writer, compress) + } + + #[inline] + fn serialized_size(&self, compress: Compress) -> usize { + self.0.serialized_size(compress) + } +} + +impl, const N: usize> Valid for Fp { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl, const N: usize> CanonicalDeserialize for Fp { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + Self::from_bigint(BigInt::deserialize_with_mode(reader, compress, validate)?) + .ok_or(SerializationError::InvalidData) + } +} diff --git a/poseidon/src/poseidon.rs b/poseidon/src/poseidon.rs index 8f35e735b6..7df61ae5fe 100644 --- a/poseidon/src/poseidon.rs +++ b/poseidon/src/poseidon.rs @@ -5,7 +5,7 @@ use crate::{ permutation::{full_round, poseidon_block_cipher}, }; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use mina_curves::pasta::wasm_friendly::minimal_field::MinimalField; +use mina_curves::pasta::wasm_friendly::MinimalField; use serde::{Deserialize, Serialize}; use serde_with::serde_as; From 9b380b0fa8cfd9b3b4e5088ad7036fd20236437f Mon Sep 17 00:00:00 2001 From: Gregor Date: Mon, 30 Sep 2024 17:27:46 +0200 Subject: [PATCH 04/15] rename, minor changes --- curves/src/pasta/wasm_friendly/wasm_fp.rs | 41 +++++++++-------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/curves/src/pasta/wasm_friendly/wasm_fp.rs b/curves/src/pasta/wasm_friendly/wasm_fp.rs index d913a1c3e4..a12ed2d691 100644 --- a/curves/src/pasta/wasm_friendly/wasm_fp.rs +++ b/curves/src/pasta/wasm_friendly/wasm_fp.rs @@ -1,5 +1,5 @@ /** - * MinimalField trait implementation `Fp` which only depends on an `FpConfig` trait + * MinimalField trait implementation `Fp` which only depends on an `FpBackend` trait * * Most of this code was copied over from ark_ff::Fp */ @@ -17,22 +17,13 @@ use std::{ use super::minimal_field::MinimalField; -pub trait FpConfig: Send + Sync + 'static + Sized { - /// The modulus of the field. +pub trait FpBackend: Send + Sync + 'static + Sized { const MODULUS: BigInt; - /// Additive identity of the field, i.e. the element `e` - /// such that, for all elements `f` of the field, `e + f = f`. const ZERO: Fp; - - /// Multiplicative identity of the field, i.e. the element `e` - /// such that, for all elements `f` of the field, `e * f = f`. const ONE: Fp; - /// Set a += b. fn add_assign(a: &mut Fp, b: &Fp); - - /// Set a *= b. fn mul_assign(a: &mut Fp, b: &Fp); /// Construct a field element from an integer in the range @@ -52,14 +43,14 @@ pub trait FpConfig: Send + Sync + 'static + Sized { PartialEq(bound = ""), Eq(bound = "") )] -pub struct Fp, const N: usize>( +pub struct Fp, const N: usize>( pub BigInt, #[derivative(Debug = "ignore")] #[doc(hidden)] pub PhantomData

, ); -impl, const N: usize> Fp { +impl, const N: usize> Fp { #[inline] fn from_bigint(r: BigInt) -> Option { P::from_bigint(r) @@ -68,7 +59,7 @@ impl, const N: usize> Fp { // field -impl, const N: usize> MinimalField for Fp { +impl, const N: usize> MinimalField for Fp { fn square_in_place(&mut self) -> &mut Self { // implemented with mul_assign for now let self_copy = *self; @@ -79,7 +70,7 @@ impl, const N: usize> MinimalField for Fp { // add, zero -impl, const N: usize> Zero for Fp { +impl, const N: usize> Zero for Fp { #[inline] fn zero() -> Self { P::ZERO @@ -91,13 +82,13 @@ impl, const N: usize> Zero for Fp { } } -impl<'a, P: FpConfig, const N: usize> AddAssign<&'a Self> for Fp { +impl<'a, P: FpBackend, const N: usize> AddAssign<&'a Self> for Fp { #[inline] fn add_assign(&mut self, other: &Self) { P::add_assign(self, other) } } -impl, const N: usize> Add for Fp { +impl, const N: usize> Add for Fp { type Output = Self; #[inline] @@ -106,7 +97,7 @@ impl, const N: usize> Add for Fp { self } } -impl<'a, P: FpConfig, const N: usize> Add<&'a Fp> for Fp { +impl<'a, P: FpBackend, const N: usize> Add<&'a Fp> for Fp { type Output = Self; #[inline] @@ -118,7 +109,7 @@ impl<'a, P: FpConfig, const N: usize> Add<&'a Fp> for Fp { // mul, one -impl, const N: usize> One for Fp { +impl, const N: usize> One for Fp { #[inline] fn one() -> Self { P::ONE @@ -129,13 +120,13 @@ impl, const N: usize> One for Fp { *self == P::ONE } } -impl<'a, P: FpConfig, const N: usize> MulAssign<&'a Self> for Fp { +impl<'a, P: FpBackend, const N: usize> MulAssign<&'a Self> for Fp { #[inline] fn mul_assign(&mut self, other: &Self) { P::mul_assign(self, other) } } -impl, const N: usize> Mul for Fp { +impl, const N: usize> Mul for Fp { type Output = Self; #[inline] @@ -144,7 +135,7 @@ impl, const N: usize> Mul for Fp { self } } -impl<'a, P: FpConfig, const N: usize> Mul<&'a Fp> for Fp { +impl<'a, P: FpBackend, const N: usize> Mul<&'a Fp> for Fp { type Output = Self; #[inline] @@ -156,7 +147,7 @@ impl<'a, P: FpConfig, const N: usize> Mul<&'a Fp> for Fp { // (de)serialization -impl, const N: usize> CanonicalSerialize for Fp { +impl, const N: usize> CanonicalSerialize for Fp { #[inline] fn serialize_with_mode( &self, @@ -172,13 +163,13 @@ impl, const N: usize> CanonicalSerialize for Fp { } } -impl, const N: usize> Valid for Fp { +impl, const N: usize> Valid for Fp { fn check(&self) -> Result<(), SerializationError> { Ok(()) } } -impl, const N: usize> CanonicalDeserialize for Fp { +impl, const N: usize> CanonicalDeserialize for Fp { fn deserialize_with_mode( reader: R, compress: Compress, From c508b59fb475f3903a752b4710fc1cb262538ae5 Mon Sep 17 00:00:00 2001 From: Gregor Date: Tue, 1 Oct 2024 10:16:41 +0200 Subject: [PATCH 05/15] start implementing one backend --- Cargo.lock | 1 + curves/Cargo.toml | 1 + curves/src/pasta/wasm_friendly/backend9.rs | 55 ++++++++++++++++++++++ curves/src/pasta/wasm_friendly/mod.rs | 2 + 4 files changed, 59 insertions(+) create mode 100644 curves/src/pasta/wasm_friendly/backend9.rs diff --git a/Cargo.lock b/Cargo.lock index 45d69b89f4..fd2fc76edc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1315,6 +1315,7 @@ dependencies = [ "ark-test-curves", "derivative", "rand", + "wasm-bindgen", ] [[package]] diff --git a/curves/Cargo.toml b/curves/Cargo.toml index 027108f4ba..da39e8d8da 100644 --- a/curves/Cargo.toml +++ b/curves/Cargo.toml @@ -13,6 +13,7 @@ license = "Apache-2.0" ark-ec.workspace = true ark-ff.workspace = true ark-serialize.workspace = true +wasm-bindgen.workspace = true derivative = { version = "2.0", features = ["use_core"] } [dev-dependencies] diff --git a/curves/src/pasta/wasm_friendly/backend9.rs b/curves/src/pasta/wasm_friendly/backend9.rs new file mode 100644 index 0000000000..68d8db9b48 --- /dev/null +++ b/curves/src/pasta/wasm_friendly/backend9.rs @@ -0,0 +1,55 @@ +/** + * Implementation of `FpBackend` for N=9, using 29-bit limbs represented by `u32`s. + */ +use super::wasm_fp::{Fp, FpBackend}; + +type B = [u32; 9]; + +const SHIFT: u64 = 29; +const MASK: u32 = (1 << SHIFT) - 1; +const MASK64: u64 = MASK as u64; +const TOTAL_BITS: u64 = 9 * SHIFT; // 261 + +pub trait FpConstants { + const MODULUS: B; + const R: B; // R = 2^261 % modulus + const R2: B; // R^2 % modulus + const MINV: u64; // -modulus^{-1} mod 2^32, as a u64 +} + +#[inline] +fn gt_modulus(a: B) -> bool { + for i in (0..9).rev() { + if a[i] > Fp::MODULUS[i] { + return true; + } else if a[i] < Fp::MODULUS[i] { + return false; + } + } + false +} + +/// TODO performance ideas to test: +/// - unroll loops +/// - introduce locals for a[i] instead of accessing memory multiple times +/// - only do 1 carry pass at the end, by proving properties of greater-than on uncarried result +/// - use cheaper, approximate greater-than check a[8] > Fp::MODULUS[8] +pub fn add_assign(mut a: B, b: B) { + let mut tmp: u32; + let mut carry: u32 = 0; + + for i in 0..9 { + tmp = a[i] + b[i] + carry; + carry = tmp >> SHIFT; + a[i] = tmp & MASK; + } + + if gt_modulus::(a) { + carry = 0; + for i in 0..9 { + tmp = a[i] - Fp::MODULUS[i] + carry; + carry = tmp >> SHIFT; + a[i] = tmp & MASK; + } + } +} diff --git a/curves/src/pasta/wasm_friendly/mod.rs b/curves/src/pasta/wasm_friendly/mod.rs index 3609e31e96..964de096dd 100644 --- a/curves/src/pasta/wasm_friendly/mod.rs +++ b/curves/src/pasta/wasm_friendly/mod.rs @@ -6,3 +6,5 @@ pub use minimal_field::MinimalField; pub mod wasm_fp; pub use wasm_fp::Fp; + +pub mod backend9; From 1a732d849fe5826fe957addd96b474762434dd47 Mon Sep 17 00:00:00 2001 From: Gregor Date: Tue, 1 Oct 2024 12:09:38 +0200 Subject: [PATCH 06/15] multiplication, start trying to satisfy backend trait --- curves/src/pasta/wasm_friendly/backend9.rs | 142 ++++++++++++++++++--- 1 file changed, 121 insertions(+), 21 deletions(-) diff --git a/curves/src/pasta/wasm_friendly/backend9.rs b/curves/src/pasta/wasm_friendly/backend9.rs index 68d8db9b48..f4842a18be 100644 --- a/curves/src/pasta/wasm_friendly/backend9.rs +++ b/curves/src/pasta/wasm_friendly/backend9.rs @@ -1,32 +1,48 @@ /** * Implementation of `FpBackend` for N=9, using 29-bit limbs represented by `u32`s. */ +use super::bigint32::BigInt; use super::wasm_fp::{Fp, FpBackend}; -type B = [u32; 9]; +type B = [i32; 9]; +type B64 = [i64; 9]; -const SHIFT: u64 = 29; -const MASK: u32 = (1 << SHIFT) - 1; -const MASK64: u64 = MASK as u64; -const TOTAL_BITS: u64 = 9 * SHIFT; // 261 +const SHIFT: i64 = 29; +const MASK: i32 = (1 << SHIFT) - 1; -pub trait FpConstants { +const SHIFT64: i64 = SHIFT as i64; +const MASK64: i64 = MASK as i64; + +const TOTAL_BITS: i64 = 9 * SHIFT; // 261 + +pub trait FpConstants: Send + Sync + 'static + Sized { const MODULUS: B; - const R: B; // R = 2^261 % modulus - const R2: B; // R^2 % modulus - const MINV: u64; // -modulus^{-1} mod 2^32, as a u64 + const MODULUS64: B64 = { + let mut modulus64 = [0i64; 9]; + let modulus = Self::MODULUS; + let mut i = 0; + while i < 9 { + modulus64[i] = modulus[i] as i64; + i += 1; + } + modulus64 + }; + + const MINV: i64; // -modulus^(-1) mod 2^29, as a u64 + + const R: B = [1, 0, 0, 0, 0, 0, 0, 0, 0]; } #[inline] -fn gt_modulus(a: B) -> bool { +fn gte_modulus(x: &B) -> bool { for i in (0..9).rev() { - if a[i] > Fp::MODULUS[i] { + if x[i] > FpC::MODULUS[i] { return true; - } else if a[i] < Fp::MODULUS[i] { + } else if x[i] < FpC::MODULUS[i] { return false; } } - false + true } /// TODO performance ideas to test: @@ -34,22 +50,106 @@ fn gt_modulus(a: B) -> bool { /// - introduce locals for a[i] instead of accessing memory multiple times /// - only do 1 carry pass at the end, by proving properties of greater-than on uncarried result /// - use cheaper, approximate greater-than check a[8] > Fp::MODULUS[8] -pub fn add_assign(mut a: B, b: B) { - let mut tmp: u32; - let mut carry: u32 = 0; +pub fn add_assign(x: &mut B, y: &B) { + let mut tmp: i32; + let mut carry: i32 = 0; for i in 0..9 { - tmp = a[i] + b[i] + carry; + tmp = x[i] + y[i] + carry; carry = tmp >> SHIFT; - a[i] = tmp & MASK; + x[i] = tmp & MASK; } - if gt_modulus::(a) { + if gte_modulus::(x) { carry = 0; for i in 0..9 { - tmp = a[i] - Fp::MODULUS[i] + carry; + tmp = x[i] - FpC::MODULUS[i] + carry; carry = tmp >> SHIFT; - a[i] = tmp & MASK; + x[i] = tmp & MASK; + } + } +} + +#[inline] +fn conditional_reduce(x: &mut B) { + if gte_modulus::(x) { + for i in 0..9 { + x[i] -= FpC::MODULUS[i]; + } + for i in 1..9 { + x[i] += x[i - 1] >> SHIFT; + } + for i in 0..8 { + x[i] &= MASK; + } + } +} + +/// Montgomery multiplication +pub fn mul_assign(x: &mut B, y: &B) { + // load y[i] into local i64s + // TODO make sure these are locals + let mut y_local = [0i64; 9]; + for i in 0..9 { + y_local[i] = y[i] as i64; + } + + // locals for result + let mut z = [0i64; 8]; + let mut tmp: i64; + + // main loop, without intermediate carries except for z0 + for i in 0..9 { + let xi = x[i] as i64; + + // compute qi and carry z0 result to z1 before discarding z0 + tmp = xi * y_local[0]; + let qi = ((tmp & MASK64) * FpC::MINV) & MASK64; + z[1] += (tmp + qi * FpC::MODULUS64[0]) >> SHIFT64; + + // compute zi and shift in one step + for j in 1..8 { + z[j - 1] = z[j] + xi * y_local[j] + qi * FpC::MODULUS64[j]; + } + // for j=8 we save an addition since z[8] is never needed + z[7] = xi * y_local[8] + qi * FpC::MODULUS64[8]; + } + + // final carry pass, store result back into x + x[0] = (z[0] & MASK64) as i32; + for i in 1..8 { + x[i] = (((z[i - 1] >> SHIFT64) + z[i]) & MASK64) as i32; + } + x[8] = (z[7] >> SHIFT64) as i32; + + // at this point, x is guaranteed to be less than 2*MODULUS + // conditionally subtract the modulus to bring it back into the canonical range + conditional_reduce::(x); +} + +// implement FpBackend given an FpConstants + +impl FpBackend<9> for FpC { + const MODULUS: BigInt<9> = BigInt(FpC::MODULUS); + const ZERO: [i32; 9] = BigInt([0; 9]); + const ONE: [i32; 9] = BigInt(Self::R); + + fn add_assign(x: &mut [i32; 9], y: &[i32; 9]) { + add_assign::(x, y); + } + + fn mul_assign(x: &mut [i32; 9], y: &[i32; 9]) { + mul_assign::(x, y); + } + + fn from_bigint(other: BigInt<9>) -> Option> { + let mut r = [0; 9]; + for i in 0..9 { + r[i] = other.0[i] as i32; + } + if gte_modulus::(&r) { + return None; } + panic!("todo") } } From 891b6146d3b1c88f5490021215f6cb817e4d6fa8 Mon Sep 17 00:00:00 2001 From: Gregor Date: Tue, 1 Oct 2024 12:45:08 +0200 Subject: [PATCH 07/15] make it compile --- curves/src/pasta/wasm_friendly/backend9.rs | 77 +++++++++++----------- curves/src/pasta/wasm_friendly/wasm_fp.rs | 12 ++-- 2 files changed, 44 insertions(+), 45 deletions(-) diff --git a/curves/src/pasta/wasm_friendly/backend9.rs b/curves/src/pasta/wasm_friendly/backend9.rs index f4842a18be..665f5faebf 100644 --- a/curves/src/pasta/wasm_friendly/backend9.rs +++ b/curves/src/pasta/wasm_friendly/backend9.rs @@ -4,33 +4,32 @@ use super::bigint32::BigInt; use super::wasm_fp::{Fp, FpBackend}; -type B = [i32; 9]; -type B64 = [i64; 9]; +type B = [u32; 9]; +type B64 = [u64; 9]; -const SHIFT: i64 = 29; -const MASK: i32 = (1 << SHIFT) - 1; +const SHIFT: u64 = 29; +const MASK: u32 = (1 << SHIFT) - 1; -const SHIFT64: i64 = SHIFT as i64; -const MASK64: i64 = MASK as i64; - -const TOTAL_BITS: i64 = 9 * SHIFT; // 261 +const SHIFT64: u64 = SHIFT as u64; +const MASK64: u64 = MASK as u64; pub trait FpConstants: Send + Sync + 'static + Sized { const MODULUS: B; const MODULUS64: B64 = { - let mut modulus64 = [0i64; 9]; + let mut modulus64 = [0u64; 9]; let modulus = Self::MODULUS; let mut i = 0; while i < 9 { - modulus64[i] = modulus[i] as i64; + modulus64[i] = modulus[i] as u64; i += 1; } modulus64 }; - const MINV: i64; // -modulus^(-1) mod 2^29, as a u64 - - const R: B = [1, 0, 0, 0, 0, 0, 0, 0, 0]; + /// montgomery params + /// TODO: compute these + const R: B; // R = 2^261 mod modulus + const MINV: u64; // -modulus^(-1) mod 2^29, as a u64 } #[inline] @@ -51,20 +50,20 @@ fn gte_modulus(x: &B) -> bool { /// - only do 1 carry pass at the end, by proving properties of greater-than on uncarried result /// - use cheaper, approximate greater-than check a[8] > Fp::MODULUS[8] pub fn add_assign(x: &mut B, y: &B) { - let mut tmp: i32; + let mut tmp: u32; let mut carry: i32 = 0; for i in 0..9 { - tmp = x[i] + y[i] + carry; - carry = tmp >> SHIFT; + tmp = x[i] + y[i] + (carry as u32); + carry = (tmp as i32) >> SHIFT; x[i] = tmp & MASK; } if gte_modulus::(x) { carry = 0; for i in 0..9 { - tmp = x[i] - FpC::MODULUS[i] + carry; - carry = tmp >> SHIFT; + tmp = x[i].wrapping_sub(FpC::MODULUS[i]) + (carry as u32); + carry = (tmp as i32) >> SHIFT; x[i] = tmp & MASK; } } @@ -74,10 +73,10 @@ pub fn add_assign(x: &mut B, y: &B) { fn conditional_reduce(x: &mut B) { if gte_modulus::(x) { for i in 0..9 { - x[i] -= FpC::MODULUS[i]; + x[i] = x[i].wrapping_sub(FpC::MODULUS[i]); } for i in 1..9 { - x[i] += x[i - 1] >> SHIFT; + x[i] = x[i] + (((x[i - 1] as i32) >> SHIFT) as u32); } for i in 0..8 { x[i] &= MASK; @@ -87,20 +86,20 @@ fn conditional_reduce(x: &mut B) { /// Montgomery multiplication pub fn mul_assign(x: &mut B, y: &B) { - // load y[i] into local i64s + // load y[i] into local u64s // TODO make sure these are locals - let mut y_local = [0i64; 9]; + let mut y_local = [0u64; 9]; for i in 0..9 { - y_local[i] = y[i] as i64; + y_local[i] = y[i] as u64; } // locals for result - let mut z = [0i64; 8]; - let mut tmp: i64; + let mut z = [0u64; 8]; + let mut tmp: u64; // main loop, without intermediate carries except for z0 for i in 0..9 { - let xi = x[i] as i64; + let xi = x[i] as u64; // compute qi and carry z0 result to z1 before discarding z0 tmp = xi * y_local[0]; @@ -109,47 +108,47 @@ pub fn mul_assign(x: &mut B, y: &B) { // compute zi and shift in one step for j in 1..8 { - z[j - 1] = z[j] + xi * y_local[j] + qi * FpC::MODULUS64[j]; + z[j - 1] = z[j] + (xi * y_local[j]) + (qi * FpC::MODULUS64[j]); } // for j=8 we save an addition since z[8] is never needed z[7] = xi * y_local[8] + qi * FpC::MODULUS64[8]; } // final carry pass, store result back into x - x[0] = (z[0] & MASK64) as i32; + x[0] = (z[0] & MASK64) as u32; for i in 1..8 { - x[i] = (((z[i - 1] >> SHIFT64) + z[i]) & MASK64) as i32; + x[i] = (((z[i - 1] >> SHIFT64) + z[i]) & MASK64) as u32; } - x[8] = (z[7] >> SHIFT64) as i32; + x[8] = (z[7] >> SHIFT64) as u32; // at this point, x is guaranteed to be less than 2*MODULUS // conditionally subtract the modulus to bring it back into the canonical range conditional_reduce::(x); } -// implement FpBackend given an FpConstants +// implement FpBackend given FpConstants impl FpBackend<9> for FpC { const MODULUS: BigInt<9> = BigInt(FpC::MODULUS); - const ZERO: [i32; 9] = BigInt([0; 9]); - const ONE: [i32; 9] = BigInt(Self::R); + const ZERO: BigInt<9> = BigInt([0; 9]); + const ONE: BigInt<9> = BigInt(FpC::R); - fn add_assign(x: &mut [i32; 9], y: &[i32; 9]) { - add_assign::(x, y); + fn add_assign(x: &mut Fp, y: &Fp) { + add_assign::(&mut x.0 .0, &y.0 .0); } - fn mul_assign(x: &mut [i32; 9], y: &[i32; 9]) { - mul_assign::(x, y); + fn mul_assign(x: &mut Fp, y: &Fp) { + mul_assign::(&mut x.0 .0, &y.0 .0); } fn from_bigint(other: BigInt<9>) -> Option> { let mut r = [0; 9]; for i in 0..9 { - r[i] = other.0[i] as i32; + r[i] = other.0[i] as u32; } if gte_modulus::(&r) { return None; } - panic!("todo") + Some(Fp(BigInt(r), Default::default())) } } diff --git a/curves/src/pasta/wasm_friendly/wasm_fp.rs b/curves/src/pasta/wasm_friendly/wasm_fp.rs index a12ed2d691..7d490836f5 100644 --- a/curves/src/pasta/wasm_friendly/wasm_fp.rs +++ b/curves/src/pasta/wasm_friendly/wasm_fp.rs @@ -20,8 +20,8 @@ use super::minimal_field::MinimalField; pub trait FpBackend: Send + Sync + 'static + Sized { const MODULUS: BigInt; - const ZERO: Fp; - const ONE: Fp; + const ZERO: BigInt; + const ONE: BigInt; fn add_assign(a: &mut Fp, b: &Fp); fn mul_assign(a: &mut Fp, b: &Fp); @@ -73,12 +73,12 @@ impl, const N: usize> MinimalField for Fp { impl, const N: usize> Zero for Fp { #[inline] fn zero() -> Self { - P::ZERO + Fp(P::ZERO, Default::default()) } #[inline] fn is_zero(&self) -> bool { - *self == P::ZERO + *self == Self::zero() } } @@ -112,12 +112,12 @@ impl<'a, P: FpBackend, const N: usize> Add<&'a Fp> for Fp { impl, const N: usize> One for Fp { #[inline] fn one() -> Self { - P::ONE + Fp(P::ONE, Default::default()) } #[inline] fn is_one(&self) -> bool { - *self == P::ONE + *self == Self::one() } } impl<'a, P: FpBackend, const N: usize> MulAssign<&'a Self> for Fp { From e8fcb6241f60483ccc3104245807184c6b938588 Mon Sep 17 00:00:00 2001 From: Gregor Date: Tue, 1 Oct 2024 13:18:58 +0200 Subject: [PATCH 08/15] export fp9 type --- curves/src/pasta/wasm_friendly/backend9.rs | 20 +++++++++++++++++--- curves/src/pasta/wasm_friendly/mod.rs | 2 ++ curves/src/pasta/wasm_friendly/pasta.rs | 13 +++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) create mode 100644 curves/src/pasta/wasm_friendly/pasta.rs diff --git a/curves/src/pasta/wasm_friendly/backend9.rs b/curves/src/pasta/wasm_friendly/backend9.rs index 665f5faebf..c31b5935cb 100644 --- a/curves/src/pasta/wasm_friendly/backend9.rs +++ b/curves/src/pasta/wasm_friendly/backend9.rs @@ -13,6 +13,20 @@ const MASK: u32 = (1 << SHIFT) - 1; const SHIFT64: u64 = SHIFT as u64; const MASK64: u64 = MASK as u64; +pub const fn from_64x4(pa: [u64; 4]) -> [u32; 9] { + let mut p = [0u32; 9]; + p[0] = (pa[0] & MASK64) as u32; + p[1] = ((pa[0] >> 29) & MASK64) as u32; + p[2] = (((pa[0] >> 58) | (pa[1] << 6)) & MASK64) as u32; + p[3] = ((pa[1] >> 23) & MASK64) as u32; + p[4] = (((pa[1] >> 52) | (pa[2] << 12)) & MASK64) as u32; + p[5] = ((pa[2] >> 17) & MASK64) as u32; + p[6] = (((pa[2] >> 46) | (pa[3] << 18)) & MASK64) as u32; + p[7] = ((pa[3] >> 11) & MASK64) as u32; + p[8] = (pa[3] >> 40) as u32; + p +} + pub trait FpConstants: Send + Sync + 'static + Sized { const MODULUS: B; const MODULUS64: B64 = { @@ -129,12 +143,12 @@ pub fn mul_assign(x: &mut B, y: &B) { // implement FpBackend given FpConstants impl FpBackend<9> for FpC { - const MODULUS: BigInt<9> = BigInt(FpC::MODULUS); + const MODULUS: BigInt<9> = BigInt(Self::MODULUS); const ZERO: BigInt<9> = BigInt([0; 9]); - const ONE: BigInt<9> = BigInt(FpC::R); + const ONE: BigInt<9> = BigInt(Self::R); fn add_assign(x: &mut Fp, y: &Fp) { - add_assign::(&mut x.0 .0, &y.0 .0); + add_assign::(&mut x.0 .0, &y.0 .0); } fn mul_assign(x: &mut Fp, y: &Fp) { diff --git a/curves/src/pasta/wasm_friendly/mod.rs b/curves/src/pasta/wasm_friendly/mod.rs index 964de096dd..15800465ec 100644 --- a/curves/src/pasta/wasm_friendly/mod.rs +++ b/curves/src/pasta/wasm_friendly/mod.rs @@ -8,3 +8,5 @@ pub mod wasm_fp; pub use wasm_fp::Fp; pub mod backend9; +pub mod pasta; +pub use pasta::Fp9; diff --git a/curves/src/pasta/wasm_friendly/pasta.rs b/curves/src/pasta/wasm_friendly/pasta.rs new file mode 100644 index 0000000000..96abb67f35 --- /dev/null +++ b/curves/src/pasta/wasm_friendly/pasta.rs @@ -0,0 +1,13 @@ +use super::backend9; +use super::wasm_fp; +use crate::pasta::Fp; +use ark_ff::PrimeField; + +pub struct Fp9Parameters; + +impl backend9::FpConstants for Fp9Parameters { + const MODULUS: [u32; 9] = backend9::from_64x4(Fp::MODULUS.0); + const R: [u32; 9] = backend9::from_64x4(Fp::R.0); + const MINV: u64 = Fp::INV; +} +pub type Fp9 = wasm_fp::Fp; From 26fb3fe78ca2b74309da137a610e591aa3a2441a Mon Sep 17 00:00:00 2001 From: Gregor Date: Tue, 1 Oct 2024 13:43:50 +0200 Subject: [PATCH 09/15] add coercion from fp to fp9 --- curves/src/pasta/wasm_friendly/pasta.rs | 12 ++++++++++++ curves/src/pasta/wasm_friendly/wasm_fp.rs | 18 ++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/curves/src/pasta/wasm_friendly/pasta.rs b/curves/src/pasta/wasm_friendly/pasta.rs index 96abb67f35..d8a4417fe5 100644 --- a/curves/src/pasta/wasm_friendly/pasta.rs +++ b/curves/src/pasta/wasm_friendly/pasta.rs @@ -11,3 +11,15 @@ impl backend9::FpConstants for Fp9Parameters { const MINV: u64 = Fp::INV; } pub type Fp9 = wasm_fp::Fp; + +impl Fp9 { + pub fn from_fp(fp: Fp) -> Self { + backend9::from_64x4(fp.0 .0).into() + } +} + +impl From for Fp9 { + fn from(fp: Fp) -> Self { + Fp9::from_fp(fp) + } +} diff --git a/curves/src/pasta/wasm_friendly/wasm_fp.rs b/curves/src/pasta/wasm_friendly/wasm_fp.rs index 7d490836f5..c9e4d74731 100644 --- a/curves/src/pasta/wasm_friendly/wasm_fp.rs +++ b/curves/src/pasta/wasm_friendly/wasm_fp.rs @@ -51,12 +51,30 @@ pub struct Fp, const N: usize>( ); impl, const N: usize> Fp { + fn new(bigint: BigInt) -> Self { + Fp(bigint, Default::default()) + } + #[inline] fn from_bigint(r: BigInt) -> Option { P::from_bigint(r) } } +// coerce into Fp from either BigInt or [u32; N] + +impl, const N: usize> From> for Fp { + fn from(val: BigInt) -> Self { + Fp::new(val) + } +} + +impl, const N: usize> From<[u32; N]> for Fp { + fn from(val: [u32; N]) -> Self { + Fp::new(BigInt(val)) + } +} + // field impl, const N: usize> MinimalField for Fp { From 31fb3130a73238130d655feac786ee255a86edfb Mon Sep 17 00:00:00 2001 From: Gregor Date: Tue, 1 Oct 2024 13:44:00 +0200 Subject: [PATCH 10/15] add fp9 poseidon benchmark --- poseidon/benches/poseidon_bench.rs | 40 +++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/poseidon/benches/poseidon_bench.rs b/poseidon/benches/poseidon_bench.rs index 81a14d20fb..a504323ae9 100644 --- a/poseidon/benches/poseidon_bench.rs +++ b/poseidon/benches/poseidon_bench.rs @@ -1,10 +1,13 @@ +use ark_ff::Zero; use criterion::{criterion_group, criterion_main, Criterion}; +use mina_curves::pasta::wasm_friendly::Fp9; use mina_curves::pasta::Fp; use mina_poseidon::{ constants::PlonkSpongeConstantsKimchi, pasta::fp_kimchi as SpongeParametersKimchi, - poseidon::{ArithmeticSponge as Poseidon, Sponge}, + poseidon::{ArithmeticSponge as Poseidon, ArithmeticSpongeParams, Sponge}, }; +use once_cell::sync::Lazy; pub fn bench_poseidon_kimchi(c: &mut Criterion) { let mut group = c.benchmark_group("Poseidon"); @@ -23,8 +26,43 @@ pub fn bench_poseidon_kimchi(c: &mut Criterion) { }) }); + // same as above but with Fp9 + group.bench_function("poseidon_hash_kimchi_fp9", |b| { + let mut hash: Fp9 = Fp9::zero(); + let mut poseidon = Poseidon::::new(fp9_static_params()); + + b.iter(|| { + poseidon.absorb(&[hash]); + hash = poseidon.squeeze(); + }) + }); + group.finish(); } criterion_group!(benches, bench_poseidon_kimchi); criterion_main!(benches); + +// sponge params for Fp9 + +fn fp9_sponge_params() -> ArithmeticSpongeParams { + let params = SpongeParametersKimchi::params(); + + // leverage .into() to convert from Fp to Fp9 + ArithmeticSpongeParams:: { + round_constants: params + .round_constants + .into_iter() + .map(|x| x.into_iter().map(Fp9::from).collect()) + .collect(), + mds: params + .mds + .into_iter() + .map(|x| x.into_iter().map(Fp9::from).collect()) + .collect(), + } +} +fn fp9_static_params() -> &'static ArithmeticSpongeParams { + static PARAMS: Lazy> = Lazy::new(fp9_sponge_params); + &PARAMS +} From 953fad2b248529df494934af308d61e0b0797325 Mon Sep 17 00:00:00 2001 From: Gregor Date: Tue, 1 Oct 2024 15:47:43 +0200 Subject: [PATCH 11/15] fix conversion from arkworks fp, and hard-code field constants --- curves/src/pasta/wasm_friendly/backend9.rs | 21 +++++++++++++-------- curves/src/pasta/wasm_friendly/pasta.rs | 17 +++++++++++++---- curves/src/pasta/wasm_friendly/wasm_fp.rs | 16 ++++++++-------- 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/curves/src/pasta/wasm_friendly/backend9.rs b/curves/src/pasta/wasm_friendly/backend9.rs index c31b5935cb..ac6734ad14 100644 --- a/curves/src/pasta/wasm_friendly/backend9.rs +++ b/curves/src/pasta/wasm_friendly/backend9.rs @@ -43,6 +43,7 @@ pub trait FpConstants: Send + Sync + 'static + Sized { /// montgomery params /// TODO: compute these const R: B; // R = 2^261 mod modulus + const R2: B; // R^2 mod modulus const MINV: u64; // -modulus^(-1) mod 2^29, as a u64 } @@ -142,6 +143,13 @@ pub fn mul_assign(x: &mut B, y: &B) { // implement FpBackend given FpConstants +pub fn from_bigint_unsafe(x: BigInt<9>) -> Fp { + let mut r = x.0.clone(); + // convert to montgomery form + mul_assign::(&mut r, &FpC::R2); + Fp(BigInt(r), Default::default()) +} + impl FpBackend<9> for FpC { const MODULUS: BigInt<9> = BigInt(Self::MODULUS); const ZERO: BigInt<9> = BigInt([0; 9]); @@ -155,14 +163,11 @@ impl FpBackend<9> for FpC { mul_assign::(&mut x.0 .0, &y.0 .0); } - fn from_bigint(other: BigInt<9>) -> Option> { - let mut r = [0; 9]; - for i in 0..9 { - r[i] = other.0[i] as u32; - } - if gte_modulus::(&r) { - return None; + fn from_bigint(x: BigInt<9>) -> Option> { + if gte_modulus::(&x.0) { + None + } else { + Some(from_bigint_unsafe(x)) } - Some(Fp(BigInt(r), Default::default())) } } diff --git a/curves/src/pasta/wasm_friendly/pasta.rs b/curves/src/pasta/wasm_friendly/pasta.rs index d8a4417fe5..572be38440 100644 --- a/curves/src/pasta/wasm_friendly/pasta.rs +++ b/curves/src/pasta/wasm_friendly/pasta.rs @@ -6,15 +6,24 @@ use ark_ff::PrimeField; pub struct Fp9Parameters; impl backend9::FpConstants for Fp9Parameters { - const MODULUS: [u32; 9] = backend9::from_64x4(Fp::MODULUS.0); - const R: [u32; 9] = backend9::from_64x4(Fp::R.0); - const MINV: u64 = Fp::INV; + const MODULUS: [u32; 9] = [ + 0x1, 0x9698768, 0x133e46e6, 0xd31f812, 0x224, 0x0, 0x0, 0x0, 0x400000, + ]; + const R: [u32; 9] = [ + 0x1fffff81, 0x14a5d367, 0x141ad3c0, 0x1435eec5, 0x1ffeefef, 0x1fffffff, 0x1fffffff, + 0x1fffffff, 0x3fffff, + ]; + const R2: [u32; 9] = [ + 0x3b6a, 0x19c10910, 0x1a6a0188, 0x12a4fd88, 0x634b36d, 0x178792ba, 0x7797a99, 0x1dce5b8a, + 0x3506bd, + ]; + const MINV: u64 = 0x1fffffff; } pub type Fp9 = wasm_fp::Fp; impl Fp9 { pub fn from_fp(fp: Fp) -> Self { - backend9::from_64x4(fp.0 .0).into() + backend9::from_bigint_unsafe(super::BigInt(backend9::from_64x4(fp.into_bigint().0))) } } diff --git a/curves/src/pasta/wasm_friendly/wasm_fp.rs b/curves/src/pasta/wasm_friendly/wasm_fp.rs index c9e4d74731..981a933575 100644 --- a/curves/src/pasta/wasm_friendly/wasm_fp.rs +++ b/curves/src/pasta/wasm_friendly/wasm_fp.rs @@ -19,7 +19,6 @@ use super::minimal_field::MinimalField; pub trait FpBackend: Send + Sync + 'static + Sized { const MODULUS: BigInt; - const ZERO: BigInt; const ONE: BigInt; @@ -41,7 +40,8 @@ pub trait FpBackend: Send + Sync + 'static + Sized { Clone(bound = ""), Copy(bound = ""), PartialEq(bound = ""), - Eq(bound = "") + Eq(bound = ""), + Debug(bound = "") )] pub struct Fp, const N: usize>( pub BigInt, @@ -51,12 +51,12 @@ pub struct Fp, const N: usize>( ); impl, const N: usize> Fp { - fn new(bigint: BigInt) -> Self { + pub fn new(bigint: BigInt) -> Self { Fp(bigint, Default::default()) } #[inline] - fn from_bigint(r: BigInt) -> Option { + pub fn from_bigint(r: BigInt) -> Option { P::from_bigint(r) } } @@ -65,13 +65,13 @@ impl, const N: usize> Fp { impl, const N: usize> From> for Fp { fn from(val: BigInt) -> Self { - Fp::new(val) + Fp::from_bigint(val).unwrap() } } impl, const N: usize> From<[u32; N]> for Fp { fn from(val: [u32; N]) -> Self { - Fp::new(BigInt(val)) + Fp::from_bigint(BigInt(val)).unwrap() } } @@ -91,7 +91,7 @@ impl, const N: usize> MinimalField for Fp { impl, const N: usize> Zero for Fp { #[inline] fn zero() -> Self { - Fp(P::ZERO, Default::default()) + Fp::new(P::ZERO) } #[inline] @@ -130,7 +130,7 @@ impl<'a, P: FpBackend, const N: usize> Add<&'a Fp> for Fp { impl, const N: usize> One for Fp { #[inline] fn one() -> Self { - Fp(P::ONE, Default::default()) + Fp::new(P::ONE) } #[inline] From cf4f1995ff3a2fff325148dbd5da6f57a4218cba Mon Sep 17 00:00:00 2001 From: Gregor Date: Tue, 1 Oct 2024 16:32:30 +0200 Subject: [PATCH 12/15] print field elements, poseidon result doesn't match yet --- Cargo.lock | 1 + curves/Cargo.toml | 1 + curves/src/pasta/wasm_friendly/backend9.rs | 33 +++++++++++++++++++++ curves/src/pasta/wasm_friendly/wasm_fp.rs | 34 +++++++++++++++++++++- poseidon/benches/poseidon_bench.rs | 6 ++++ 5 files changed, 74 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index fd2fc76edc..4cc68dba46 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1314,6 +1314,7 @@ dependencies = [ "ark-std", "ark-test-curves", "derivative", + "num-bigint", "rand", "wasm-bindgen", ] diff --git a/curves/Cargo.toml b/curves/Cargo.toml index da39e8d8da..f97a60cbae 100644 --- a/curves/Cargo.toml +++ b/curves/Cargo.toml @@ -14,6 +14,7 @@ ark-ec.workspace = true ark-ff.workspace = true ark-serialize.workspace = true wasm-bindgen.workspace = true +num-bigint.workspace = true derivative = { version = "2.0", features = ["use_core"] } [dev-dependencies] diff --git a/curves/src/pasta/wasm_friendly/backend9.rs b/curves/src/pasta/wasm_friendly/backend9.rs index ac6734ad14..f01c34b996 100644 --- a/curves/src/pasta/wasm_friendly/backend9.rs +++ b/curves/src/pasta/wasm_friendly/backend9.rs @@ -26,6 +26,22 @@ pub const fn from_64x4(pa: [u64; 4]) -> [u32; 9] { p[8] = (pa[3] >> 40) as u32; p } +pub const fn to_64x4(pa: [u32; 9]) -> [u64; 4] { + let mut p = [0u64; 4]; + p[0] = pa[0] as u64; + p[0] |= (pa[1] as u64) << 29; + p[0] |= (pa[2] as u64) << 58; + p[1] = (pa[2] as u64) >> 6; + p[1] |= (pa[3] as u64) << 23; + p[1] |= (pa[4] as u64) << 52; + p[2] = (pa[4] as u64) >> 12; + p[2] |= (pa[5] as u64) << 17; + p[2] |= (pa[6] as u64) << 46; + p[3] = (pa[6] as u64) >> 18; + p[3] |= (pa[7] as u64) << 11; + p[3] |= (pa[8] as u64) << 40; + p +} pub trait FpConstants: Send + Sync + 'static + Sized { const MODULUS: B; @@ -170,4 +186,21 @@ impl FpBackend<9> for FpC { Some(from_bigint_unsafe(x)) } } + fn to_bigint(x: Fp) -> BigInt<9> { + let one = [1, 0, 0, 0, 0, 0, 0, 0, 0]; + let mut r = x.0 .0.clone(); + // convert back from montgomery form + mul_assign::(&mut r, &one); + BigInt(r) + } + + fn pack(x: Fp) -> Vec { + let x = Self::to_bigint(x).0; + let x64 = to_64x4(x); + let mut res = Vec::with_capacity(4); + for limb in x64.iter() { + res.push(*limb); + } + res + } } diff --git a/curves/src/pasta/wasm_friendly/wasm_fp.rs b/curves/src/pasta/wasm_friendly/wasm_fp.rs index 981a933575..5970015057 100644 --- a/curves/src/pasta/wasm_friendly/wasm_fp.rs +++ b/curves/src/pasta/wasm_friendly/wasm_fp.rs @@ -10,6 +10,7 @@ use ark_serialize::{ Write, }; use derivative::Derivative; +use num_bigint::BigUint; use std::{ marker::PhantomData, ops::{Add, AddAssign, Mul, MulAssign}, @@ -28,7 +29,10 @@ pub trait FpBackend: Send + Sync + 'static + Sized { /// Construct a field element from an integer in the range /// `0..(Self::MODULUS - 1)`. Returns `None` if the integer is outside /// this range. - fn from_bigint(other: BigInt) -> Option>; + fn from_bigint(x: BigInt) -> Option>; + fn to_bigint(x: Fp) -> BigInt; + + fn pack(x: Fp) -> Vec; } /// Represents an element of the prime field F_p, where `p == P::MODULUS`. @@ -59,6 +63,19 @@ impl, const N: usize> Fp { pub fn from_bigint(r: BigInt) -> Option { P::from_bigint(r) } + #[inline] + pub fn into_bigint(self) -> BigInt { + P::to_bigint(self) + } + + pub fn to_bytes_le(self) -> Vec { + let chunks = P::pack(self).into_iter().map(|x| x.to_le_bytes()); + let mut bytes = Vec::with_capacity(chunks.len() * 8); + for chunk in chunks { + bytes.extend_from_slice(&chunk); + } + bytes + } } // coerce into Fp from either BigInt or [u32; N] @@ -197,3 +214,18 @@ impl, const N: usize> CanonicalDeserialize for Fp { .ok_or(SerializationError::InvalidData) } } + +// display + +impl, const N: usize> From> for BigUint { + #[inline] + fn from(val: Fp) -> BigUint { + BigUint::from_bytes_le(&val.to_bytes_le()) + } +} + +impl, const N: usize> std::fmt::Display for Fp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + BigUint::from(*self).fmt(f) + } +} diff --git a/poseidon/benches/poseidon_bench.rs b/poseidon/benches/poseidon_bench.rs index a504323ae9..24d8514d84 100644 --- a/poseidon/benches/poseidon_bench.rs +++ b/poseidon/benches/poseidon_bench.rs @@ -20,6 +20,9 @@ pub fn bench_poseidon_kimchi(c: &mut Criterion) { SpongeParametersKimchi::static_params(), ); + poseidon.absorb(&[Fp::zero()]); + println!("{}", poseidon.squeeze().to_string()); + b.iter(|| { poseidon.absorb(&[hash]); hash = poseidon.squeeze(); @@ -31,6 +34,9 @@ pub fn bench_poseidon_kimchi(c: &mut Criterion) { let mut hash: Fp9 = Fp9::zero(); let mut poseidon = Poseidon::::new(fp9_static_params()); + poseidon.absorb(&[Fp9::zero()]); + println!("{}", poseidon.squeeze().to_string()); + b.iter(|| { poseidon.absorb(&[hash]); hash = poseidon.squeeze(); From b0f8f5ee9cd505dfbc53e77fb416cd0140d5fed9 Mon Sep 17 00:00:00 2001 From: Gregor Date: Tue, 1 Oct 2024 16:36:27 +0200 Subject: [PATCH 13/15] cargo fmt --- curves/src/pasta/wasm_friendly/pasta.rs | 3 +-- poseidon/benches/poseidon_bench.rs | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/curves/src/pasta/wasm_friendly/pasta.rs b/curves/src/pasta/wasm_friendly/pasta.rs index 572be38440..f6cdb802b8 100644 --- a/curves/src/pasta/wasm_friendly/pasta.rs +++ b/curves/src/pasta/wasm_friendly/pasta.rs @@ -1,5 +1,4 @@ -use super::backend9; -use super::wasm_fp; +use super::{backend9, wasm_fp}; use crate::pasta::Fp; use ark_ff::PrimeField; diff --git a/poseidon/benches/poseidon_bench.rs b/poseidon/benches/poseidon_bench.rs index 24d8514d84..fb240cde88 100644 --- a/poseidon/benches/poseidon_bench.rs +++ b/poseidon/benches/poseidon_bench.rs @@ -1,7 +1,6 @@ use ark_ff::Zero; use criterion::{criterion_group, criterion_main, Criterion}; -use mina_curves::pasta::wasm_friendly::Fp9; -use mina_curves::pasta::Fp; +use mina_curves::pasta::{wasm_friendly::Fp9, Fp}; use mina_poseidon::{ constants::PlonkSpongeConstantsKimchi, pasta::fp_kimchi as SpongeParametersKimchi, From 8b972dc17d3a323f767ed480169fd2909d247734 Mon Sep 17 00:00:00 2001 From: Gregor Date: Tue, 1 Oct 2024 17:02:08 +0200 Subject: [PATCH 14/15] fix some clippy warnings, ignore others --- curves/src/pasta/wasm_friendly/backend9.rs | 15 +++++++++++---- curves/src/pasta/wasm_friendly/wasm_fp.rs | 7 ++++++- poseidon/benches/poseidon_bench.rs | 8 ++++---- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/curves/src/pasta/wasm_friendly/backend9.rs b/curves/src/pasta/wasm_friendly/backend9.rs index f01c34b996..73d83556a9 100644 --- a/curves/src/pasta/wasm_friendly/backend9.rs +++ b/curves/src/pasta/wasm_friendly/backend9.rs @@ -7,7 +7,7 @@ use super::wasm_fp::{Fp, FpBackend}; type B = [u32; 9]; type B64 = [u64; 9]; -const SHIFT: u64 = 29; +const SHIFT: u32 = 29; const MASK: u32 = (1 << SHIFT) - 1; const SHIFT64: u64 = SHIFT as u64; @@ -66,6 +66,8 @@ pub trait FpConstants: Send + Sync + 'static + Sized { #[inline] fn gte_modulus(x: &B) -> bool { for i in (0..9).rev() { + // don't fix warning -- that makes it 15% slower! + #[allow(clippy::comparison_chain)] if x[i] > FpC::MODULUS[i] { return true; } else if x[i] < FpC::MODULUS[i] { @@ -92,6 +94,7 @@ pub fn add_assign(x: &mut B, y: &B) { if gte_modulus::(x) { carry = 0; + #[allow(clippy::needless_range_loop)] for i in 0..9 { tmp = x[i].wrapping_sub(FpC::MODULUS[i]) + (carry as u32); carry = (tmp as i32) >> SHIFT; @@ -103,12 +106,15 @@ pub fn add_assign(x: &mut B, y: &B) { #[inline] fn conditional_reduce(x: &mut B) { if gte_modulus::(x) { + #[allow(clippy::needless_range_loop)] for i in 0..9 { x[i] = x[i].wrapping_sub(FpC::MODULUS[i]); } + #[allow(clippy::needless_range_loop)] for i in 1..9 { - x[i] = x[i] + (((x[i - 1] as i32) >> SHIFT) as u32); + x[i] += ((x[i - 1] as i32) >> SHIFT) as u32; } + #[allow(clippy::needless_range_loop)] for i in 0..8 { x[i] &= MASK; } @@ -129,6 +135,7 @@ pub fn mul_assign(x: &mut B, y: &B) { let mut tmp: u64; // main loop, without intermediate carries except for z0 + #[allow(clippy::needless_range_loop)] for i in 0..9 { let xi = x[i] as u64; @@ -160,7 +167,7 @@ pub fn mul_assign(x: &mut B, y: &B) { // implement FpBackend given FpConstants pub fn from_bigint_unsafe(x: BigInt<9>) -> Fp { - let mut r = x.0.clone(); + let mut r = x.0; // convert to montgomery form mul_assign::(&mut r, &FpC::R2); Fp(BigInt(r), Default::default()) @@ -188,7 +195,7 @@ impl FpBackend<9> for FpC { } fn to_bigint(x: Fp) -> BigInt<9> { let one = [1, 0, 0, 0, 0, 0, 0, 0, 0]; - let mut r = x.0 .0.clone(); + let mut r = x.0 .0; // convert back from montgomery form mul_assign::(&mut r, &one); BigInt(r) diff --git a/curves/src/pasta/wasm_friendly/wasm_fp.rs b/curves/src/pasta/wasm_friendly/wasm_fp.rs index 5970015057..7c0d45d34a 100644 --- a/curves/src/pasta/wasm_friendly/wasm_fp.rs +++ b/curves/src/pasta/wasm_friendly/wasm_fp.rs @@ -41,7 +41,6 @@ pub trait FpBackend: Send + Sync + 'static + Sized { #[derivative( Default(bound = ""), Hash(bound = ""), - Clone(bound = ""), Copy(bound = ""), PartialEq(bound = ""), Eq(bound = ""), @@ -54,6 +53,12 @@ pub struct Fp, const N: usize>( pub PhantomData

, ); +impl, const N: usize> Clone for Fp { + fn clone(&self) -> Self { + *self + } +} + impl, const N: usize> Fp { pub fn new(bigint: BigInt) -> Self { Fp(bigint, Default::default()) diff --git a/poseidon/benches/poseidon_bench.rs b/poseidon/benches/poseidon_bench.rs index fb240cde88..4f44ad1f99 100644 --- a/poseidon/benches/poseidon_bench.rs +++ b/poseidon/benches/poseidon_bench.rs @@ -19,8 +19,8 @@ pub fn bench_poseidon_kimchi(c: &mut Criterion) { SpongeParametersKimchi::static_params(), ); - poseidon.absorb(&[Fp::zero()]); - println!("{}", poseidon.squeeze().to_string()); + // poseidon.absorb(&[Fp::zero()]); + // println!("{}", poseidon.squeeze()); b.iter(|| { poseidon.absorb(&[hash]); @@ -33,8 +33,8 @@ pub fn bench_poseidon_kimchi(c: &mut Criterion) { let mut hash: Fp9 = Fp9::zero(); let mut poseidon = Poseidon::::new(fp9_static_params()); - poseidon.absorb(&[Fp9::zero()]); - println!("{}", poseidon.squeeze().to_string()); + // poseidon.absorb(&[Fp9::zero()]); + // println!("{}", poseidon.squeeze()); b.iter(|| { poseidon.absorb(&[hash]); From d727c5f161f51fb23c6ca3c940edec27e0f2fcfc Mon Sep 17 00:00:00 2001 From: Gregor Date: Tue, 1 Oct 2024 17:14:09 +0200 Subject: [PATCH 15/15] fix comments --- curves/src/pasta/wasm_friendly/backend9.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/curves/src/pasta/wasm_friendly/backend9.rs b/curves/src/pasta/wasm_friendly/backend9.rs index 73d83556a9..9295fa8471 100644 --- a/curves/src/pasta/wasm_friendly/backend9.rs +++ b/curves/src/pasta/wasm_friendly/backend9.rs @@ -77,11 +77,11 @@ fn gte_modulus(x: &B) -> bool { true } -/// TODO performance ideas to test: -/// - unroll loops -/// - introduce locals for a[i] instead of accessing memory multiple times -/// - only do 1 carry pass at the end, by proving properties of greater-than on uncarried result -/// - use cheaper, approximate greater-than check a[8] > Fp::MODULUS[8] +// TODO performance ideas to test: +// - unroll loops +// - introduce locals for a[i] instead of accessing memory multiple times +// - only do 1 carry pass at the end, by proving properties of greater-than on uncarried result +// - use cheaper, approximate greater-than check a[8] > Fp::MODULUS[8] pub fn add_assign(x: &mut B, y: &B) { let mut tmp: u32; let mut carry: i32 = 0;