Skip to content

Commit

Permalink
Alignment fix
Browse files Browse the repository at this point in the history
  • Loading branch information
spapinistarkware committed Jun 27, 2024
1 parent 7d2397f commit 0129f29
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
12 changes: 9 additions & 3 deletions crates/prover/src/core/backend/simd/blake2s.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::iter::repeat;
use std::mem::transmute;
use std::simd::u32x16;

use bytemuck::cast_slice;
use itertools::Itertools;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
Expand Down Expand Up @@ -49,7 +50,7 @@ impl MerkleOps<Blake2sMerkleHasher> for SimdBackend {
prev_layer: Option<&Vec<Blake2sHash>>,
columns: &[&Col<Self, BaseField>],
) -> Vec<Blake2sHash> {
if log_size < N_LANES as u32 {
if log_size < LOG_N_LANES {
#[cfg(not(feature = "parallel"))]
let iter = 0..1 << log_size;

Expand All @@ -65,6 +66,7 @@ impl MerkleOps<Blake2sMerkleHasher> for SimdBackend {
})
.collect();
}
println!("B");

if let Some(prev_layer) = prev_layer {
assert_eq!(prev_layer.len(), 1 << (log_size + 1));
Expand All @@ -84,8 +86,12 @@ impl MerkleOps<Blake2sMerkleHasher> for SimdBackend {
let mut state: [u32x16; 8] = unsafe { std::mem::zeroed() };
// Hash prev_layer, if exists.
if let Some(prev_layer) = prev_layer {
let ptr = prev_layer[(i << 5)..((i + 1) << 5)].as_ptr() as *const u32x16;
let msgs: [u32x16; 16] = array::from_fn(|j| unsafe { *ptr.add(j) });
let prev_chunk_u32s = cast_slice::<_, u32>(&prev_layer[(i << 5)..((i + 1) << 5)]);
// Note: prev_layer might be unaligned.
let msgs: [u32x16; 16] = array::from_fn(|j| {
u32x16::from_array(std::array::from_fn(|k| prev_chunk_u32s[16 * j + k]))
});
// let msgs: [u32x16; 16] = array::from_fn(|j| unsafe { *ptr.add(j) });
state = compress16(state, transpose_msgs(msgs), zeros, zeros, zeros, zeros);
}

Expand Down
5 changes: 3 additions & 2 deletions crates/prover/src/core/vcs/blake2_hash.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::fmt;

use blake2::{Blake2s256, Digest};
use bytemuck::{Pod, Zeroable};

// Wrapper for the blake2s hash type.
#[repr(align(32))]
#[derive(Clone, Copy, PartialEq, Default, Eq)]
#[repr(C, align(32))]
#[derive(Clone, Copy, PartialEq, Default, Eq, Pod, Zeroable)]
pub struct Blake2sHash([u8; 32]);

impl From<Blake2sHash> for Vec<u8> {
Expand Down

0 comments on commit 0129f29

Please sign in to comment.