forked from huggingface/candle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
simd128 optimized q8_0 vecdot (huggingface#972)
* wasm/simd128 version of the quantized q8_0 vecdot. * Add the missing conversion.
- Loading branch information
1 parent
29bd6b2
commit e59784e
Showing
3 changed files
with
54 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
use super::k_quants::{BlockQ8_0, QK8_0}; | ||
use crate::Result; | ||
use half::f16; | ||
|
||
use core::arch::wasm32::*; | ||
|
||
#[inline(always)] | ||
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> { | ||
let qk = QK8_0; | ||
if n % QK8_0 != 0 { | ||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") | ||
} | ||
let nb = n / QK8_0; | ||
if nb % 2 != 0 { | ||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even") | ||
} | ||
unsafe { | ||
let mut acc = f32x4_splat(0.0f32); | ||
for (x, y) in xs.iter().zip(ys.iter()) { | ||
let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr()); | ||
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr()); | ||
let sum_xy = i32x4_dot_i16x8(x1, y1); | ||
|
||
let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8)); | ||
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8)); | ||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2)); | ||
|
||
let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16)); | ||
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16)); | ||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3)); | ||
|
||
let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24)); | ||
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24)); | ||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4)); | ||
|
||
let sum_xy = f32x4_convert_i32x4(sum_xy); | ||
|
||
// f32x4_relaxed_madd is nightly only. | ||
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d)); | ||
let scaled = f32x4_mul(sum_xy, d); | ||
acc = f32x4_add(acc, scaled) | ||
} | ||
let res = f32x4_extract_lane::<0>(acc) | ||
+ f32x4_extract_lane::<1>(acc) | ||
+ f32x4_extract_lane::<2>(acc) | ||
+ f32x4_extract_lane::<3>(acc); | ||
Ok(res) | ||
} | ||
} |