From 9746e79ce547c45ef137d6767e8931b4c233a68d Mon Sep 17 00:00:00 2001
From: vadorovsky <vadorovsky@protonmail.com>
Date: Tue, 31 Oct 2023 20:58:27 +0100
Subject: [PATCH] fix: Ensure that input doesn't exceed the modulus, this time
 for real (#37)

Apparently, `from_random_bytes` doesn't do it reliably. An another
alternative we considered from ark-ff is `F::from_bigint`, but it
can panic...

Therefore, it's just better if we convert an array to `BigUint`
(and then to `F::BigInt`) and do the modulus check ourselves.

Fixes: #36
---
 light-poseidon/Cargo.toml           |   1 +
 light-poseidon/src/lib.rs           | 103 +++++++++++++++-------------
 light-poseidon/tests/bn254_fq_x5.rs |  48 +++++++++++--
 3 files changed, 97 insertions(+), 55 deletions(-)

diff --git a/light-poseidon/Cargo.toml b/light-poseidon/Cargo.toml
index eb9fcfc..4ef4f21 100644
--- a/light-poseidon/Cargo.toml
+++ b/light-poseidon/Cargo.toml
@@ -12,6 +12,7 @@ edition = "2021"
 [dependencies]
 ark-bn254 = "0.4.0"
 ark-ff = "0.4.0"
+num-bigint = "0.4.4"
 thiserror = "1.0"
 
 [dev-dependencies]
diff --git a/light-poseidon/src/lib.rs b/light-poseidon/src/lib.rs
index 5219911..74a150a 100644
--- a/light-poseidon/src/lib.rs
+++ b/light-poseidon/src/lib.rs
@@ -156,12 +156,16 @@ pub enum PoseidonError {
         len: usize,
         modulus_bytes_len: usize,
     },
+    #[error("Failed to convert bytes {bytes:?} into a prime field element")]
+    BytesToPrimeFieldElement { bytes: Vec<u8> },
     #[error("Input is larger than the modulus of the prime field.")]
     InputLargerThanModulus,
     #[error("Failed to convert a vector of bytes into an array.")]
     VecToArray,
     #[error("Failed to convert the number of inputs from u64 to u8.")]
     U64Tou8,
+    #[error("Failed to convert bytes to BigInt")]
+    BytesToBigInt,
     #[error("Invalid width: {width}. Choose a width between 2 and 16 for 1 to 15 inputs.")]
     InvalidWidthCircom { width: usize, max_limit: usize },
 }
@@ -413,49 +417,32 @@ impl<F: PrimeField> PoseidonHasher<F> for Poseidon<F> {
     }
 }
 
