Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wasm-friendly Field #2638

Draft
wants to merge 15 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions curves/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ license = "Apache-2.0"
[dependencies]
ark-ec.workspace = true
ark-ff.workspace = true
ark-serialize.workspace = true
wasm-bindgen.workspace = true
num-bigint.workspace = true
derivative = { version = "2.0", features = ["use_core"] }

[dev-dependencies]
rand.workspace = true
rand.workspace = true
ark-test-curves.workspace = true
ark-algebra-test-templates.workspace = true
ark-serialize.workspace = true
ark-std.workspace = true
1 change: 1 addition & 0 deletions curves/src/pasta/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod curves;
pub mod fields;
pub mod wasm_friendly;

pub use curves::{
pallas::{Pallas, PallasParameters, ProjectivePallas},
Expand Down
213 changes: 213 additions & 0 deletions curves/src/pasta/wasm_friendly/backend9.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
/**
* Implementation of `FpBackend` for N=9, using 29-bit limbs represented by `u32`s.
*/
use super::bigint32::BigInt;
use super::wasm_fp::{Fp, FpBackend};

type B = [u32; 9];
type B64 = [u64; 9];

const SHIFT: u32 = 29;
const MASK: u32 = (1 << SHIFT) - 1;

const SHIFT64: u64 = SHIFT as u64;
const MASK64: u64 = MASK as u64;

pub const fn from_64x4(pa: [u64; 4]) -> [u32; 9] {
let mut p = [0u32; 9];
p[0] = (pa[0] & MASK64) as u32;
p[1] = ((pa[0] >> 29) & MASK64) as u32;
p[2] = (((pa[0] >> 58) | (pa[1] << 6)) & MASK64) as u32;
p[3] = ((pa[1] >> 23) & MASK64) as u32;
p[4] = (((pa[1] >> 52) | (pa[2] << 12)) & MASK64) as u32;
p[5] = ((pa[2] >> 17) & MASK64) as u32;
p[6] = (((pa[2] >> 46) | (pa[3] << 18)) & MASK64) as u32;
p[7] = ((pa[3] >> 11) & MASK64) as u32;
p[8] = (pa[3] >> 40) as u32;
p
}
pub const fn to_64x4(pa: [u32; 9]) -> [u64; 4] {
let mut p = [0u64; 4];
p[0] = pa[0] as u64;
p[0] |= (pa[1] as u64) << 29;
p[0] |= (pa[2] as u64) << 58;
p[1] = (pa[2] as u64) >> 6;
p[1] |= (pa[3] as u64) << 23;
p[1] |= (pa[4] as u64) << 52;
p[2] = (pa[4] as u64) >> 12;
p[2] |= (pa[5] as u64) << 17;
p[2] |= (pa[6] as u64) << 46;
p[3] = (pa[6] as u64) >> 18;
p[3] |= (pa[7] as u64) << 11;
p[3] |= (pa[8] as u64) << 40;
p
}

pub trait FpConstants: Send + Sync + 'static + Sized {
const MODULUS: B;
const MODULUS64: B64 = {
let mut modulus64 = [0u64; 9];
let modulus = Self::MODULUS;
let mut i = 0;
while i < 9 {
modulus64[i] = modulus[i] as u64;
i += 1;
}
modulus64
};

/// montgomery params
/// TODO: compute these
const R: B; // R = 2^261 mod modulus
const R2: B; // R^2 mod modulus
const MINV: u64; // -modulus^(-1) mod 2^29, as a u64
}

#[inline]
fn gte_modulus<FpC: FpConstants>(x: &B) -> bool {
for i in (0..9).rev() {
// don't fix warning -- that makes it 15% slower!
#[allow(clippy::comparison_chain)]
if x[i] > FpC::MODULUS[i] {
return true;
} else if x[i] < FpC::MODULUS[i] {
return false;
}
}
true
}

// TODO performance ideas to test:
// - unroll loops
// - introduce locals for a[i] instead of accessing memory multiple times
// - only do 1 carry pass at the end, by proving properties of greater-than on uncarried result
// - use cheaper, approximate greater-than check a[8] > Fp::MODULUS[8]
pub fn add_assign<FpC: FpConstants>(x: &mut B, y: &B) {
let mut tmp: u32;
let mut carry: i32 = 0;

for i in 0..9 {
tmp = x[i] + y[i] + (carry as u32);
carry = (tmp as i32) >> SHIFT;
x[i] = tmp & MASK;
}

if gte_modulus::<FpC>(x) {
carry = 0;
#[allow(clippy::needless_range_loop)]
for i in 0..9 {
tmp = x[i].wrapping_sub(FpC::MODULUS[i]) + (carry as u32);
carry = (tmp as i32) >> SHIFT;
x[i] = tmp & MASK;
}
}
}

