diff --git a/light-poseidon/Cargo.toml b/light-poseidon/Cargo.toml index eb9fcfc..4ef4f21 100644 --- a/light-poseidon/Cargo.toml +++ b/light-poseidon/Cargo.toml @@ -12,6 +12,7 @@ edition = "2021" [dependencies] ark-bn254 = "0.4.0" ark-ff = "0.4.0" +num-bigint = "0.4.4" thiserror = "1.0" [dev-dependencies] diff --git a/light-poseidon/src/lib.rs b/light-poseidon/src/lib.rs index 5219911..74a150a 100644 --- a/light-poseidon/src/lib.rs +++ b/light-poseidon/src/lib.rs @@ -156,12 +156,16 @@ pub enum PoseidonError { len: usize, modulus_bytes_len: usize, }, + #[error("Failed to convert bytes {bytes:?} into a prime field element")] + BytesToPrimeFieldElement { bytes: Vec }, #[error("Input is larger than the modulus of the prime field.")] InputLargerThanModulus, #[error("Failed to convert a vector of bytes into an array.")] VecToArray, #[error("Failed to convert the number of inputs from u64 to u8.")] U64Tou8, + #[error("Failed to convert bytes to BigInt")] + BytesToBigInt, #[error("Invalid width: {width}. Choose a width between 2 and 16 for 1 to 15 inputs.")] InvalidWidthCircom { width: usize, max_limit: usize }, } @@ -413,49 +417,32 @@ impl PoseidonHasher for Poseidon { } } -impl PoseidonBytesHasher for Poseidon { - fn hash_bytes_le(&mut self, inputs: &[&[u8]]) -> Result<[u8; HASH_LEN], PoseidonError> { - let inputs: Result, _> = inputs - .iter() - .map(|input| validate_bytes_length::(input)) - .collect(); - let inputs = inputs?; - let inputs: Result, _> = inputs - .iter() - .map(|input| bytes_to_prime_field_element(input)) - .collect(); - let inputs = inputs?; - let hash = self.hash(&inputs)?; - - hash.into_bigint() - .to_bytes_le() - .try_into() - .map_err(|_| PoseidonError::VecToArray) - } +macro_rules! impl_hash_bytes { + ($fn_name:ident, $bytes_to_prime_field_element_fn:ident, $to_bytes_fn:ident) => { + fn $fn_name(&mut self, inputs: &[&[u8]]) -> Result<[u8; HASH_LEN], PoseidonError> { + let inputs: Result, _> = inputs + .iter() + .map(|input| validate_bytes_length::(input)) + .collect(); + let inputs = inputs?; + let inputs: Result, _> = inputs + .iter() + .map(|input| $bytes_to_prime_field_element_fn(input)) + .collect(); + let inputs = inputs?; + let hash = self.hash(&inputs)?; - fn hash_bytes_be(&mut self, inputs: &[&[u8]]) -> Result<[u8; HASH_LEN], PoseidonError> { - let inputs: Result, _> = inputs - .iter() - .map(|input| validate_bytes_length::(input)) - .collect(); - let inputs = inputs?; - let inputs: Result, _> = inputs - .iter() - .map(|input| { - let mut input = input.to_vec(); - input.reverse(); - input - }) - .map(|input| bytes_to_prime_field_element(input.as_slice())) - .collect(); - let inputs = inputs?; - let hash = self.hash(&inputs)?; + hash.into_bigint() + .$to_bytes_fn() + .try_into() + .map_err(|_| PoseidonError::VecToArray) + } + }; +} - hash.into_bigint() - .to_bytes_be() - .try_into() - .map_err(|_| PoseidonError::VecToArray) - } +impl PoseidonBytesHasher for Poseidon { + impl_hash_bytes!(hash_bytes_le, bytes_to_prime_field_element_le, to_bytes_le); + impl_hash_bytes!(hash_bytes_be, bytes_to_prime_field_element_be, to_bytes_be); } /// Checks whether a slice of bytes is not empty or its length does not exceed @@ -485,15 +472,35 @@ where Ok(input) } -/// Converts a slice of bytes into a prime field element, represented by the -/// [`ark_ff::PrimeField`](ark_ff::PrimeField)) trait. -fn bytes_to_prime_field_element(input: &[u8]) -> Result -where - F: PrimeField, -{ - F::from_random_bytes(input).ok_or(PoseidonError::InputLargerThanModulus) +macro_rules! impl_bytes_to_prime_field_element { + ($name:ident, $from_bytes_method:ident, $endianess:expr) => { + #[doc = "Converts a slice of "] + #[doc = $endianess] + #[doc = "-endian bytes into a prime field element, \ + represented by the [`ark_ff::PrimeField`](ark_ff::PrimeField) trait."] + fn $name(input: &[u8]) -> Result + where + F: PrimeField, + { + let element = num_bigint::BigUint::$from_bytes_method(input); + let element = F::BigInt::try_from(element).map_err(|_| PoseidonError::BytesToBigInt)?; + + // In theory, `F::from_bigint` should also perform a check whether input is + // larger than modulus (and return `None` if it is), but it's not reliable... + // To be sure, we check it ourselves. + if element >= F::MODULUS { + return Err(PoseidonError::InputLargerThanModulus); + } + let element = F::from_bigint(element).ok_or(PoseidonError::InputLargerThanModulus)?; + + Ok(element) + } + }; } +impl_bytes_to_prime_field_element!(bytes_to_prime_field_element_le, from_bytes_le, "little"); +impl_bytes_to_prime_field_element!(bytes_to_prime_field_element_be, from_bytes_be, "big"); + impl Poseidon { pub fn new_circom(nr_inputs: usize) -> Result, PoseidonError> { Self::with_domain_tag_circom(nr_inputs, Fr::zero()) diff --git a/light-poseidon/tests/bn254_fq_x5.rs b/light-poseidon/tests/bn254_fq_x5.rs index 14c79eb..f8b791d 100644 --- a/light-poseidon/tests/bn254_fq_x5.rs +++ b/light-poseidon/tests/bn254_fq_x5.rs @@ -224,13 +224,13 @@ test_invalid_input_length!( hash_bytes_le ); -macro_rules! test_fuzz_input_gt_field_size { +macro_rules! test_fuzz_input_gte_field_size { ($name:ident, $method:ident, $to_bytes_method:ident) => { #[test] fn $name() { let mut greater_than_field_size = Fr::MODULUS; let mut rng = rand::thread_rng(); - let random_number = rng.gen_range(1u64..1_000_000u64); + let random_number = rng.gen_range(0u64..1_000_000u64); greater_than_field_size.add_with_carry(&BigInteger256::from(random_number)); let greater_than_field_size = greater_than_field_size.$to_bytes_method(); @@ -251,19 +251,19 @@ macro_rules! test_fuzz_input_gt_field_size { }; } -test_fuzz_input_gt_field_size!( +test_fuzz_input_gte_field_size!( test_fuzz_poseidon_bn254_fq_hash_bytes_be_input_gt_field_size, hash_bytes_be, to_bytes_be ); -test_fuzz_input_gt_field_size!( +test_fuzz_input_gte_field_size!( test_fuzz_poseidon_bn254_fq_hash_bytes_le_input_gt_field_size, hash_bytes_le, to_bytes_le ); -macro_rules! test_input_gt_field_size { +macro_rules! test_input_gte_field_size { ($name:ident, $method:ident, $greater_than_field_size:expr) => { #[test] fn $name() { @@ -282,7 +282,25 @@ macro_rules! test_input_gt_field_size { }; } -test_input_gt_field_size!( +test_input_gte_field_size!( + test_poseidon_bn254_fq_hash_bytes_be_input_gt_field_size_our_check, + hash_bytes_be, + [ + 216, 137, 85, 159, 239, 194, 107, 138, 254, 68, 21, 16, 165, 41, 64, 148, 208, 198, 201, + 59, 220, 102, 142, 81, 49, 251, 174, 183, 183, 182, 4, 32, + ] +); + +test_input_gte_field_size!( + test_poseidon_bn254_fq_hash_bytes_le_input_gt_field_size_our_check, + hash_bytes_le, + [ + 32, 4, 182, 183, 183, 174, 251, 49, 81, 142, 102, 220, 59, 201, 198, 208, 148, 64, 41, 165, + 16, 21, 68, 254, 138, 107, 194, 239, 159, 85, 137, 216, + ] +); + +test_input_gte_field_size!( test_poseidon_bn254_fq_hash_bytes_be_input_gt_field_size, hash_bytes_be, [ @@ -291,7 +309,7 @@ test_input_gt_field_size!( ] ); -test_input_gt_field_size!( +test_input_gte_field_size!( test_poseidon_bn254_fq_hash_bytes_le_input_gt_field_size, hash_bytes_le, [ @@ -300,6 +318,22 @@ test_input_gt_field_size!( ] ); +#[test] +fn test_input_eq_field_size_be() { + let mut hasher = Poseidon::::new_circom(1).unwrap(); + let input = Fr::MODULUS.to_bytes_be(); + let hash = hasher.hash_bytes_be(&[&input]); + assert_eq!(hash, Err(PoseidonError::InputLargerThanModulus)); +} + +#[test] +fn test_input_eq_field_size_le() { + let mut hasher = Poseidon::::new_circom(1).unwrap(); + let input = Fr::MODULUS.to_bytes_le(); + let hash = hasher.hash_bytes_le(&[&input]); + assert_eq!(hash, Err(PoseidonError::InputLargerThanModulus)); +} + #[test] fn test_endianness() { let mut hasher = Poseidon::::new_circom(2).unwrap();