diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index a0fe455c6d..e0218c34ec 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -606,6 +606,9 @@ impl GgmlType for BlockQ8_0 { #[cfg(target_feature = "neon")] return super::neon::vec_dot_q8_0_q8_0(n, xs, ys); + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys); + let qk = QK8_0; if n % QK8_0 != 0 { crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index f627f0f63b..61fabc63e8 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -7,6 +7,8 @@ pub mod gguf_file; pub mod k_quants; #[cfg(target_feature = "neon")] pub mod neon; +#[cfg(target_feature = "simd128")] +pub mod simd128; pub mod utils; pub use k_quants::GgmlType; diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs new file mode 100644 index 0000000000..9cb7119ff7 --- /dev/null +++ b/candle-core/src/quantized/simd128.rs @@ -0,0 +1,49 @@ +use super::k_quants::{BlockQ8_0, QK8_0}; +use crate::Result; +use half::f16; + +use core::arch::wasm32::*; + +#[inline(always)] +pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") + } + let nb = n / QK8_0; + if nb % 2 != 0 { + crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even") + } + unsafe { + let mut acc = f32x4_splat(0.0f32); + for (x, y) in xs.iter().zip(ys.iter()) { + let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr()); + let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr()); + let sum_xy = i32x4_dot_i16x8(x1, y1); + + let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8)); + let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2)); + + let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16)); + let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3)); + + let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24)); + let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4)); + + let sum_xy = f32x4_convert_i32x4(sum_xy); + + // f32x4_relaxed_madd is nightly only. + let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d)); + let scaled = f32x4_mul(sum_xy, d); + acc = f32x4_add(acc, scaled) + } + let res = f32x4_extract_lane::<0>(acc) + + f32x4_extract_lane::<1>(acc) + + f32x4_extract_lane::<2>(acc) + + f32x4_extract_lane::<3>(acc); + Ok(res) + } +}