diff --git a/crates/prover/src/core/backend/simd/blake2s.rs b/crates/prover/src/core/backend/simd/blake2s.rs index b39b1fdd2..3d21f6dd6 100644 --- a/crates/prover/src/core/backend/simd/blake2s.rs +++ b/crates/prover/src/core/backend/simd/blake2s.rs @@ -6,6 +6,7 @@ use std::iter::repeat; use std::mem::transmute; use std::simd::u32x16; +use bytemuck::cast_slice; use itertools::Itertools; #[cfg(feature = "parallel")] use rayon::prelude::*; @@ -49,7 +50,7 @@ impl MerkleOps for SimdBackend { prev_layer: Option<&Vec>, columns: &[&Col], ) -> Vec { - if log_size < N_LANES as u32 { + if log_size < LOG_N_LANES { #[cfg(not(feature = "parallel"))] let iter = 0..1 << log_size; @@ -65,6 +66,7 @@ impl MerkleOps for SimdBackend { }) .collect(); } + println!("B"); if let Some(prev_layer) = prev_layer { assert_eq!(prev_layer.len(), 1 << (log_size + 1)); @@ -84,8 +86,11 @@ impl MerkleOps for SimdBackend { let mut state: [u32x16; 8] = unsafe { std::mem::zeroed() }; // Hash prev_layer, if exists. if let Some(prev_layer) = prev_layer { - let ptr = prev_layer[(i << 5)..((i + 1) << 5)].as_ptr() as *const u32x16; - let msgs: [u32x16; 16] = array::from_fn(|j| unsafe { *ptr.add(j) }); + let prev_chunk_u32s = cast_slice::<_, u32>(&prev_layer[(i << 5)..((i + 1) << 5)]); + // Note: prev_layer might be unaligned. + let msgs: [u32x16; 16] = array::from_fn(|j| { + u32x16::from_array(std::array::from_fn(|k| prev_chunk_u32s[16 * j + k])) + }); state = compress16(state, transpose_msgs(msgs), zeros, zeros, zeros, zeros); } diff --git a/crates/prover/src/core/vcs/blake2_hash.rs b/crates/prover/src/core/vcs/blake2_hash.rs index c25580a91..42af0ea53 100644 --- a/crates/prover/src/core/vcs/blake2_hash.rs +++ b/crates/prover/src/core/vcs/blake2_hash.rs @@ -1,10 +1,11 @@ use std::fmt; use blake2::{Blake2s256, Digest}; +use bytemuck::{Pod, Zeroable}; // Wrapper for the blake2s hash type. -#[repr(align(32))] -#[derive(Clone, Copy, PartialEq, Default, Eq)] +#[repr(C, align(32))] +#[derive(Clone, Copy, PartialEq, Default, Eq, Pod, Zeroable)] pub struct Blake2sHash([u8; 32]); impl From for Vec {