#[inline]
fn conditional_reduce<FpC: FpConstants>(x: &mut B) {
if gte_modulus::<FpC>(x) {
#[allow(clippy::needless_range_loop)]
for i in 0..9 {
x[i] = x[i].wrapping_sub(FpC::MODULUS[i]);
}
#[allow(clippy::needless_range_loop)]
for i in 1..9 {
x[i] += ((x[i - 1] as i32) >> SHIFT) as u32;
}
#[allow(clippy::needless_range_loop)]
for i in 0..8 {
x[i] &= MASK;
}
}
}

/// Montgomery multiplication
pub fn mul_assign<FpC: FpConstants>(x: &mut B, y: &B) {
// load y[i] into local u64s
// TODO make sure these are locals
let mut y_local = [0u64; 9];
for i in 0..9 {
y_local[i] = y[i] as u64;
}

// locals for result
let mut z = [0u64; 8];
let mut tmp: u64;

// main loop, without intermediate carries except for z0
#[allow(clippy::needless_range_loop)]
for i in 0..9 {
let xi = x[i] as u64;

// compute qi and carry z0 result to z1 before discarding z0
tmp = xi * y_local[0];
let qi = ((tmp & MASK64) * FpC::MINV) & MASK64;
z[1] += (tmp + qi * FpC::MODULUS64[0]) >> SHIFT64;

// compute zi and shift in one step
for j in 1..8 {
z[j - 1] = z[j] + (xi * y_local[j]) + (qi * FpC::MODULUS64[j]);
}
// for j=8 we save an addition since z[8] is never needed
z[7] = xi * y_local[8] + qi * FpC::MODULUS64[8];
}

// final carry pass, store result back into x
x[0] = (z[0] & MASK64) as u32;
for i in 1..8 {
x[i] = (((z[i - 1] >> SHIFT64) + z[i]) & MASK64) as u32;
}
x[8] = (z[7] >> SHIFT64) as u32;

// at this point, x is guaranteed to be less than 2*MODULUS
// conditionally subtract the modulus to bring it back into the canonical range
conditional_reduce::<FpC>(x);
}

// implement FpBackend given FpConstants

pub fn from_bigint_unsafe<FpC: FpConstants>(x: BigInt<9>) -> Fp<FpC, 9> {
let mut r = x.0;
// convert to montgomery form
mul_assign::<FpC>(&mut r, &FpC::R2);
Fp(BigInt(r), Default::default())
}

impl<FpC: FpConstants> FpBackend<9> for FpC {
const MODULUS: BigInt<9> = BigInt(Self::MODULUS);
const ZERO: BigInt<9> = BigInt([0; 9]);
const ONE: BigInt<9> = BigInt(Self::R);

fn add_assign(x: &mut Fp<Self, 9>, y: &Fp<Self, 9>) {
add_assign::<Self>(&mut x.0 .0, &y.0 .0);
}

fn mul_assign(x: &mut Fp<Self, 9>, y: &Fp<Self, 9>) {
mul_assign::<Self>(&mut x.0 .0, &y.0 .0);
}

fn from_bigint(x: BigInt<9>) -> Option<Fp<Self, 9>> {
if gte_modulus::<Self>(&x.0) {
None
} else {
Some(from_bigint_unsafe(x))
}
}
fn to_bigint(x: Fp<Self, 9>) -> BigInt<9> {
let one = [1, 0, 0, 0, 0, 0, 0, 0, 0];
let mut r = x.0 .0;
// convert back from montgomery form
mul_assign::<Self>(&mut r, &one);
BigInt(r)
}

fn pack(x: Fp<Self, 9>) -> Vec<u64> {
let x = Self::to_bigint(x).0;
let x64 = to_64x4(x);
let mut res = Vec::with_capacity(4);
for limb in x64.iter() {
res.push(*limb);
}
res
}
}
52 changes: 52 additions & 0 deletions curves/src/pasta/wasm_friendly/bigint32.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/**
* BigInt with 32-bit limbs
*
* Contains everything for wasm_fp which is unrelated to being a field
*
* Code is mostly copied from ark-ff::BigInt
*/
use ark_serialize::{
CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate,
Write,
};

