Skip to content

Commit

Permalink
Manually unroll MDS calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeee committed Sep 16, 2024
1 parent 513d045 commit 1c43ab5
Showing 1 changed file with 120 additions and 105 deletions.
225 changes: 120 additions & 105 deletions poseidon/src/permutation.rs
Original file line number Diff line number Diff line change
@@ -1,105 +1,120 @@
//! The permutation module contains the function implementing the permutation used in Poseidon

use crate::constants::SpongeConstants;
use crate::poseidon::{sbox, ArithmeticSpongeParams};
use ark_ff::Field;

fn apply_mds_matrix<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &[F],
) -> Vec<F> {
if SC::PERM_FULL_MDS {
params
.mds
.iter()
.map(|m| {
state
.iter()
.zip(m.iter())
.fold(F::zero(), |x, (s, &m)| m * s + x)
})
.collect()
} else {
vec![
state[0] + state[2],
state[0] + state[1],
state[1] + state[2],
]
}
}

pub fn full_round<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut Vec<F>,
r: usize,
) {
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
*state = apply_mds_matrix::<F, SC>(params, state);
for (i, x) in params.round_constants[r].iter().enumerate() {
state[i].add_assign(x);
}
}

pub fn half_rounds<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut [F],
) {
for r in 0..SC::PERM_HALF_ROUNDS_FULL {
for (i, x) in params.round_constants[r].iter().enumerate() {
state[i].add_assign(x);
}
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
apply_mds_matrix::<F, SC>(params, state);
}

for r in 0..SC::PERM_ROUNDS_PARTIAL {
for (i, x) in params.round_constants[SC::PERM_HALF_ROUNDS_FULL + r]
.iter()
.enumerate()
{
state[i].add_assign(x);
}
state[0] = sbox::<F, SC>(state[0]);
apply_mds_matrix::<F, SC>(params, state);
}

for r in 0..SC::PERM_HALF_ROUNDS_FULL {
for (i, x) in params.round_constants
[SC::PERM_HALF_ROUNDS_FULL + SC::PERM_ROUNDS_PARTIAL + r]
.iter()
.enumerate()
{
state[i].add_assign(x);
}
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
apply_mds_matrix::<F, SC>(params, state);
}
}

pub fn poseidon_block_cipher<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut Vec<F>,
) {
if SC::PERM_HALF_ROUNDS_FULL == 0 {
if SC::PERM_INITIAL_ARK {
for (i, x) in params.round_constants[0].iter().enumerate() {
state[i].add_assign(x);
}
for r in 0..SC::PERM_ROUNDS_FULL {
full_round::<F, SC>(params, state, r + 1);
}
} else {
for r in 0..SC::PERM_ROUNDS_FULL {
full_round::<F, SC>(params, state, r);
}
}
} else {
half_rounds::<F, SC>(params, state);
}
}
//! The permutation module contains the function implementing the permutation used in Poseidon

use crate::constants::SpongeConstants;
use crate::poseidon::{sbox, ArithmeticSpongeParams};
use ark_ff::Field;

fn apply_mds_matrix<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &[F],
) -> Vec<F> {
if SC::PERM_FULL_MDS {
if params.mds.len() == 3 {
vec![
// Manually unrolled loops for multiplying each row by the vector
params.mds[0][0] * state[0]
+ params.mds[0][1] * state[1]
+ params.mds[0][2] * state[2],
params.mds[1][0] * state[0]
+ params.mds[1][1] * state[1]
+ params.mds[1][2] * state[2],
params.mds[2][0] * state[0]
+ params.mds[2][1] * state[1]
+ params.mds[2][2] * state[2],
]
} else {
params
.mds
.iter()
.map(|m| {
state
.iter()
.zip(m.iter())
.fold(F::zero(), |x, (s, &m)| m * s + x)
})
.collect()
}
} else {
vec![
state[0] + state[2],
state[0] + state[1],
state[1] + state[2],
]
}
}

pub fn full_round<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut Vec<F>,
r: usize,
) {
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
*state = apply_mds_matrix::<F, SC>(params, state);
for (i, x) in params.round_constants[r].iter().enumerate() {
state[i].add_assign(x);
}
}

pub fn half_rounds<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut [F],
) {
for r in 0..SC::PERM_HALF_ROUNDS_FULL {
for (i, x) in params.round_constants[r].iter().enumerate() {
state[i].add_assign(x);
}
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
apply_mds_matrix::<F, SC>(params, state);
}

for r in 0..SC::PERM_ROUNDS_PARTIAL {
for (i, x) in params.round_constants[SC::PERM_HALF_ROUNDS_FULL + r]
.iter()
.enumerate()
{
state[i].add_assign(x);
}
state[0] = sbox::<F, SC>(state[0]);
apply_mds_matrix::<F, SC>(params, state);
}

for r in 0..SC::PERM_HALF_ROUNDS_FULL {
for (i, x) in params.round_constants
[SC::PERM_HALF_ROUNDS_FULL + SC::PERM_ROUNDS_PARTIAL + r]
.iter()
.enumerate()
{
state[i].add_assign(x);
}
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
apply_mds_matrix::<F, SC>(params, state);
}
}

pub fn poseidon_block_cipher<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut Vec<F>,
) {
if SC::PERM_HALF_ROUNDS_FULL == 0 {
if SC::PERM_INITIAL_ARK {
for (i, x) in params.round_constants[0].iter().enumerate() {
state[i].add_assign(x);
}
for r in 0..SC::PERM_ROUNDS_FULL {
full_round::<F, SC>(params, state, r + 1);
}
} else {
for r in 0..SC::PERM_ROUNDS_FULL {
full_round::<F, SC>(params, state, r);
}
}
} else {
half_rounds::<F, SC>(params, state);
}
}

0 comments on commit 1c43ab5

Please sign in to comment.