From 3f693c7d110c2700961a0073d29062b9a12d64ea Mon Sep 17 00:00:00 2001 From: georgeee Date: Mon, 16 Sep 2024 05:01:09 +0000 Subject: [PATCH] Compute poseidon with par_iter --- poseidon/src/permutation.rs | 212 ++++++++++++++++++------------------ 1 file changed, 107 insertions(+), 105 deletions(-) diff --git a/poseidon/src/permutation.rs b/poseidon/src/permutation.rs index 1d1595666e..c30543f412 100644 --- a/poseidon/src/permutation.rs +++ b/poseidon/src/permutation.rs @@ -1,105 +1,107 @@ -//! The permutation module contains the function implementing the permutation used in Poseidon - -use crate::constants::SpongeConstants; -use crate::poseidon::{sbox, ArithmeticSpongeParams}; -use ark_ff::Field; - -fn apply_mds_matrix( - params: &ArithmeticSpongeParams, - state: &[F], -) -> Vec { - if SC::PERM_FULL_MDS { - params - .mds - .iter() - .map(|m| { - state - .iter() - .zip(m.iter()) - .fold(F::zero(), |x, (s, &m)| m * s + x) - }) - .collect() - } else { - vec![ - state[0] + state[2], - state[0] + state[1], - state[1] + state[2], - ] - } -} - -pub fn full_round( - params: &ArithmeticSpongeParams, - state: &mut Vec, - r: usize, -) { - for state_i in state.iter_mut() { - *state_i = sbox::(*state_i); - } - *state = apply_mds_matrix::(params, state); - for (i, x) in params.round_constants[r].iter().enumerate() { - state[i].add_assign(x); - } -} - -pub fn half_rounds( - params: &ArithmeticSpongeParams, - state: &mut [F], -) { - for r in 0..SC::PERM_HALF_ROUNDS_FULL { - for (i, x) in params.round_constants[r].iter().enumerate() { - state[i].add_assign(x); - } - for state_i in state.iter_mut() { - *state_i = sbox::(*state_i); - } - apply_mds_matrix::(params, state); - } - - for r in 0..SC::PERM_ROUNDS_PARTIAL { - for (i, x) in params.round_constants[SC::PERM_HALF_ROUNDS_FULL + r] - .iter() - .enumerate() - { - state[i].add_assign(x); - } - state[0] = sbox::(state[0]); - apply_mds_matrix::(params, state); - } - - for r in 0..SC::PERM_HALF_ROUNDS_FULL { - for (i, x) in params.round_constants - [SC::PERM_HALF_ROUNDS_FULL + SC::PERM_ROUNDS_PARTIAL + r] - .iter() - .enumerate() - { - state[i].add_assign(x); - } - for state_i in state.iter_mut() { - *state_i = sbox::(*state_i); - } - apply_mds_matrix::(params, state); - } -} - -pub fn poseidon_block_cipher( - params: &ArithmeticSpongeParams, - state: &mut Vec, -) { - if SC::PERM_HALF_ROUNDS_FULL == 0 { - if SC::PERM_INITIAL_ARK { - for (i, x) in params.round_constants[0].iter().enumerate() { - state[i].add_assign(x); - } - for r in 0..SC::PERM_ROUNDS_FULL { - full_round::(params, state, r + 1); - } - } else { - for r in 0..SC::PERM_ROUNDS_FULL { - full_round::(params, state, r); - } - } - } else { - half_rounds::(params, state); - } -} +//! The permutation module contains the function implementing the permutation used in Poseidon + +use crate::constants::SpongeConstants; +use crate::poseidon::{sbox, ArithmeticSpongeParams}; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use ark_ff::Field; + +fn apply_mds_matrix( + params: &ArithmeticSpongeParams, + state: &[F], +) -> Vec { + if SC::PERM_FULL_MDS { + params + .mds + .par_iter() + .map(|m| { + state + .par_iter() + .zip(m.par_iter()) + .map(|(s, &m)| m * s) + .sum::() + }) + .collect() + } else { + vec![ + state[0] + state[2], + state[0] + state[1], + state[1] + state[2], + ] + } +} + +pub fn full_round( + params: &ArithmeticSpongeParams, + state: &mut Vec, + r: usize, +) { + for state_i in state.iter_mut() { + *state_i = sbox::(*state_i); + } + *state = apply_mds_matrix::(params, state); + for (i, x) in params.round_constants[r].iter().enumerate() { + state[i].add_assign(x); + } +} + +pub fn half_rounds( + params: &ArithmeticSpongeParams, + state: &mut [F], +) { + for r in 0..SC::PERM_HALF_ROUNDS_FULL { + for (i, x) in params.round_constants[r].iter().enumerate() { + state[i].add_assign(x); + } + for state_i in state.iter_mut() { + *state_i = sbox::(*state_i); + } + apply_mds_matrix::(params, state); + } + + for r in 0..SC::PERM_ROUNDS_PARTIAL { + for (i, x) in params.round_constants[SC::PERM_HALF_ROUNDS_FULL + r] + .iter() + .enumerate() + { + state[i].add_assign(x); + } + state[0] = sbox::(state[0]); + apply_mds_matrix::(params, state); + } + + for r in 0..SC::PERM_HALF_ROUNDS_FULL { + for (i, x) in params.round_constants + [SC::PERM_HALF_ROUNDS_FULL + SC::PERM_ROUNDS_PARTIAL + r] + .iter() + .enumerate() + { + state[i].add_assign(x); + } + for state_i in state.iter_mut() { + *state_i = sbox::(*state_i); + } + apply_mds_matrix::(params, state); + } +} + +pub fn poseidon_block_cipher( + params: &ArithmeticSpongeParams, + state: &mut Vec, +) { + if SC::PERM_HALF_ROUNDS_FULL == 0 { + if SC::PERM_INITIAL_ARK { + for (i, x) in params.round_constants[0].iter().enumerate() { + state[i].add_assign(x); + } + for r in 0..SC::PERM_ROUNDS_FULL { + full_round::(params, state, r + 1); + } + } else { + for r in 0..SC::PERM_ROUNDS_FULL { + full_round::(params, state, r); + } + } + } else { + half_rounds::(params, state); + } +}