#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct BigInt<const N: usize>(pub [u32; N]);

impl<const N: usize> Default for BigInt<N> {
fn default() -> Self {
Self([0u32; N])
}
}

impl<const N: usize> CanonicalSerialize for BigInt<N> {
fn serialize_with_mode<W: Write>(
&self,
writer: W,
compress: Compress,
) -> Result<(), SerializationError> {
self.0.serialize_with_mode(writer, compress)
}

fn serialized_size(&self, compress: Compress) -> usize {
self.0.serialized_size(compress)
}
}

impl<const N: usize> Valid for BigInt<N> {
fn check(&self) -> Result<(), SerializationError> {
self.0.check()
}
}

impl<const N: usize> CanonicalDeserialize for BigInt<N> {
fn deserialize_with_mode<R: Read>(
reader: R,
compress: Compress,
validate: Validate,
) -> Result<Self, SerializationError> {
Ok(BigInt::<N>(<[u32; N]>::deserialize_with_mode(
reader, compress, validate,
)?))
}
}
44 changes: 44 additions & 0 deletions curves/src/pasta/wasm_friendly/minimal_field.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use ark_ff::{BitIteratorBE, One, Zero};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use std::ops::{Add, AddAssign, Mul, MulAssign};

/**
* Minimal Field trait needed to implement Poseidon
*/
pub trait MinimalField:
'static
+ Copy
+ Clone
+ CanonicalSerialize
+ CanonicalDeserialize
+ Zero
+ One
+ for<'a> Add<&'a Self, Output = Self>
+ for<'a> Mul<&'a Self, Output = Self>
+ for<'a> AddAssign<&'a Self>
+ for<'a> MulAssign<&'a Self>
{
/// Squares `self` in place.
fn square_in_place(&mut self) -> &mut Self;

/// Returns `self^exp`, where `exp` is an integer represented with `u64` limbs,
/// least significant limb first.
fn pow<S: AsRef<[u64]>>(&self, exp: S) -> Self {
let mut res = Self::one();

for i in BitIteratorBE::without_leading_zeros(exp) {
res.square_in_place();

if i {
res *= self;
}
}
res
}
}

impl<F: ark_ff::Field> MinimalField for F {
fn square_in_place(&mut self) -> &mut Self {
self.square_in_place()
}
}
12 changes: 12 additions & 0 deletions curves/src/pasta/wasm_friendly/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
pub mod bigint32;
pub use bigint32::BigInt;

pub mod minimal_field;
pub use minimal_field::MinimalField;

pub mod wasm_fp;
pub use wasm_fp::Fp;

pub mod backend9;
pub mod pasta;
pub use pasta::Fp9;
33 changes: 33 additions & 0 deletions curves/src/pasta/wasm_friendly/pasta.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use super::{backend9, wasm_fp};
use crate::pasta::Fp;
use ark_ff::PrimeField;

pub struct Fp9Parameters;

impl backend9::FpConstants for Fp9Parameters {
const MODULUS: [u32; 9] = [
0x1, 0x9698768, 0x133e46e6, 0xd31f812, 0x224, 0x0, 0x0, 0x0, 0x400000,
];
const R: [u32; 9] = [
0x1fffff81, 0x14a5d367, 0x141ad3c0, 0x1435eec5, 0x1ffeefef, 0x1fffffff, 0x1fffffff,
0x1fffffff, 0x3fffff,
];
const R2: [u32; 9] = [
0x3b6a, 0x19c10910, 0x1a6a0188, 0x12a4fd88, 0x634b36d, 0x178792ba, 0x7797a99, 0x1dce5b8a,
0x3506bd,
];
const MINV: u64 = 0x1fffffff;
}
pub type Fp9 = wasm_fp::Fp<Fp9Parameters, 9>;

impl Fp9 {
pub fn from_fp(fp: Fp) -> Self {
backend9::from_bigint_unsafe(super::BigInt(backend9::from_64x4(fp.into_bigint().0)))
}
}

impl From<Fp> for Fp9 {
fn from(fp: Fp) -> Self {
Fp9::from_fp(fp)
}
}
Loading