Skip to content

Commit

Permalink
fix: Ensure that input doesn't exceed the modulus, this time for real (
Browse files Browse the repository at this point in the history
…#37)

Apparently, `from_random_bytes` doesn't do it reliably. An another
alternative we considered from ark-ff is `F::from_bigint`, but it
can panic...

Therefore, it's just better if we convert an array to `BigUint`
(and then to `F::BigInt`) and do the modulus check ourselves.

Fixes: #36
  • Loading branch information
vadorovsky authored Oct 31, 2023
1 parent a06ab4f commit 9746e79
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 55 deletions.
1 change: 1 addition & 0 deletions light-poseidon/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ edition = "2021"
[dependencies]
ark-bn254 = "0.4.0"
ark-ff = "0.4.0"
num-bigint = "0.4.4"
thiserror = "1.0"

[dev-dependencies]
Expand Down
103 changes: 55 additions & 48 deletions light-poseidon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,16 @@ pub enum PoseidonError {
len: usize,
modulus_bytes_len: usize,
},
#[error("Failed to convert bytes {bytes:?} into a prime field element")]
BytesToPrimeFieldElement { bytes: Vec<u8> },
#[error("Input is larger than the modulus of the prime field.")]
InputLargerThanModulus,
#[error("Failed to convert a vector of bytes into an array.")]
VecToArray,
#[error("Failed to convert the number of inputs from u64 to u8.")]
U64Tou8,
#[error("Failed to convert bytes to BigInt")]
BytesToBigInt,
#[error("Invalid width: {width}. Choose a width between 2 and 16 for 1 to 15 inputs.")]
InvalidWidthCircom { width: usize, max_limit: usize },
}
Expand Down Expand Up @@ -413,49 +417,32 @@ impl<F: PrimeField> PoseidonHasher<F> for Poseidon<F> {
}
}

impl<F: PrimeField> PoseidonBytesHasher for Poseidon<F> {
fn hash_bytes_le(&mut self, inputs: &[&[u8]]) -> Result<[u8; HASH_LEN], PoseidonError> {
let inputs: Result<Vec<_>, _> = inputs
.iter()
.map(|input| validate_bytes_length::<F>(input))
.collect();
let inputs = inputs?;
let inputs: Result<Vec<_>, _> = inputs
.iter()
.map(|input| bytes_to_prime_field_element(input))
.collect();
let inputs = inputs?;
let hash = self.hash(&inputs)?;

hash.into_bigint()
.to_bytes_le()
.try_into()
.map_err(|_| PoseidonError::VecToArray)
}
macro_rules! impl_hash_bytes {
($fn_name:ident, $bytes_to_prime_field_element_fn:ident, $to_bytes_fn:ident) => {
fn $fn_name(&mut self, inputs: &[&[u8]]) -> Result<[u8; HASH_LEN], PoseidonError> {
let inputs: Result<Vec<_>, _> = inputs
.iter()
.map(|input| validate_bytes_length::<F>(input))
.collect();
let inputs = inputs?;
let inputs: Result<Vec<_>, _> = inputs
.iter()
.map(|input| $bytes_to_prime_field_element_fn(input))
.collect();
let inputs = inputs?;
let hash = self.hash(&inputs)?;

fn hash_bytes_be(&mut self, inputs: &[&[u8]]) -> Result<[u8; HASH_LEN], PoseidonError> {
let inputs: Result<Vec<_>, _> = inputs
.iter()
.map(|input| validate_bytes_length::<F>(input))
.collect();
let inputs = inputs?;
let inputs: Result<Vec<_>, _> = inputs
.iter()
.map(|input| {
let mut input = input.to_vec();
input.reverse();
input
})
.map(|input| bytes_to_prime_field_element(input.as_slice()))
.collect();
let inputs = inputs?;
let hash = self.hash(&inputs)?;
hash.into_bigint()
.$to_bytes_fn()
.try_into()
.map_err(|_| PoseidonError::VecToArray)
}
};
}

hash.into_bigint()
.to_bytes_be()
.try_into()
.map_err(|_| PoseidonError::VecToArray)
}
impl<F: PrimeField> PoseidonBytesHasher for Poseidon<F> {
impl_hash_bytes!(hash_bytes_le, bytes_to_prime_field_element_le, to_bytes_le);
impl_hash_bytes!(hash_bytes_be, bytes_to_prime_field_element_be, to_bytes_be);
}

/// Checks whether a slice of bytes is not empty or its length does not exceed
Expand Down Expand Up @@ -485,15 +472,35 @@ where
Ok(input)
}

