Skip to content

Commit

Permalink
[feat] pre-compute a lookup table for bn256::scalarfield (#46)
Browse files Browse the repository at this point in the history
* [feat] use pre-computed table for bn256::scalarfield

* [chore] cargo fmt

* [feat] turn off bn256-table by default

---------

Co-authored-by: Han <[email protected]>
  • Loading branch information
zhenfeizhang and han0110 authored May 31, 2023
1 parent cf57ee9 commit e97adcb
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 14 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@
Cargo.lock
**/*.rs.bk
.vscode
**/*.html
**/*.html

# script generated source code
src/bn256/fr/table.rs
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ serde_arrays = { version = "0.1.0", optional = true }
[features]
default = ["reexport", "bits"]
asm = []
bits = ["ff/bits"]
bn256-table = []
derive_serde = ["serde/derive", "serde_arrays"]
prefetch = []
print-trace = ["ark-std/print-trace"]
derive_serde = ["serde/derive", "serde_arrays"]
reexport = []
bits = ["ff/bits"]

[profile.bench]
opt-level = 3
Expand Down
13 changes: 13 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,17 @@ fn main() {
eprintln!("Currently feature `asm` can only be enabled on x86_64 arch.");
std::process::exit(1);
}
#[cfg(feature = "bn256-table")]
{
if std::path::Path::new("src/bn256/fr/table.rs").exists() {
eprintln!("Pre-computed table for BN256 scalar field exists.");
eprintln!("Skip pre-computation\n");
} else {
eprintln!("Generating pre-computed table for BN256 scalar field\n");
std::process::Command::new("python3")
.args(["script/bn256.py"])
.output()
.expect("requires python 3 to build pre-computed table");
}
}
}
46 changes: 46 additions & 0 deletions script/bn256.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# This file generates the montogomary form integers for x in [0, 2^16) \intersect
# BN::ScalarField

verbose = False

modulus = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
R = 2**256 % modulus
table_size = 1<<16

# @input: field element a
# @output: 4 u64 a0, a1, a2, a3 s.t.
# a = a3 * 2^192 + a2 * 2^128 + a1 * 2^64 + a0
def decompose_field_element(a):
a0 = a % 2**64
a = a // 2**64
a1 = a % 2**64
a = a // 2**64
a2 = a % 2**64
a = a // 2**64
a3 = a
return [a0, a1, a2, a3]


# @input: field element a
# @output: a rust format string that encodes
# 4 u64 a0, a1, a2, a3 s.t.
# a = a3 * 2^192 + a2 * 2^128 + a1 * 2^64 + a0
def format_field_element(a):
[a0, a1, a2, a3] = decompose_field_element(a);
return "Fr([" + hex(a0) + "," + hex(a1) + "," + hex(a2) + "," + hex(a3) + "]),\n"


f = open("src/bn256/fr/table.rs", "w")
f.write("//! auto generated file from scripts/bn256.sage, do not modify\n")
f.write("//! see src/bn256/fr.rs for more details\n")
f.write("use super::Fr;\n")
f.write("pub const FR_TABLE: &[Fr] = &[\n")

for i in range(table_size):
a = (i * R) % modulus
if verbose:
print (i, a, format_field_element(a))
f.write(format_field_element(a))

f.write("\n];")

3 changes: 2 additions & 1 deletion src/bn256/fq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::ff::{Field, FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
field_bits, field_common, impl_add_binop_specify_output, impl_binops_additive,
impl_binops_additive_specify_output, impl_binops_multiplicative,
impl_binops_multiplicative_mixed, impl_sub_binop_specify_output, impl_sum_prod,
impl_binops_multiplicative_mixed, impl_from_u64, impl_sub_binop_specify_output, impl_sum_prod,
};
use core::convert::TryInto;
use core::fmt;
Expand Down Expand Up @@ -134,6 +134,7 @@ field_common!(
R3
);
impl_sum_prod!(Fq);
impl_from_u64!(Fq, R2);

#[cfg(not(feature = "asm"))]
field_arithmetic!(Fq, MODULUS, INV, sparse);
Expand Down
52 changes: 50 additions & 2 deletions src/bn256/fr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@ use crate::bn256::assembly::field_arithmetic_asm;
#[cfg(not(feature = "asm"))]
use crate::{field_arithmetic, field_specific};

#[cfg(feature = "bn256-table")]
#[rustfmt::skip]
mod table;
#[cfg(feature = "bn256-table")]
#[cfg(test)]
mod table_tests;

#[cfg(feature = "bn256-table")]
// This table should have being generated by `build.rs`;
// and stored in `src/bn256/fr/table.rs`.
pub use table::FR_TABLE;

#[cfg(not(feature = "bn256-table"))]
use crate::impl_from_u64;

use crate::arithmetic::{adc, mac, sbb};
use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
Expand Down Expand Up @@ -152,6 +167,28 @@ field_common!(
);
impl_sum_prod!(Fr);

#[cfg(not(feature = "bn256-table"))]
impl_from_u64!(Fr, R2);
#[cfg(feature = "bn256-table")]
// A field element is represented in the montgomery form -- this allows for cheap mul_mod operations.
// The catch is, if we build an Fr element, regardless of its format, we need to perform one big integer multiplication:
//
// Fr([val, 0, 0, 0]) * R2
//
// When the "bn256-table" feature is enabled, we read the Fr element directly from the table.
// This avoids a big integer multiplication.
//
// We use a table with 2^16 entries when the element is smaller than 2^16.
impl From<u64> for Fr {
fn from(val: u64) -> Fr {
if val < 65536 {
FR_TABLE[val as usize]
} else {
Fr([val, 0, 0, 0]) * R2
}
}
}

#[cfg(not(feature = "asm"))]
field_arithmetic!(Fr, MODULUS, INV, sparse);
#[cfg(feature = "asm")]
Expand Down Expand Up @@ -402,8 +439,7 @@ mod test {
0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06,
0xbc, 0xe5,
]);
let _message = "serialization fr";
let start = start_timer!(|| _message);
let start = start_timer!(|| "serialize fr");
// failure check
for _ in 0..1000000 {
let rand_word = [(); 4].map(|_| rng.next_u64());
Expand All @@ -420,4 +456,16 @@ mod test {
}
end_timer!(start);
}

#[test]
fn bench_fr_from_u16() {
let repeat = 10000000;
let mut rng = ark_std::test_rng();
let base = (0..repeat).map(|_| (rng.next_u32() % (1 << 16)) as u64);

let timer = start_timer!(|| format!("generate {} Bn256 scalar field elements", repeat));
let _res: Vec<_> = base.map(|b| Fr::from(b)).collect();

end_timer!(timer);
}
}
8 changes: 8 additions & 0 deletions src/bn256/fr/table_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use crate::bn256::{Fr, FR_TABLE};

#[test]
fn test_table() {
for (i, e) in FR_TABLE.iter().enumerate() {
assert_eq!(Fr::from(i as u64), *e);
}
}
17 changes: 11 additions & 6 deletions src/derive/field.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
#[macro_export]
macro_rules! impl_from_u64 {
($field:ident, $r2:ident) => {
impl From<u64> for $field {
fn from(val: u64) -> $field {
$field([val, 0, 0, 0]) * $r2
}
}
};
}

#[macro_export]
macro_rules! field_common {
(
Expand Down Expand Up @@ -170,12 +181,6 @@ macro_rules! field_common {
}
}

impl From<u64> for $field {
fn from(val: u64) -> $field {
$field([val, 0, 0, 0]) * $r2
}
}

impl ConstantTimeEq for $field {
fn ct_eq(&self, other: &Self) -> Choice {
self.0[0].ct_eq(&other.0[0])
Expand Down
3 changes: 2 additions & 1 deletion src/secp256k1/fp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
field_arithmetic, field_bits, field_common, field_specific, impl_add_binop_specify_output,
impl_binops_additive, impl_binops_additive_specify_output, impl_binops_multiplicative,
impl_binops_multiplicative_mixed, impl_sub_binop_specify_output, impl_sum_prod,
impl_binops_multiplicative_mixed, impl_from_u64, impl_sub_binop_specify_output, impl_sum_prod,
};
use core::convert::TryInto;
use core::fmt;
Expand Down Expand Up @@ -126,6 +126,7 @@ field_common!(
R2,
R3
);
impl_from_u64!(Fp, R2);
field_arithmetic!(Fp, MODULUS, INV, dense);
impl_sum_prod!(Fp);

Expand Down
3 changes: 2 additions & 1 deletion src/secp256k1/fq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use crate::{
field_arithmetic, field_bits, field_common, field_specific, impl_add_binop_specify_output,
impl_binops_additive, impl_binops_additive_specify_output, impl_binops_multiplicative,
impl_binops_multiplicative_mixed, impl_sub_binop_specify_output, impl_sum_prod,
impl_binops_multiplicative_mixed, impl_from_u64, impl_sub_binop_specify_output, impl_sum_prod,
};
use core::convert::TryInto;
use core::fmt;
Expand Down Expand Up @@ -138,6 +138,7 @@ field_common!(
R2,
R3
);
impl_from_u64!(Fq, R2);
field_arithmetic!(Fq, MODULUS, INV, dense);
impl_sum_prod!(Fq);

Expand Down

0 comments on commit e97adcb

Please sign in to comment.