-impl<F: PrimeField> PoseidonBytesHasher for Poseidon<F> {
-    fn hash_bytes_le(&mut self, inputs: &[&[u8]]) -> Result<[u8; HASH_LEN], PoseidonError> {
-        let inputs: Result<Vec<_>, _> = inputs
-            .iter()
-            .map(|input| validate_bytes_length::<F>(input))
-            .collect();
-        let inputs = inputs?;
-        let inputs: Result<Vec<_>, _> = inputs
-            .iter()
-            .map(|input| bytes_to_prime_field_element(input))
-            .collect();
-        let inputs = inputs?;
-        let hash = self.hash(&inputs)?;
-
-        hash.into_bigint()
-            .to_bytes_le()
-            .try_into()
-            .map_err(|_| PoseidonError::VecToArray)
-    }
+macro_rules! impl_hash_bytes {
+    ($fn_name:ident, $bytes_to_prime_field_element_fn:ident, $to_bytes_fn:ident) => {
+        fn $fn_name(&mut self, inputs: &[&[u8]]) -> Result<[u8; HASH_LEN], PoseidonError> {
+            let inputs: Result<Vec<_>, _> = inputs
+                .iter()
+                .map(|input| validate_bytes_length::<F>(input))
+                .collect();
+            let inputs = inputs?;
+            let inputs: Result<Vec<_>, _> = inputs
+                .iter()
+                .map(|input| $bytes_to_prime_field_element_fn(input))
+                .collect();
+            let inputs = inputs?;
+            let hash = self.hash(&inputs)?;
 
-    fn hash_bytes_be(&mut self, inputs: &[&[u8]]) -> Result<[u8; HASH_LEN], PoseidonError> {
-        let inputs: Result<Vec<_>, _> = inputs
-            .iter()
-            .map(|input| validate_bytes_length::<F>(input))
-            .collect();
-        let inputs = inputs?;
-        let inputs: Result<Vec<_>, _> = inputs
-            .iter()
-            .map(|input| {
-                let mut input = input.to_vec();
-                input.reverse();
-                input
-            })
-            .map(|input| bytes_to_prime_field_element(input.as_slice()))
-            .collect();
-        let inputs = inputs?;
-        let hash = self.hash(&inputs)?;
+            hash.into_bigint()
+                .$to_bytes_fn()
+                .try_into()
+                .map_err(|_| PoseidonError::VecToArray)
+        }
+    };
+}
 
-        hash.into_bigint()
-            .to_bytes_be()
-            .try_into()
-            .map_err(|_| PoseidonError::VecToArray)
-    }
+impl<F: PrimeField> PoseidonBytesHasher for Poseidon<F> {
+    impl_hash_bytes!(hash_bytes_le, bytes_to_prime_field_element_le, to_bytes_le);
+    impl_hash_bytes!(hash_bytes_be, bytes_to_prime_field_element_be, to_bytes_be);
 }
 
 /// Checks whether a slice of bytes is not empty or its length does not exceed
@@ -485,15 +472,35 @@ where
     Ok(input)
 }
 
