Skip to content

Commit

Permalink
Compute poseidon with par_iter
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeee committed Sep 16, 2024
1 parent 23e0787 commit 3f693c7
Showing 1 changed file with 107 additions and 105 deletions.
212 changes: 107 additions & 105 deletions poseidon/src/permutation.rs
Original file line number Diff line number Diff line change
@@ -1,105 +1,107 @@
//! The permutation module contains the function implementing the permutation used in Poseidon

use crate::constants::SpongeConstants;
use crate::poseidon::{sbox, ArithmeticSpongeParams};
use ark_ff::Field;

fn apply_mds_matrix<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &[F],
) -> Vec<F> {
if SC::PERM_FULL_MDS {
params
.mds
.iter()
.map(|m| {
state
.iter()
.zip(m.iter())
.fold(F::zero(), |x, (s, &m)| m * s + x)
})
.collect()
} else {
vec![
state[0] + state[2],
state[0] + state[1],
state[1] + state[2],
]
}
}

pub fn full_round<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut Vec<F>,
r: usize,
) {
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
*state = apply_mds_matrix::<F, SC>(params, state);
for (i, x) in params.round_constants[r].iter().enumerate() {
state[i].add_assign(x);
}
}

pub fn half_rounds<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut [F],
) {
for r in 0..SC::PERM_HALF_ROUNDS_FULL {
for (i, x) in params.round_constants[r].iter().enumerate() {
state[i].add_assign(x);
}
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
apply_mds_matrix::<F, SC>(params, state);
}

for r in 0..SC::PERM_ROUNDS_PARTIAL {
for (i, x) in params.round_constants[SC::PERM_HALF_ROUNDS_FULL + r]
.iter()
.enumerate()
{
state[i].add_assign(x);
}
state[0] = sbox::<F, SC>(state[0]);
apply_mds_matrix::<F, SC>(params, state);
}

for r in 0..SC::PERM_HALF_ROUNDS_FULL {
for (i, x) in params.round_constants
[SC::PERM_HALF_ROUNDS_FULL + SC::PERM_ROUNDS_PARTIAL + r]
.iter()
.enumerate()
{
state[i].add_assign(x);
}
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
apply_mds_matrix::<F, SC>(params, state);
}
}

pub fn poseidon_block_cipher<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut Vec<F>,
) {
if SC::PERM_HALF_ROUNDS_FULL == 0 {
if SC::PERM_INITIAL_ARK {
for (i, x) in params.round_constants[0].iter().enumerate() {
state[i].add_assign(x);
}
for r in 0..SC::PERM_ROUNDS_FULL {
full_round::<F, SC>(params, state, r + 1);
}
} else {
for r in 0..SC::PERM_ROUNDS_FULL {
full_round::<F, SC>(params, state, r);
}
}
} else {
half_rounds::<F, SC>(params, state);
}
}
//! The permutation module contains the function implementing the permutation used in Poseidon

use crate::constants::SpongeConstants;
use crate::poseidon::{sbox, ArithmeticSpongeParams};
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
use ark_ff::Field;

fn apply_mds_matrix<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &[F],
) -> Vec<F> {
if SC::PERM_FULL_MDS {
params
.mds
.par_iter()
.map(|m| {
state
.par_iter()
.zip(m.par_iter())
.map(|(s, &m)| m * s)
.sum::<F>()
})
.collect()
} else {
vec![
state[0] + state[2],
state[0] + state[1],
state[1] + state[2],
]
}
}

pub fn full_round<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut Vec<F>,
r: usize,
) {
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
*state = apply_mds_matrix::<F, SC>(params, state);
for (i, x) in params.round_constants[r].iter().enumerate() {
state[i].add_assign(x);
}
}

pub fn half_rounds<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut [F],
) {
for r in 0..SC::PERM_HALF_ROUNDS_FULL {
for (i, x) in params.round_constants[r].iter().enumerate() {
state[i].add_assign(x);
}
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
apply_mds_matrix::<F, SC>(params, state);
}

for r in 0..SC::PERM_ROUNDS_PARTIAL {
for (i, x) in params.round_constants[SC::PERM_HALF_ROUNDS_FULL + r]
.iter()
.enumerate()
{
state[i].add_assign(x);
}
state[0] = sbox::<F, SC>(state[0]);
apply_mds_matrix::<F, SC>(params, state);
}

for r in 0..SC::PERM_HALF_ROUNDS_FULL {
for (i, x) in params.round_constants
[SC::PERM_HALF_ROUNDS_FULL + SC::PERM_ROUNDS_PARTIAL + r]
.iter()
.enumerate()
{
state[i].add_assign(x);
}
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
apply_mds_matrix::<F, SC>(params, state);
}
}

pub fn poseidon_block_cipher<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut Vec<F>,
) {
if SC::PERM_HALF_ROUNDS_FULL == 0 {
if SC::PERM_INITIAL_ARK {
for (i, x) in params.round_constants[0].iter().enumerate() {
state[i].add_assign(x);
}
for r in 0..SC::PERM_ROUNDS_FULL {
full_round::<F, SC>(params, state, r + 1);
}
} else {
for r in 0..SC::PERM_ROUNDS_FULL {
full_round::<F, SC>(params, state, r);
}
}
} else {
half_rounds::<F, SC>(params, state);
}
}

0 comments on commit 3f693c7

Please sign in to comment.