From 1da71a5da156f53db20ec4aec22bfea7043dfcb5 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 27 Aug 2023 21:30:47 +0100 Subject: [PATCH] Neon optimized version of the q4k vecdot product. (#632) --- candle-core/src/quantized/k_quants.rs | 4 ++ candle-core/src/quantized/neon.rs | 96 ++++++++++++++++++++++++++- 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index fec240bb5..7b405ec93 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -1102,7 +1102,11 @@ impl GgmlType for BlockQ4K { const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; + #[allow(unreachable_code)] fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q4k_q8k(n, xs, ys); + if n % QK_K != 0 { crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") } diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 32c93af4d..69d616f4b 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -1,5 +1,6 @@ -use super::k_quants::{BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; +use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K}; use crate::Result; +use byteorder::{ByteOrder, LittleEndian}; #[allow(unused_imports)] #[cfg(target_arch = "arm")] @@ -279,3 +280,96 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res } Ok(sum) } + +#[inline(always)] +pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}") + } + let mut sumf = 0f32; + let mut utmp = [0u32; 4]; + let mut scales = [0u8; 16]; + let kmask1: u32 = 0x3f3f3f3f; + let kmask2: u32 = 0x0f0f0f0f; + let kmask3: u32 = 0x03030303; + + unsafe { + let m4b = vdupq_n_u8(0xF); + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d * x.d.to_f32(); + let dmin = y.d * x.dmin.to_f32(); + + let q8sums = vpaddq_s16( + vld1q_s16(y.bsums.as_ptr()), + vld1q_s16(y.bsums.as_ptr().add(8)), + ); + + LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]); + + let mins8 = vld1_u32( + [ + utmp[1] & kmask1, + ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), + ] + .as_ptr(), + ); + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + let prod = vaddq_s32( + vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)), + ); + sumf -= dmin * vaddvq_s32(prod) as f32; + + LittleEndian::write_u32_into(&utmp, &mut scales); + + let mut q4 = x.qs.as_ptr(); + let mut q8 = y.qs.as_ptr(); + + let mut sumi1 = 0i32; + let mut sumi2 = 0i32; + + for j in 0..QK_K / 64 { + let q4bits = vld1q_u8_x2(q4); + q4 = q4.add(32); + // TODO: dotprod + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + let q4bytes = int8x16x2_t( + vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)), + vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)), + ); + let p0 = vaddq_s16( + vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), + vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), + ); + let p1 = vaddq_s16( + vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)), + vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)), + ); + sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32; + + let q8bytes = vld1q_s8_x2(q8); + q8 = q8.add(32); + let q4bytes = int8x16x2_t( + vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)), + vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)), + ); + let p2 = vaddq_s16( + vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), + vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), + ); + let p3 = vaddq_s16( + vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)), + vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)), + ); + sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32; + } + sumf += d * (sumi1 + sumi2) as f32; + } + } + Ok(sumf) +}