Skip to content

Commit

Permalink
fixup! Manually unroll MDS calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeee committed Sep 16, 2024
1 parent 1c43ab5 commit fd5d15d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 31 deletions.
69 changes: 38 additions & 31 deletions poseidon/src/permutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,16 @@ fn apply_mds_matrix<F: Field, SC: SpongeConstants>(
state: &[F],
) -> Vec<F> {
if SC::PERM_FULL_MDS {
if params.mds.len() == 3 {
vec![
// Manually unrolled loops for multiplying each row by the vector
params.mds[0][0] * state[0]
+ params.mds[0][1] * state[1]
+ params.mds[0][2] * state[2],
params.mds[1][0] * state[0]
+ params.mds[1][1] * state[1]
+ params.mds[1][2] * state[2],
params.mds[2][0] * state[0]
+ params.mds[2][1] * state[1]
+ params.mds[2][2] * state[2],
]
} else {
params
.mds
.iter()
.map(|m| {
state
.iter()
.zip(m.iter())
.fold(F::zero(), |x, (s, &m)| m * s + x)
})
.collect()
}
params
.mds
.iter()
.map(|m| {
state
.iter()
.zip(m.iter())
.fold(F::zero(), |x, (s, &m)| m * s + x)
})
.collect()
} else {
vec![
state[0] + state[2],
Expand All @@ -43,17 +28,39 @@ fn apply_mds_matrix<F: Field, SC: SpongeConstants>(
}
}

#[inline]
pub fn full_round<F: Field, SC: SpongeConstants>(
params: &ArithmeticSpongeParams<F>,
state: &mut Vec<F>,
r: usize,
) {
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
*state = apply_mds_matrix::<F, SC>(params, state);
for (i, x) in params.round_constants[r].iter().enumerate() {
state[i].add_assign(x);
if SC::PERM_FULL_MDS && state.len() == 3 {
state[0] = sbox::<F, SC>(state[0]);
state[1] = sbox::<F, SC>(state[1]);
state[2] = sbox::<F, SC>(state[2]);
*state = vec![
// Manually unrolled loops for multiplying each row by the vector
params.mds[0][0] * state[0]
+ params.mds[0][1] * state[1]
+ params.mds[0][2] * state[2]
+ params.round_constants[r][0],
params.mds[1][0] * state[0]
+ params.mds[1][1] * state[1]
+ params.mds[1][2] * state[2]
+ params.round_constants[r][1],
params.mds[2][0] * state[0]
+ params.mds[2][1] * state[1]
+ params.mds[2][2] * state[2]
+ params.round_constants[r][2],
];
} else {
for state_i in state.iter_mut() {
*state_i = sbox::<F, SC>(*state_i);
}
*state = apply_mds_matrix::<F, SC>(params, state);
for (i, x) in params.round_constants[r].iter().enumerate() {
state[i].add_assign(x);
}
}
}

Expand Down
1 change: 1 addition & 0 deletions poseidon/src/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub trait Sponge<Input: Field, Digest> {
fn reset(&mut self);
}

#[inline]
pub fn sbox<F: Field, SC: SpongeConstants>(mut x: F) -> F {
if SC::PERM_SBOX == 7 {
// This is much faster than using the generic `pow`. Hard-code to get the ~50% speed-up
Expand Down

0 comments on commit fd5d15d

Please sign in to comment.