From 64fd01c6b56b4ed950cfc1b89e7dabcd270fbd8e Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Tue, 30 Apr 2024 23:58:14 -0400 Subject: [PATCH] Add blake2s implementation for SIMD backend --- .../prover/src/core/backend/simd/blake2s.rs | 408 ++++++++---------- 1 file changed, 190 insertions(+), 218 deletions(-) diff --git a/crates/prover/src/core/backend/simd/blake2s.rs b/crates/prover/src/core/backend/simd/blake2s.rs index bbca088f3..d4f8be449 100644 --- a/crates/prover/src/core/backend/simd/blake2s.rs +++ b/crates/prover/src/core/backend/simd/blake2s.rs @@ -1,11 +1,14 @@ -//! An AVX512 implementation of the BLAKE2s compression function. +//! A SIMD implementation of the BLAKE2s compression function. //! Based on . +use std::array; +use std::iter::repeat; +use std::mem::transmute; use std::simd::u32x16; use itertools::Itertools; -use super::m31::LOG_N_LANES; +use super::m31::{LOG_N_LANES, N_LANES}; use super::SimdBackend; use crate::core::backend::{Col, Column, ColumnOps}; use crate::core::fields::m31::BaseField; @@ -13,8 +16,6 @@ use crate::core::vcs::blake2_hash::Blake2sHash; use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; -const VECS_LOG_SIZE: usize = LOG_N_LANES as usize; - const IV: [u32; 8] = [ 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, ]; @@ -44,10 +45,9 @@ impl MerkleOps for SimdBackend { fn commit_on_layer( log_size: u32, prev_layer: Option<&Vec>, - columns: &[&Col], + columns: &[&Col], ) -> Vec { - // Pad prev_layer if too small. - if log_size < VECS_LOG_SIZE as u32 { + if log_size < N_LANES as u32 { return (0..(1 << log_size)) .map(|i| { Blake2sMerkleHasher::hash_node( @@ -62,31 +62,24 @@ impl MerkleOps for SimdBackend { assert_eq!(prev_layer.len(), 1 << (log_size + 1)); } + let zeros = u32x16::splat(0); + // Commit to columns. let mut res = Vec::with_capacity(1 << log_size); - for i in 0..(1 << (log_size - VECS_LOG_SIZE as u32)) { + for i in 0..(1 << (log_size - LOG_N_LANES)) { let mut state: [u32x16; 8] = unsafe { std::mem::zeroed() }; // Hash prev_layer, if exists. if let Some(prev_layer) = prev_layer { let ptr = prev_layer[(i << 5)..((i + 1) << 5)].as_ptr() as *const u32x16; - let msgs: [u32x16; 16] = std::array::from_fn(|j| unsafe { *ptr.add(j) }); - state = unsafe { - compress16( - state, - transpose_msgs(msgs), - set1(0), - set1(0), - set1(0), - set1(0), - ) - }; + let msgs: [u32x16; 16] = array::from_fn(|j| unsafe { *ptr.add(j) }); + state = compress16(state, transpose_msgs(msgs), zeros, zeros, zeros, zeros); } // Hash columns in chunks of 16. let mut col_chunk_iter = columns.array_chunks(); for col_chunk in &mut col_chunk_iter { let msgs = col_chunk.map(|column| column.data[i].into_simd()); - state = unsafe { compress16(state, msgs, set1(0), set1(0), set1(0), set1(0)) }; + state = compress16(state, msgs, zeros, zeros, zeros, zeros); } // Hash remaining columns. @@ -95,179 +88,148 @@ impl MerkleOps for SimdBackend { let msgs = remainder .iter() .map(|column| column.data[i].into_simd()) - .chain(std::iter::repeat(unsafe { set1(0) })) - .take(16) + .chain(repeat(zeros)) + .take(N_LANES) .collect_vec() .try_into() .unwrap(); - state = unsafe { compress16(state, msgs, set1(0), set1(0), set1(0), set1(0)) }; + state = compress16(state, msgs, zeros, zeros, zeros, zeros); } - let state: [Blake2sHash; 16] = - unsafe { std::mem::transmute(untranspose_states(state)) }; + let state: [Blake2sHash; 16] = unsafe { transmute(untranspose_states(state)) }; res.extend_from_slice(&state); } res } } -/// # Safety -#[inline(always)] -pub unsafe fn set1(iv: i32) -> u32x16 { - u32x16::splat(iv as u32) -} - -#[inline(always)] -unsafe fn add(a: u32x16, b: u32x16) -> u32x16 { - a + b -} - -#[inline(always)] -unsafe fn xor(a: u32x16, b: u32x16) -> u32x16 { - a ^ b -} - -#[inline(always)] -unsafe fn rot16(x: u32x16) -> u32x16 { - (x >> 16) | (x << (32 - 16)) -} - +/// Applies [`u32::rotate_right(N)`] to each element of the vector +/// +/// [`u32::rotate_right(N)`]: u32::rotate_right #[inline(always)] -unsafe fn rot12(x: u32x16) -> u32x16 { - (x >> 12) | (x << (32 - 12)) +fn rotate(x: u32x16) -> u32x16 { + (x >> N) | (x << (u32::BITS - N)) } -#[inline(always)] -unsafe fn rot8(x: u32x16) -> u32x16 { - (x >> 8) | (x << (32 - 8)) -} - -#[inline(always)] -unsafe fn rot7(x: u32x16) -> u32x16 { - (x >> 7) | (x << (32 - 7)) +#[inline] +fn round(v: &mut [u32x16; 16], m: [u32x16; 16], r: usize) { + v[0] += m[SIGMA[r][0] as usize]; + v[1] += m[SIGMA[r][2] as usize]; + v[2] += m[SIGMA[r][4] as usize]; + v[3] += m[SIGMA[r][6] as usize]; + v[0] += v[4]; + v[1] += v[5]; + v[2] += v[6]; + v[3] += v[7]; + v[12] ^= v[0]; + v[13] ^= v[1]; + v[14] ^= v[2]; + v[15] ^= v[3]; + v[12] = rotate::<16>(v[12]); + v[13] = rotate::<16>(v[13]); + v[14] = rotate::<16>(v[14]); + v[15] = rotate::<16>(v[15]); + v[8] += v[12]; + v[9] += v[13]; + v[10] += v[14]; + v[11] += v[15]; + v[4] ^= v[8]; + v[5] ^= v[9]; + v[6] ^= v[10]; + v[7] ^= v[11]; + v[4] = rotate::<12>(v[4]); + v[5] = rotate::<12>(v[5]); + v[6] = rotate::<12>(v[6]); + v[7] = rotate::<12>(v[7]); + v[0] += m[SIGMA[r][1] as usize]; + v[1] += m[SIGMA[r][3] as usize]; + v[2] += m[SIGMA[r][5] as usize]; + v[3] += m[SIGMA[r][7] as usize]; + v[0] += v[4]; + v[1] += v[5]; + v[2] += v[6]; + v[3] += v[7]; + v[12] ^= v[0]; + v[13] ^= v[1]; + v[14] ^= v[2]; + v[15] ^= v[3]; + v[12] = rotate::<8>(v[12]); + v[13] = rotate::<8>(v[13]); + v[14] = rotate::<8>(v[14]); + v[15] = rotate::<8>(v[15]); + v[8] += v[12]; + v[9] += v[13]; + v[10] += v[14]; + v[11] += v[15]; + v[4] ^= v[8]; + v[5] ^= v[9]; + v[6] ^= v[10]; + v[7] ^= v[11]; + v[4] = rotate::<7>(v[4]); + v[5] = rotate::<7>(v[5]); + v[6] = rotate::<7>(v[6]); + v[7] = rotate::<7>(v[7]); + + v[0] += m[SIGMA[r][8] as usize]; + v[1] += m[SIGMA[r][10] as usize]; + v[2] += m[SIGMA[r][12] as usize]; + v[3] += m[SIGMA[r][14] as usize]; + v[0] += v[5]; + v[1] += v[6]; + v[2] += v[7]; + v[3] += v[4]; + v[15] ^= v[0]; + v[12] ^= v[1]; + v[13] ^= v[2]; + v[14] ^= v[3]; + v[15] = rotate::<16>(v[15]); + v[12] = rotate::<16>(v[12]); + v[13] = rotate::<16>(v[13]); + v[14] = rotate::<16>(v[14]); + v[10] += v[15]; + v[11] += v[12]; + v[8] += v[13]; + v[9] += v[14]; + v[5] ^= v[10]; + v[6] ^= v[11]; + v[7] ^= v[8]; + v[4] ^= v[9]; + v[5] = rotate::<12>(v[5]); + v[6] = rotate::<12>(v[6]); + v[7] = rotate::<12>(v[7]); + v[4] = rotate::<12>(v[4]); + v[0] += m[SIGMA[r][9] as usize]; + v[1] += m[SIGMA[r][11] as usize]; + v[2] += m[SIGMA[r][13] as usize]; + v[3] += m[SIGMA[r][15] as usize]; + v[0] += v[5]; + v[1] += v[6]; + v[2] += v[7]; + v[3] += v[4]; + v[15] ^= v[0]; + v[12] ^= v[1]; + v[13] ^= v[2]; + v[14] ^= v[3]; + v[15] = rotate::<8>(v[15]); + v[12] = rotate::<8>(v[12]); + v[13] = rotate::<8>(v[13]); + v[14] = rotate::<8>(v[14]); + v[10] += v[15]; + v[11] += v[12]; + v[8] += v[13]; + v[9] += v[14]; + v[5] ^= v[10]; + v[6] ^= v[11]; + v[7] ^= v[8]; + v[4] ^= v[9]; + v[5] = rotate::<7>(v[5]); + v[6] = rotate::<7>(v[6]); + v[7] = rotate::<7>(v[7]); + v[4] = rotate::<7>(v[4]); } -#[inline(always)] -unsafe fn round(v: &mut [u32x16; 16], m: [u32x16; 16], r: usize) { - v[0] = add(v[0], m[SIGMA[r][0] as usize]); - v[1] = add(v[1], m[SIGMA[r][2] as usize]); - v[2] = add(v[2], m[SIGMA[r][4] as usize]); - v[3] = add(v[3], m[SIGMA[r][6] as usize]); - v[0] = add(v[0], v[4]); - v[1] = add(v[1], v[5]); - v[2] = add(v[2], v[6]); - v[3] = add(v[3], v[7]); - v[12] = xor(v[12], v[0]); - v[13] = xor(v[13], v[1]); - v[14] = xor(v[14], v[2]); - v[15] = xor(v[15], v[3]); - v[12] = rot16(v[12]); - v[13] = rot16(v[13]); - v[14] = rot16(v[14]); - v[15] = rot16(v[15]); - v[8] = add(v[8], v[12]); - v[9] = add(v[9], v[13]); - v[10] = add(v[10], v[14]); - v[11] = add(v[11], v[15]); - v[4] = xor(v[4], v[8]); - v[5] = xor(v[5], v[9]); - v[6] = xor(v[6], v[10]); - v[7] = xor(v[7], v[11]); - v[4] = rot12(v[4]); - v[5] = rot12(v[5]); - v[6] = rot12(v[6]); - v[7] = rot12(v[7]); - v[0] = add(v[0], m[SIGMA[r][1] as usize]); - v[1] = add(v[1], m[SIGMA[r][3] as usize]); - v[2] = add(v[2], m[SIGMA[r][5] as usize]); - v[3] = add(v[3], m[SIGMA[r][7] as usize]); - v[0] = add(v[0], v[4]); - v[1] = add(v[1], v[5]); - v[2] = add(v[2], v[6]); - v[3] = add(v[3], v[7]); - v[12] = xor(v[12], v[0]); - v[13] = xor(v[13], v[1]); - v[14] = xor(v[14], v[2]); - v[15] = xor(v[15], v[3]); - v[12] = rot8(v[12]); - v[13] = rot8(v[13]); - v[14] = rot8(v[14]); - v[15] = rot8(v[15]); - v[8] = add(v[8], v[12]); - v[9] = add(v[9], v[13]); - v[10] = add(v[10], v[14]); - v[11] = add(v[11], v[15]); - v[4] = xor(v[4], v[8]); - v[5] = xor(v[5], v[9]); - v[6] = xor(v[6], v[10]); - v[7] = xor(v[7], v[11]); - v[4] = rot7(v[4]); - v[5] = rot7(v[5]); - v[6] = rot7(v[6]); - v[7] = rot7(v[7]); - - v[0] = add(v[0], m[SIGMA[r][8] as usize]); - v[1] = add(v[1], m[SIGMA[r][10] as usize]); - v[2] = add(v[2], m[SIGMA[r][12] as usize]); - v[3] = add(v[3], m[SIGMA[r][14] as usize]); - v[0] = add(v[0], v[5]); - v[1] = add(v[1], v[6]); - v[2] = add(v[2], v[7]); - v[3] = add(v[3], v[4]); - v[15] = xor(v[15], v[0]); - v[12] = xor(v[12], v[1]); - v[13] = xor(v[13], v[2]); - v[14] = xor(v[14], v[3]); - v[15] = rot16(v[15]); - v[12] = rot16(v[12]); - v[13] = rot16(v[13]); - v[14] = rot16(v[14]); - v[10] = add(v[10], v[15]); - v[11] = add(v[11], v[12]); - v[8] = add(v[8], v[13]); - v[9] = add(v[9], v[14]); - v[5] = xor(v[5], v[10]); - v[6] = xor(v[6], v[11]); - v[7] = xor(v[7], v[8]); - v[4] = xor(v[4], v[9]); - v[5] = rot12(v[5]); - v[6] = rot12(v[6]); - v[7] = rot12(v[7]); - v[4] = rot12(v[4]); - v[0] = add(v[0], m[SIGMA[r][9] as usize]); - v[1] = add(v[1], m[SIGMA[r][11] as usize]); - v[2] = add(v[2], m[SIGMA[r][13] as usize]); - v[3] = add(v[3], m[SIGMA[r][15] as usize]); - v[0] = add(v[0], v[5]); - v[1] = add(v[1], v[6]); - v[2] = add(v[2], v[7]); - v[3] = add(v[3], v[4]); - v[15] = xor(v[15], v[0]); - v[12] = xor(v[12], v[1]); - v[13] = xor(v[13], v[2]); - v[14] = xor(v[14], v[3]); - v[15] = rot8(v[15]); - v[12] = rot8(v[12]); - v[13] = rot8(v[13]); - v[14] = rot8(v[14]); - v[10] = add(v[10], v[15]); - v[11] = add(v[11], v[12]); - v[8] = add(v[8], v[13]); - v[9] = add(v[9], v[14]); - v[5] = xor(v[5], v[10]); - v[6] = xor(v[6], v[11]); - v[7] = xor(v[7], v[8]); - v[4] = xor(v[4], v[9]); - v[5] = rot7(v[5]); - v[6] = rot7(v[6]); - v[7] = rot7(v[7]); - v[4] = rot7(v[4]); -} - -/// Transposes input chunks (16 chunks of 16 u32s each), to get 16 u32x16, each +/// Transposes input chunks (16 chunks of 16 `u32`s each), to get 16 `u32x16`, each /// representing 16 packed instances of a message word. -/// # Safety -pub unsafe fn transpose_msgs(mut data: [u32x16; 16]) -> [u32x16; 16] { - // Each _m512i chunk contains 16 u32 words. +fn transpose_msgs(mut data: [u32x16; 16]) -> [u32x16; 16] { // Index abcd:xyzw, refers to a specific word in data as follows: // abcd - chunk index (in base 2) // xyzw - word offset (in base 2) @@ -287,13 +249,11 @@ pub unsafe fn transpose_msgs(mut data: [u32x16; 16]) -> [u32x16; 16] { d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15, ]; } + data } -/// Transposes states, from 8 packed words, to get 16 results, each of size 32B. -/// # Safety -pub unsafe fn untranspose_states(mut states: [u32x16; 8]) -> [u32x16; 8] { - // Each _m512i chunk contains 16 u32 words. +fn untranspose_states(mut states: [u32x16; 8]) -> [u32x16; 8] { // Index abc:xyzw, refers to a specific word in data as follows: // abc - chunk index (in base 2) // xyzw - word offset (in base 2) @@ -310,9 +270,8 @@ pub unsafe fn untranspose_states(mut states: [u32x16; 8]) -> [u32x16; 8] { states } -/// Compress 16 blake2s instances. -/// # Safety -pub unsafe fn compress16( +/// Compresses 16 blake2s instances. +fn compress16( h_vecs: [u32x16; 8], msg_vecs: [u32x16; 16], count_low: u32x16, @@ -329,14 +288,14 @@ pub unsafe fn compress16( h_vecs[5], h_vecs[6], h_vecs[7], - set1(IV[0] as i32), - set1(IV[1] as i32), - set1(IV[2] as i32), - set1(IV[3] as i32), - xor(set1(IV[4] as i32), count_low), - xor(set1(IV[5] as i32), count_high), - xor(set1(IV[6] as i32), lastblock), - xor(set1(IV[7] as i32), lastnode), + u32x16::splat(IV[0]), + u32x16::splat(IV[1]), + u32x16::splat(IV[2]), + u32x16::splat(IV[3]), + u32x16::splat(IV[4]) ^ count_low, + u32x16::splat(IV[5]) ^ count_high, + u32x16::splat(IV[6]) ^ lastblock, + u32x16::splat(IV[7]) ^ lastnode, ]; round(&mut v, msg_vecs, 0); @@ -351,58 +310,70 @@ pub unsafe fn compress16( round(&mut v, msg_vecs, 9); [ - xor(xor(h_vecs[0], v[0]), v[8]), - xor(xor(h_vecs[1], v[1]), v[9]), - xor(xor(h_vecs[2], v[2]), v[10]), - xor(xor(h_vecs[3], v[3]), v[11]), - xor(xor(h_vecs[4], v[4]), v[12]), - xor(xor(h_vecs[5], v[5]), v[13]), - xor(xor(h_vecs[6], v[6]), v[14]), - xor(xor(h_vecs[7], v[7]), v[15]), + h_vecs[0] ^ v[0] ^ v[8], + h_vecs[1] ^ v[1] ^ v[9], + h_vecs[2] ^ v[2] ^ v[10], + h_vecs[3] ^ v[3] ^ v[11], + h_vecs[4] ^ v[4] ^ v[12], + h_vecs[5] ^ v[5] ^ v[13], + h_vecs[6] ^ v[6] ^ v[14], + h_vecs[7] ^ v[7] ^ v[15], ] } #[cfg(test)] mod tests { + use std::array; + use std::mem::transmute; use std::simd::u32x16; - use super::{compress16, set1, transpose_msgs, untranspose_states}; + use aligned::{Aligned, A64}; + + use super::{compress16, transpose_msgs, untranspose_states}; use crate::core::vcs::blake2s_ref::compress; #[test] - fn test_compress16() { - let states: [[u32; 8]; 16] = - std::array::from_fn(|i| std::array::from_fn(|j| (i + j) as u32)); - let msgs: [[u32; 16]; 16] = - std::array::from_fn(|i| std::array::from_fn(|j| (i + j + 20) as u32)); + fn compress16_works() { + let states: Aligned = + Aligned(array::from_fn(|i| array::from_fn(|j| (i + j) as u32))); + let msgs: Aligned = + Aligned(array::from_fn(|i| array::from_fn(|j| (i + j + 20) as u32))); let count_low = 1; let count_high = 2; let lastblock = 3; let lastnode = 4; - let res_unvectorized = std::array::from_fn(|i| { + let res_unvectorized = array::from_fn(|i| { compress( states[i], msgs[i], count_low, count_high, lastblock, lastnode, ) }); let res_vectorized: [[u32; 8]; 16] = unsafe { - std::mem::transmute(untranspose_states(compress16( - transpose_states(std::mem::transmute(states)), - transpose_msgs(std::mem::transmute(msgs)), - set1(count_low as i32), - set1(count_high as i32), - set1(lastblock as i32), - set1(lastnode as i32), + transmute(untranspose_states(compress16( + transpose_states(transmute(states)), + transpose_msgs(transmute(msgs)), + u32x16::splat(count_low), + u32x16::splat(count_high), + u32x16::splat(lastblock), + u32x16::splat(lastnode), ))) }; - assert_eq!(res_unvectorized, res_vectorized); + assert_eq!(res_vectorized, res_unvectorized); + } + + #[test] + fn untranspose_states_is_transpose_states_inverse() { + let states = array::from_fn(|i| u32x16::from(array::from_fn(|j| (i + j) as u32))); + let transposed_states = transpose_states(states); + + let untrasponsed_transposed_states = untranspose_states(transposed_states); + + assert_eq!(untrasponsed_transposed_states, states) } /// Transposes states, from 8 packed words, to get 16 results, each of size 32B. - /// # Safety - pub unsafe fn transpose_states(mut states: [u32x16; 8]) -> [u32x16; 8] { - // Each _m512i chunk contains 16 u32 words. + fn transpose_states(mut states: [u32x16; 8]) -> [u32x16; 8] { // Index abc:xyzw, refers to a specific word in data as follows: // abc - chunk index (in base 2) // xyzw - word offset (in base 2) @@ -416,6 +387,7 @@ mod tests { let (s3, s7) = states[6].deinterleave(states[7]); states = [s0, s1, s2, s3, s4, s5, s6, s7]; } + states } }