/// Converts a slice of bytes into a prime field element, represented by the
/// [`ark_ff::PrimeField`](ark_ff::PrimeField)) trait.
fn bytes_to_prime_field_element<F>(input: &[u8]) -> Result<F, PoseidonError>
where
F: PrimeField,
{
F::from_random_bytes(input).ok_or(PoseidonError::InputLargerThanModulus)
macro_rules! impl_bytes_to_prime_field_element {
($name:ident, $from_bytes_method:ident, $endianess:expr) => {
#[doc = "Converts a slice of "]
#[doc = $endianess]
#[doc = "-endian bytes into a prime field element, \
represented by the [`ark_ff::PrimeField`](ark_ff::PrimeField) trait."]
fn $name<F>(input: &[u8]) -> Result<F, PoseidonError>
where
F: PrimeField,
{
let element = num_bigint::BigUint::$from_bytes_method(input);
let element = F::BigInt::try_from(element).map_err(|_| PoseidonError::BytesToBigInt)?;

// In theory, `F::from_bigint` should also perform a check whether input is
// larger than modulus (and return `None` if it is), but it's not reliable...
// To be sure, we check it ourselves.
if element >= F::MODULUS {
return Err(PoseidonError::InputLargerThanModulus);
}
let element = F::from_bigint(element).ok_or(PoseidonError::InputLargerThanModulus)?;

Ok(element)
}
};
}

impl_bytes_to_prime_field_element!(bytes_to_prime_field_element_le, from_bytes_le, "little");
impl_bytes_to_prime_field_element!(bytes_to_prime_field_element_be, from_bytes_be, "big");

impl<F: PrimeField> Poseidon<F> {
pub fn new_circom(nr_inputs: usize) -> Result<Poseidon<Fr>, PoseidonError> {
Self::with_domain_tag_circom(nr_inputs, Fr::zero())
Expand Down
48 changes: 41 additions & 7 deletions light-poseidon/tests/bn254_fq_x5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ test_invalid_input_length!(
hash_bytes_le
);

macro_rules! test_fuzz_input_gt_field_size {
macro_rules! test_fuzz_input_gte_field_size {
($name:ident, $method:ident, $to_bytes_method:ident) => {
#[test]
fn $name() {
let mut greater_than_field_size = Fr::MODULUS;
let mut rng = rand::thread_rng();
let random_number = rng.gen_range(1u64..1_000_000u64);
let random_number = rng.gen_range(0u64..1_000_000u64);
greater_than_field_size.add_with_carry(&BigInteger256::from(random_number));
let greater_than_field_size = greater_than_field_size.$to_bytes_method();

Expand All @@ -251,19 +251,19 @@ macro_rules! test_fuzz_input_gt_field_size {
};
}

test_fuzz_input_gt_field_size!(
test_fuzz_input_gte_field_size!(
test_fuzz_poseidon_bn254_fq_hash_bytes_be_input_gt_field_size,
hash_bytes_be,
to_bytes_be
);

test_fuzz_input_gt_field_size!(
test_fuzz_input_gte_field_size!(
test_fuzz_poseidon_bn254_fq_hash_bytes_le_input_gt_field_size,
hash_bytes_le,
to_bytes_le
);

macro_rules! test_input_gt_field_size {
macro_rules! test_input_gte_field_size {
($name:ident, $method:ident, $greater_than_field_size:expr) => {
#[test]
fn $name() {
Expand All @@ -282,7 +282,25 @@ macro_rules! test_input_gt_field_size {
};
}

test_input_gt_field_size!(
test_input_gte_field_size!(
test_poseidon_bn254_fq_hash_bytes_be_input_gt_field_size_our_check,
hash_bytes_be,
[
216, 137, 85, 159, 239, 194, 107, 138, 254, 68, 21, 16, 165, 41, 64, 148, 208, 198, 201,
59, 220, 102, 142, 81, 49, 251, 174, 183, 183, 182, 4, 32,
]
);

test_input_gte_field_size!(
test_poseidon_bn254_fq_hash_bytes_le_input_gt_field_size_our_check,
hash_bytes_le,
[
32, 4, 182, 183, 183, 174, 251, 49, 81, 142, 102, 220, 59, 201, 198, 208, 148, 64, 41, 165,
16, 21, 68, 254, 138, 107, 194, 239, 159, 85, 137, 216,
]
);

test_input_gte_field_size!(
test_poseidon_bn254_fq_hash_bytes_be_input_gt_field_size,
hash_bytes_be,
[
Expand All @@ -291,7 +309,7 @@ test_input_gt_field_size!(
]
);

test_input_gt_field_size!(
test_input_gte_field_size!(
test_poseidon_bn254_fq_hash_bytes_le_input_gt_field_size,
hash_bytes_le,
[
Expand All @@ -300,6 +318,22 @@ test_input_gt_field_size!(
]
);

#[test]
fn test_input_eq_field_size_be() {
let mut hasher = Poseidon::<Fr>::new_circom(1).unwrap();
let input = Fr::MODULUS.to_bytes_be();
let hash = hasher.hash_bytes_be(&[&input]);
assert_eq!(hash, Err(PoseidonError::InputLargerThanModulus));
}

#[test]
fn test_input_eq_field_size_le() {
let mut hasher = Poseidon::<Fr>::new_circom(1).unwrap();
let input = Fr::MODULUS.to_bytes_le();
let hash = hasher.hash_bytes_le(&[&input]);
assert_eq!(hash, Err(PoseidonError::InputLargerThanModulus));
}

#[test]
fn test_endianness() {
let mut hasher = Poseidon::<Fr>::new_circom(2).unwrap();
Expand Down

0 comments on commit 9746e79

Please sign in to comment.