Skip to content

Commit

Permalink
Neon optimized version of the q4k vecdot product. (huggingface#632)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Aug 27, 2023
1 parent 24dda44 commit 1da71a5
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
4 changes: 4 additions & 0 deletions candle-core/src/quantized/k_quants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,11 @@ impl GgmlType for BlockQ4K {
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;

#[allow(unreachable_code)]
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q4k_q8k(n, xs, ys);

if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
Expand Down
96 changes: 95 additions & 1 deletion candle-core/src/quantized/neon.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::k_quants::{BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
use super::k_quants::{BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};

#[allow(unused_imports)]
#[cfg(target_arch = "arm")]
Expand Down Expand Up @@ -279,3 +280,96 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
}
Ok(sum)
}

#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0f32;
let mut utmp = [0u32; 4];
let mut scales = [0u8; 16];
let kmask1: u32 = 0x3f3f3f3f;
let kmask2: u32 = 0x0f0f0f0f;
let kmask3: u32 = 0x03030303;

unsafe {
let m4b = vdupq_n_u8(0xF);

for (x, y) in xs.iter().zip(ys.iter()) {
let d = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();

let q8sums = vpaddq_s16(
vld1q_s16(y.bsums.as_ptr()),
vld1q_s16(y.bsums.as_ptr().add(8)),
);

LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);

let mins8 = vld1_u32(
[
utmp[1] & kmask1,
((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4),
]
.as_ptr(),
);
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[0] &= kmask1;

let mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
let prod = vaddq_s32(
vmull_s16(vget_low_s16(q8sums), vget_low_s16(mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)),
);
sumf -= dmin * vaddvq_s32(prod) as f32;

LittleEndian::write_u32_into(&utmp, &mut scales);

let mut q4 = x.qs.as_ptr();
let mut q8 = y.qs.as_ptr();

let mut sumi1 = 0i32;
let mut sumi2 = 0i32;

for j in 0..QK_K / 64 {
let q4bits = vld1q_u8_x2(q4);
q4 = q4.add(32);
// TODO: dotprod
let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
);
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;

let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32);
let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
}
sumf += d * (sumi1 + sumi2) as f32;
}
}
Ok(sumf)
}

0 comments on commit 1da71a5

Please sign in to comment.