Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement pow-by-constant with NAF for FpVar #72

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

### Improvements

- [\#72](https://github.com/arkworks-rs/r1cs-std/pull/72) Implement `pow_by_constant` with NAF for `FpVar`.

### Bug Fixes

- [\#70](https://github.com/arkworks-rs/r1cs-std/pull/70) Fix soundness issues of `mul_by_inverse` for field gadgets.
Expand Down
232 changes: 231 additions & 1 deletion src/fields/fp/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ark_ff::{BigInteger, FpParameters, PrimeField};
use ark_ff::{BigInteger, BitIteratorBE, FpParameters, PrimeField};
use ark_relations::r1cs::{
ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable,
};
Expand Down Expand Up @@ -764,6 +764,128 @@ impl<F: PrimeField> FieldVar<F, F> for FpVar<F> {
*self = self.frobenius_map(power)?;
Ok(self)
}

/// Computes `self^S`, where S is interpreted as an little-endian
/// u64-decomposition of an integer.
#[tracing::instrument(target = "r1cs", skip(exp))]
fn pow_by_constant<S: AsRef<[u64]>>(&self, exp: S) -> Result<Self, SynthesisError> {
use ark_ff::biginteger::arithmetic::find_wnaf;

// first check if exp = 0
let mut is_nonzero = false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we can move to a function in utils?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tricky since that the optimization here is very specific to constraints (e.g., count inversion at a cost of 3 constraints).

for limb in exp.as_ref() {
if *limb != 0u64 {
is_nonzero = true;
}
}

// handle the case when exp = 0
if !is_nonzero {
return Ok(FpVar::Constant(F::one()));
}

// now we consider the case when exp != 0

// if `self` is constant, we compute it directly.
if self.is_constant() {
return Ok(FpVar::Constant(self.value()?.pow(exp)));
}

// now we consider the case when exp != 0 and `self` is not a constant

// obtain the NAF representation
let naf_be = find_wnaf(exp.as_ref());
let found_minus_one_in_naf = naf_be.contains(&-1i64);

// now discuss whether or not we should use NAF
let mut use_naf = true;
let mut standard_be = None;

// if the NAF does not contain `-1`, it cannot be faster than the square-and-multiply
if !found_minus_one_in_naf {
use_naf = false;
standard_be = Some(BitIteratorBE::without_leading_zeros(&exp).collect::<Vec<bool>>());
}

// since NAF needs to compute the inverse, which incurs additional overhead,
// it might not be better than the standard square-and-multiply

if use_naf {
// obtain the standard representation
let standard_be_bits =
BitIteratorBE::without_leading_zeros(&exp).collect::<Vec<bool>>();

// compute the cost of the NAF representation
let mut naf_cost = naf_be.len() + naf_be.iter().filter(|x| **x != 0i64).count();
if found_minus_one_in_naf {
// computing the inverse_or_any incurs additional overhead
// two for computing the inverse-or-any, one for ensuring 0 ^ exp = 0
naf_cost += 3;
}

// compute the cost of the standard representation
let standard_cost =
standard_be_bits.len() + standard_be_bits.iter().filter(|x| **x == true).count();
Comment on lines +827 to +828
Copy link
Member

@ValarDragon ValarDragon Jul 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to make a pow_by_const trait, with two impls, but eventually three, naf, standard, and windowed?

Then we can have four methods, pow, pow_r1cs, pow_native_cost, and pow_r1cs_cost?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or does that seem like a premature abstraction? If so, I at least recommend splitting out the pow and r1cs_cost into separate functions for code readability & reuse in the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got the idea. It seems better to have a separate function that simply returns what is the best sequence, and then this function in fp/mod should just execute it.

I agree and will make the change. Let me think about where to put this function since it is naturally r1cs...

(Note: later it might also be used for group elements)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try to split it and will keep you posted.


if standard_cost <= naf_cost {
use_naf = false;
}

standard_be = Some(standard_be_bits);
}

if !use_naf {
// use simple square-and-multiple
let mut res = Self::one();
for i in standard_be.unwrap() {
res.square_in_place()?;
if i {
res *= self;
}
}
Ok(res)
} else {
// use NAF

// first compute `inverse_or_any`
// if `self` != 0, it implies that `self` * `inverse_or_any` = 1
// if `self` == 0, `inverse_or_any` can be any value
let self_inverse_or_any = {
let inverse_or_any = Self::new_witness(self.cs().clone(), || {
Ok(self.value()?.inverse().unwrap_or_else(F::zero))
})?;

// self * self = tmp
let tmp = self.square()?;

// tmp * inverse_or_any = self
tmp.mul_equals(&inverse_or_any, &self)?;

inverse_or_any
};

// the initial `res` = 1 if `self` != 0, or `res` = 0 if `self` == 0
let mut res = self * &self_inverse_or_any;

let mut found_non_zero = false;
for &value in naf_be.iter().rev() {
if found_non_zero {
res = res.square()?;
}

if value != 0 {
found_non_zero = true;
if value > 0 {
res *= self;
} else {
res *= &self_inverse_or_any;
}
}
}

Ok(res)
}
}
}

impl_ops!(
Expand Down Expand Up @@ -1091,3 +1213,111 @@ mod test {
assert_eq!(sum.value().unwrap(), sum_expected);
}
}

#[cfg(test)]
mod test_pow_by_constant {
use crate::alloc::AllocVar;
use crate::fields::fp::FpVar;
use crate::fields::FieldVar;
use crate::R1CSVar;
use ark_ff::Field;
use ark_relations::r1cs::ConstraintSystem;
use ark_std::{One, UniformRand, Zero};
use ark_test_curves::bls12_381::Fr;

#[test]
fn test_rand() {
let mut rng = ark_std::test_rng();
let cs = ConstraintSystem::new_ref();

let mut rand_base = Fr::rand(&mut rng);
// ensure that rand_base is not zero.
if rand_base == Fr::zero() {
rand_base = rand_base + &Fr::one();
}

let rand_exp = [
u64::rand(&mut rng),
u64::rand(&mut rng),
u64::rand(&mut rng),
u64::rand(&mut rng),
];

{
let rand_base_g = FpVar::<Fr>::new_witness(cs.clone(), || Ok(rand_base)).unwrap();
let res_expected = rand_base.pow(rand_exp);
let res = rand_base_g
.pow_by_constant(&rand_exp)
.unwrap()
.value()
.unwrap();
assert_eq!(res, res_expected);
}

{
let rand_base_g = FpVar::<Fr>::new_constant(cs.clone(), rand_base).unwrap();
let res_expected = rand_base.pow(rand_exp);
let res = rand_base_g
.pow_by_constant(&rand_exp)
.unwrap()
.value()
.unwrap();
assert_eq!(res, res_expected);
}

assert!(cs.is_satisfied().unwrap());
}

#[test]
fn test_zero_base() {
let cs = ConstraintSystem::new_ref();
let exp = [1u64, 2u64, 3u64, 4u64];

{
let base_g = FpVar::<Fr>::new_witness(cs.clone(), || Ok(Fr::zero())).unwrap();
let res_expected = Fr::zero();
let res = base_g.pow_by_constant(exp).unwrap().value().unwrap();
assert_eq!(res, res_expected);
}

{
let base_g = FpVar::<Fr>::new_constant(cs.clone(), Fr::zero()).unwrap();
let res_expected = Fr::zero();
let res = base_g.pow_by_constant(exp).unwrap().value().unwrap();
assert_eq!(res, res_expected);
}

assert!(cs.is_satisfied().unwrap());
}

#[test]
fn test_zero_exp() {
let mut rng = ark_std::test_rng();
let cs = ConstraintSystem::new_ref();
let exp = [0u64, 0u64, 0u64, 0u64];

let mut rand_base = Fr::rand(&mut rng);

// ensure that rand_base is not zero.
if rand_base == Fr::zero() {
rand_base = rand_base + &Fr::one();
}

{
let rand_base_g =
FpVar::<Fr>::new_witness(cs.clone(), || Ok(rand_base.clone())).unwrap();
let res_expected = Fr::one();
let res = rand_base_g.pow_by_constant(&exp).unwrap().value().unwrap();
assert_eq!(res, res_expected);
}

{
let rand_base_g = FpVar::<Fr>::new_constant(cs.clone(), rand_base).unwrap();
let res_expected = Fr::one();
let res = rand_base_g.pow_by_constant(&exp).unwrap().value().unwrap();
assert_eq!(res, res_expected);
}

assert!(cs.is_satisfied().unwrap());
}
}