Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CIOS for ARM F::mul #134

Merged
merged 11 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ serde_arrays = { version = "0.1.0", optional = true }
hex = { version = "0.4", optional = true, default-features = false, features = ["alloc", "serde"] }
blake2b_simd = "1"
rayon = "1.8"
unroll = "0.1.5"

[features]
default = ["bits"]
Expand Down
24 changes: 24 additions & 0 deletions src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ pub(crate) const fn macx(a: u64, b: u64, c: u64) -> (u64, u64) {
(res as u64, (res >> 64) as u64)
}

/// Returns a >= b
#[inline(always)]
pub(crate) const fn bigint_geq(a: &[u64; 4], b: &[u64; 4]) -> bool {
if a[3] > b[3] {
return true;
} else if a[3] < b[3] {
return false;
}
if a[2] > b[2] {
return true;
} else if a[2] < b[2] {
return false;
}
if a[1] > b[1] {
return true;
} else if a[1] < b[1] {
return false;
}
if a[0] >= b[0] {
return true;
}
false
}

/// Compute a * b, returning the result.
#[inline(always)]
pub(crate) fn mul_512(a: [u64; 4], b: [u64; 4]) -> [u64; 8] {
Expand Down
2 changes: 1 addition & 1 deletion src/bn256/fq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::bn256::assembly::field_arithmetic_asm;
#[cfg(not(feature = "asm"))]
use crate::{arithmetic::macx, field_arithmetic, field_specific};

use crate::arithmetic::{adc, mac, sbb};
use crate::arithmetic::{adc, bigint_geq, mac, sbb};
use crate::extend_field_legendre;
use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
Expand Down
2 changes: 1 addition & 1 deletion src/bn256/fr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub use table::FR_TABLE;
#[cfg(not(feature = "bn256-table"))]
use crate::impl_from_u64;

use crate::arithmetic::{adc, mac, sbb};
use crate::arithmetic::{adc, bigint_geq, mac, sbb};
use crate::extend_field_legendre;
use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
Expand Down
233 changes: 136 additions & 97 deletions src/derive/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,73 +63,88 @@ macro_rules! field_common {
$crate::ff_ext::jacobi::jacobi::<5>(&self.0, &$modulus.0)
}

#[cfg(feature = "asm")]
const fn montgomery_form(val: [u64; 4], r: $field) -> $field {
// Converts a 4 64-bit limb value into its congruent field representation.
// If `val` represents a 256 bit value then `r` should be R^2,
// if `val` represents the 256 MSB of a 512 bit value, then `r` should be R^3.

let (r0, carry) = mac(0, val[0], r.0[0], 0);
let (r1, carry) = mac(0, val[0], r.0[1], carry);
let (r2, carry) = mac(0, val[0], r.0[2], carry);
let (r3, r4) = mac(0, val[0], r.0[3], carry);

let (r1, carry) = mac(r1, val[1], r.0[0], 0);
let (r2, carry) = mac(r2, val[1], r.0[1], carry);
let (r3, carry) = mac(r3, val[1], r.0[2], carry);
let (r4, r5) = mac(r4, val[1], r.0[3], carry);

let (r2, carry) = mac(r2, val[2], r.0[0], 0);
let (r3, carry) = mac(r3, val[2], r.0[1], carry);
let (r4, carry) = mac(r4, val[2], r.0[2], carry);
let (r5, r6) = mac(r5, val[2], r.0[3], carry);

let (r3, carry) = mac(r3, val[3], r.0[0], 0);
let (r4, carry) = mac(r4, val[3], r.0[1], carry);
let (r5, carry) = mac(r5, val[3], r.0[2], carry);
let (r6, r7) = mac(r6, val[3], r.0[3], carry);

// Montgomery reduction
let k = r0.wrapping_mul($inv);
let (_, carry) = mac(r0, k, $modulus.0[0], 0);
let (r1, carry) = mac(r1, k, $modulus.0[1], carry);
let (r2, carry) = mac(r2, k, $modulus.0[2], carry);
let (r3, carry) = mac(r3, k, $modulus.0[3], carry);
let (r4, carry2) = adc(r4, 0, carry);

let k = r1.wrapping_mul($inv);
let (_, carry) = mac(r1, k, $modulus.0[0], 0);
let (r2, carry) = mac(r2, k, $modulus.0[1], carry);
let (r3, carry) = mac(r3, k, $modulus.0[2], carry);
let (r4, carry) = mac(r4, k, $modulus.0[3], carry);
let (r5, carry2) = adc(r5, carry2, carry);

let k = r2.wrapping_mul($inv);
let (_, carry) = mac(r2, k, $modulus.0[0], 0);
let (r3, carry) = mac(r3, k, $modulus.0[1], carry);
let (r4, carry) = mac(r4, k, $modulus.0[2], carry);
let (r5, carry) = mac(r5, k, $modulus.0[3], carry);
let (r6, carry2) = adc(r6, carry2, carry);

let k = r3.wrapping_mul($inv);
let (_, carry) = mac(r3, k, $modulus.0[0], 0);
let (r4, carry) = mac(r4, k, $modulus.0[1], carry);
let (r5, carry) = mac(r5, k, $modulus.0[2], carry);
let (r6, carry) = mac(r6, k, $modulus.0[3], carry);
let (r7, carry2) = adc(r7, carry2, carry);

// Result may be within MODULUS of the correct value
let (d0, borrow) = sbb(r4, $modulus.0[0], 0);
let (d1, borrow) = sbb(r5, $modulus.0[1], borrow);
let (d2, borrow) = sbb(r6, $modulus.0[2], borrow);
let (d3, borrow) = sbb(r7, $modulus.0[3], borrow);
let (_, borrow) = sbb(carry2, 0, borrow);
let (d0, carry) = adc(d0, $modulus.0[0] & borrow, 0);
let (d1, carry) = adc(d1, $modulus.0[1] & borrow, carry);
let (d2, carry) = adc(d2, $modulus.0[2] & borrow, carry);
let (d3, _) = adc(d3, $modulus.0[3] & borrow, carry);
#[cfg(feature = "asm")]
{
let (r0, carry) = mac(0, val[0], r.0[0], 0);
let (r1, carry) = mac(0, val[0], r.0[1], carry);
let (r2, carry) = mac(0, val[0], r.0[2], carry);
let (r3, r4) = mac(0, val[0], r.0[3], carry);

let (r1, carry) = mac(r1, val[1], r.0[0], 0);
let (r2, carry) = mac(r2, val[1], r.0[1], carry);
let (r3, carry) = mac(r3, val[1], r.0[2], carry);
let (r4, r5) = mac(r4, val[1], r.0[3], carry);

let (r2, carry) = mac(r2, val[2], r.0[0], 0);
let (r3, carry) = mac(r3, val[2], r.0[1], carry);
let (r4, carry) = mac(r4, val[2], r.0[2], carry);
let (r5, r6) = mac(r5, val[2], r.0[3], carry);

let (r3, carry) = mac(r3, val[3], r.0[0], 0);
let (r4, carry) = mac(r4, val[3], r.0[1], carry);
let (r5, carry) = mac(r5, val[3], r.0[2], carry);
let (r6, r7) = mac(r6, val[3], r.0[3], carry);

// Montgomery reduction
let k = r0.wrapping_mul($inv);
let (_, carry) = mac(r0, k, $modulus.0[0], 0);
let (r1, carry) = mac(r1, k, $modulus.0[1], carry);
let (r2, carry) = mac(r2, k, $modulus.0[2], carry);
let (r3, carry) = mac(r3, k, $modulus.0[3], carry);
let (r4, carry2) = adc(r4, 0, carry);

let k = r1.wrapping_mul($inv);
let (_, carry) = mac(r1, k, $modulus.0[0], 0);
let (r2, carry) = mac(r2, k, $modulus.0[1], carry);
let (r3, carry) = mac(r3, k, $modulus.0[2], carry);
let (r4, carry) = mac(r4, k, $modulus.0[3], carry);
let (r5, carry2) = adc(r5, carry2, carry);

let k = r2.wrapping_mul($inv);
let (_, carry) = mac(r2, k, $modulus.0[0], 0);
let (r3, carry) = mac(r3, k, $modulus.0[1], carry);
let (r4, carry) = mac(r4, k, $modulus.0[2], carry);
let (r5, carry) = mac(r5, k, $modulus.0[3], carry);
let (r6, carry2) = adc(r6, carry2, carry);

let k = r3.wrapping_mul($inv);
let (_, carry) = mac(r3, k, $modulus.0[0], 0);
let (r4, carry) = mac(r4, k, $modulus.0[1], carry);
let (r5, carry) = mac(r5, k, $modulus.0[2], carry);
let (r6, carry) = mac(r6, k, $modulus.0[3], carry);
let (r7, carry2) = adc(r7, carry2, carry);

// Result may be within MODULUS of the correct value
let (d0, borrow) = sbb(r4, $modulus.0[0], 0);
let (d1, borrow) = sbb(r5, $modulus.0[1], borrow);
let (d2, borrow) = sbb(r6, $modulus.0[2], borrow);
let (d3, borrow) = sbb(r7, $modulus.0[3], borrow);
let (_, borrow) = sbb(carry2, 0, borrow);
let (d0, carry) = adc(d0, $modulus.0[0] & borrow, 0);
let (d1, carry) = adc(d1, $modulus.0[1] & borrow, carry);
let (d2, carry) = adc(d2, $modulus.0[2] & borrow, carry);
let (d3, _) = adc(d3, $modulus.0[3] & borrow, carry);

$field([d0, d1, d2, d3])
}

$field([d0, d1, d2, d3])
#[cfg(not(feature = "asm"))]
{
let mut val = val;
if bigint_geq(&val, &$modulus.0) {
let mut borrow = 0;
(val[0], borrow) = sbb(val[0], $modulus.0[0], borrow);
(val[1], borrow) = sbb(val[1], $modulus.0[1], borrow);
(val[2], borrow) = sbb(val[2], $modulus.0[2], borrow);
(val[3], _) = sbb(val[3], $modulus.0[3], borrow);
}
$field::mul(&$field(val), &r)
}
}

fn from_u512(limbs: [u64; 8]) -> $field {
Expand All @@ -150,27 +165,13 @@ macro_rules! field_common {
let lower_256 = [limbs[0], limbs[1], limbs[2], limbs[3]];
let upper_256 = [limbs[4], limbs[5], limbs[6], limbs[7]];

#[cfg(feature = "asm")]
{
Self::montgomery_form(lower_256, $r2) + Self::montgomery_form(upper_256, $r3)
}
#[cfg(not(feature = "asm"))]
{
$field(lower_256) * $r2 + $field(upper_256) * $r3
}
Self::montgomery_form(lower_256, $r2) + Self::montgomery_form(upper_256, $r3)
}

/// Converts from an integer represented in little endian
/// into its (congruent) `$field` representation.
pub const fn from_raw(val: [u64; 4]) -> Self {
#[cfg(feature = "asm")]
{
Self::montgomery_form(val, $r2)
}
#[cfg(not(feature = "asm"))]
{
(&$field(val)).mul(&$r2)
}
Self::montgomery_form(val, $r2)
}

/// Attempts to convert a little-endian byte representation of
Expand Down Expand Up @@ -429,31 +430,69 @@ macro_rules! field_arithmetic {
}

/// Multiplies `rhs` by `self`, returning the result.
#[inline]
pub const fn mul(&self, rhs: &Self) -> $field {
// Schoolbook multiplication
#[inline(always)]
#[unroll::unroll_for_loops]
#[allow(unused_assignments)]
pub const fn mul(&self, rhs: &Self) -> Self {
// Fast Coarsely Integrated Operand Scanning (CIOS) as described
// in Algorithm 2 of EdMSM: https://eprint.iacr.org/2022/1400.pdf
//
// Cannot use the fast version (algorithm 2) if
// modulus_high_word >= (WORD_SIZE - 1) / 2 - 1 = (2^64 - 1)/2 - 1

if $modulus.0[3] < (u64::MAX / 2) {
const N: usize = 4;
let mut t: [u64; N] = [0u64; N];
let mut c_2: u64;
for i in 0..4 {
let mut c: u64 = 0u64;
for j in 0..4 {
(t[j], c) = mac(t[j], self.0[j], rhs.0[i], c);
}
c_2 = c;

let m = t[0].wrapping_mul(INV);
(_, c) = macx(t[0], m, $modulus.0[0]);

for j in 1..4 {
(t[j - 1], c) = mac(t[j], m, $modulus.0[j], c);
}
(t[N - 1], _) = adc(c_2, c, 0);
}

if bigint_geq(&t, &$modulus.0) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the bigint_geq procedure and its usage may be suboptimal. First of all, you can notice that $t ≥ m ⇔ !(t &lt; m)$. You can implement the "less-than" as subtraction $t - m$. If it overflows (borrow is 1) then $t &lt; m$.
Then you can notice that you are actually computing the $t - m$ just after. The common pattern is to compute the subtraction speculatively and use its result in case it hasn't overflowed.

(tmp, borrow) = t - m
t = borrow ? t : tmp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

I believe strategy would only save instructions in the case that bigint_geq == true. It would cost some instructions in the case that bigint_geq == false as sbb gets broken out into a few instructions on ARM where as the 4 u64 LTs should be a single instruction each.

let mut borrow = 0;
(t[0], borrow) = sbb(t[0], $modulus.0[0], borrow);
(t[1], borrow) = sbb(t[1], $modulus.0[1], borrow);
(t[2], borrow) = sbb(t[2], $modulus.0[2], borrow);
(t[3], borrow) = sbb(t[3], $modulus.0[3], borrow);
}
$field(t)
} else {
// Schoolbook multiplication

let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0);
let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry);
let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry);
let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry);
let (r0, carry) = mac(0, self.0[0], rhs.0[0], 0);
let (r1, carry) = mac(0, self.0[0], rhs.0[1], carry);
let (r2, carry) = mac(0, self.0[0], rhs.0[2], carry);
let (r3, r4) = mac(0, self.0[0], rhs.0[3], carry);