-/// Converts a slice of bytes into a prime field element, represented by the
-/// [`ark_ff::PrimeField`](ark_ff::PrimeField)) trait.
-fn bytes_to_prime_field_element<F>(input: &[u8]) -> Result<F, PoseidonError>
-where
-    F: PrimeField,
-{
-    F::from_random_bytes(input).ok_or(PoseidonError::InputLargerThanModulus)
+macro_rules! impl_bytes_to_prime_field_element {
+    ($name:ident, $from_bytes_method:ident, $endianess:expr) => {
+        #[doc = "Converts a slice of "]
+        #[doc = $endianess]
+        #[doc = "-endian bytes into a prime field element, \
+                 represented by the [`ark_ff::PrimeField`](ark_ff::PrimeField) trait."]
+        fn $name<F>(input: &[u8]) -> Result<F, PoseidonError>
+        where
+            F: PrimeField,
+        {
+            let element = num_bigint::BigUint::$from_bytes_method(input);
+            let element = F::BigInt::try_from(element).map_err(|_| PoseidonError::BytesToBigInt)?;
+
+            // In theory, `F::from_bigint` should also perform a check whether input is
+            // larger than modulus (and return `None` if it is), but it's not reliable...
+            // To be sure, we check it ourselves.
+            if element >= F::MODULUS {
+                return Err(PoseidonError::InputLargerThanModulus);
+            }
+            let element = F::from_bigint(element).ok_or(PoseidonError::InputLargerThanModulus)?;
+
+            Ok(element)
+        }
+    };
 }
 
+impl_bytes_to_prime_field_element!(bytes_to_prime_field_element_le, from_bytes_le, "little");
+impl_bytes_to_prime_field_element!(bytes_to_prime_field_element_be, from_bytes_be, "big");
+
 impl<F: PrimeField> Poseidon<F> {
     pub fn new_circom(nr_inputs: usize) -> Result<Poseidon<Fr>, PoseidonError> {
         Self::with_domain_tag_circom(nr_inputs, Fr::zero())
diff --git a/light-poseidon/tests/bn254_fq_x5.rs b/light-poseidon/tests/bn254_fq_x5.rs
index 14c79eb..f8b791d 100644
--- a/light-poseidon/tests/bn254_fq_x5.rs
+++ b/light-poseidon/tests/bn254_fq_x5.rs
@@ -224,13 +224,13 @@ test_invalid_input_length!(
     hash_bytes_le
 );
 
-macro_rules! test_fuzz_input_gt_field_size {
+macro_rules! test_fuzz_input_gte_field_size {
     ($name:ident, $method:ident, $to_bytes_method:ident) => {
         #[test]
         fn $name() {
             let mut greater_than_field_size = Fr::MODULUS;
             let mut rng = rand::thread_rng();
-            let random_number = rng.gen_range(1u64..1_000_000u64);
+            let random_number = rng.gen_range(0u64..1_000_000u64);
             greater_than_field_size.add_with_carry(&BigInteger256::from(random_number));
             let greater_than_field_size = greater_than_field_size.$to_bytes_method();
 
@@ -251,19 +251,19 @@ macro_rules! test_fuzz_input_gt_field_size {
     };
 }
 
-test_fuzz_input_gt_field_size!(
+test_fuzz_input_gte_field_size!(
     test_fuzz_poseidon_bn254_fq_hash_bytes_be_input_gt_field_size,
     hash_bytes_be,
     to_bytes_be
 );
 
-test_fuzz_input_gt_field_size!(
+test_fuzz_input_gte_field_size!(
     test_fuzz_poseidon_bn254_fq_hash_bytes_le_input_gt_field_size,
     hash_bytes_le,
     to_bytes_le
 );
 
-macro_rules! test_input_gt_field_size {
+macro_rules! test_input_gte_field_size {
     ($name:ident, $method:ident, $greater_than_field_size:expr) => {
         #[test]
         fn $name() {
@@ -282,7 +282,25 @@ macro_rules! test_input_gt_field_size {
     };
 }
 
-test_input_gt_field_size!(
+test_input_gte_field_size!(
+    test_poseidon_bn254_fq_hash_bytes_be_input_gt_field_size_our_check,
+    hash_bytes_be,
+    [
+        216, 137, 85, 159, 239, 194, 107, 138, 254, 68, 21, 16, 165, 41, 64, 148, 208, 198, 201,
+        59, 220, 102, 142, 81, 49, 251, 174, 183, 183, 182, 4, 32,
+    ]
+);
+
+test_input_gte_field_size!(
+    test_poseidon_bn254_fq_hash_bytes_le_input_gt_field_size_our_check,
+    hash_bytes_le,
+    [
+        32, 4, 182, 183, 183, 174, 251, 49, 81, 142, 102, 220, 59, 201, 198, 208, 148, 64, 41, 165,
+        16, 21, 68, 254, 138, 107, 194, 239, 159, 85, 137, 216,
+    ]
+);
+
+test_input_gte_field_size!(
     test_poseidon_bn254_fq_hash_bytes_be_input_gt_field_size,
     hash_bytes_be,
     [
@@ -291,7 +309,7 @@ test_input_gt_field_size!(
     ]
 );
 
-test_input_gt_field_size!(
+test_input_gte_field_size!(
     test_poseidon_bn254_fq_hash_bytes_le_input_gt_field_size,
     hash_bytes_le,
     [
@@ -300,6 +318,22 @@ test_input_gt_field_size!(
     ]
 );
 
+#[test]
+fn test_input_eq_field_size_be() {
+    let mut hasher = Poseidon::<Fr>::new_circom(1).unwrap();
+    let input = Fr::MODULUS.to_bytes_be();
+    let hash = hasher.hash_bytes_be(&[&input]);
+    assert_eq!(hash, Err(PoseidonError::InputLargerThanModulus));
+}
+
+#[test]
+fn test_input_eq_field_size_le() {
+    let mut hasher = Poseidon::<Fr>::new_circom(1).unwrap();
+    let input = Fr::MODULUS.to_bytes_le();
+    let hash = hasher.hash_bytes_le(&[&input]);
+    assert_eq!(hash, Err(PoseidonError::InputLargerThanModulus));
+}
+
 #[test]
 fn test_endianness() {
     let mut hasher = Poseidon::<Fr>::new_circom(2).unwrap();