let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0);
let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry);
let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry);
let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry);
let (r1, carry) = mac(r1, self.0[1], rhs.0[0], 0);
let (r2, carry) = mac(r2, self.0[1], rhs.0[1], carry);
let (r3, carry) = mac(r3, self.0[1], rhs.0[2], carry);
let (r4, r5) = mac(r4, self.0[1], rhs.0[3], carry);

let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0);
let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry);
let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry);
let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry);
let (r2, carry) = mac(r2, self.0[2], rhs.0[0], 0);
let (r3, carry) = mac(r3, self.0[2], rhs.0[1], carry);
let (r4, carry) = mac(r4, self.0[2], rhs.0[2], carry);
let (r5, r6) = mac(r5, self.0[2], rhs.0[3], carry);

let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0);
let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry);
let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry);
let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry);
let (r3, carry) = mac(r3, self.0[3], rhs.0[0], 0);
let (r4, carry) = mac(r4, self.0[3], rhs.0[1], carry);
let (r5, carry) = mac(r5, self.0[3], rhs.0[2], carry);
let (r6, r7) = mac(r6, self.0[3], rhs.0[3], carry);

$field::montgomery_reduce(&[r0, r1, r2, r3, r4, r5, r6, r7])
$field::montgomery_reduce(&[r0, r1, r2, r3, r4, r5, r6, r7])
}
}

/// Subtracts `rhs` from `self`, returning the result.
Expand Down
2 changes: 1 addition & 1 deletion src/secp256k1/fp.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::arithmetic::{adc, mac, macx, sbb};
use crate::arithmetic::{adc, bigint_geq, mac, macx, sbb};
use crate::extend_field_legendre;
use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
Expand Down
2 changes: 1 addition & 1 deletion src/secp256k1/fq.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::arithmetic::{adc, mac, macx, sbb};
use crate::arithmetic::{adc, bigint_geq, mac, macx, sbb};
use crate::extend_field_legendre;
use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
Expand Down
2 changes: 1 addition & 1 deletion src/secp256r1/fp.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::arithmetic::{adc, mac, macx, sbb};
use crate::arithmetic::{adc, bigint_geq, mac, macx, sbb};
use crate::extend_field_legendre;
use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
Expand Down
2 changes: 1 addition & 1 deletion src/secp256r1/fq.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::arithmetic::{adc, mac, macx, sbb};
use crate::arithmetic::{adc, bigint_geq, mac, macx, sbb};
use crate::extend_field_legendre;
use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use core::fmt;
Expand Down
Loading