Skip to content

Commit

Permalink
Simd128 vec-dot for q4_0. (huggingface#974)
Browse files Browse the repository at this point in the history
* Simd128 vec-dot for q4_0.

* Bugfix.

* Add wasm tests.

* Bugfix for the q40 vecdot.

* More quantization tests.
  • Loading branch information
LaurentMazare committed Sep 27, 2023
1 parent e59784e commit 667f01c
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 2 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ members = [
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
"candle-wasm-examples/bert",
"candle-wasm-examples/t5",
"candle-wasm-examples/phi",
"candle-wasm-examples/t5",
"candle-wasm-tests",
]
exclude = ["candle-flash-attn", "candle-kernels"]
resolver = "2"
Expand Down
3 changes: 3 additions & 0 deletions candle-core/src/quantized/k_quants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ impl GgmlType for BlockQ4_0 {
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);

#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);

let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
Expand Down
52 changes: 51 additions & 1 deletion candle-core/src/quantized/simd128.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,59 @@
use super::k_quants::{BlockQ8_0, QK8_0};
use super::k_quants::{BlockQ4_0, BlockQ8_0, QK8_0};
use crate::Result;
use half::f16;

use core::arch::wasm32::*;

#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let x1234 = v128_load(x.qs.as_ptr() as *const v128);
let x12 = v128_and(x1234, u8x16_splat(0x0F));
let x12 = i8x16_sub(x12, i8x16_splat(8));
let x34 = u8x16_shr(x1234, 4);
let x34 = i8x16_sub(x34, i8x16_splat(8));

let x1 = i16x8_extend_low_i8x16(x12);
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
let sum_xy = i32x4_dot_i16x8(x1, y1);

let x2 = i16x8_extend_high_i8x16(x12);
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));

let x3 = i16x8_extend_low_i8x16(x34);
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));

let x4 = i16x8_extend_high_i8x16(x34);
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));

let sum_xy = f32x4_convert_i32x4(sum_xy);

// f32x4_relaxed_madd is nightly only.
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
let scaled = f32x4_mul(sum_xy, d);
acc = f32x4_add(acc, scaled)
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}

#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
Expand Down
15 changes: 15 additions & 0 deletions candle-wasm-tests/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "candle-wasm-tests"
version.workspace = true
edition.workspace = true
description = "WASM tests for candle"
keywords.workspace = true
categories.workspace = true

[dependencies]
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
rand = { workspace = true }
getrandom = { version = "0.2", features = ["js"] }

[dev-dependencies]
wasm-bindgen-test = "0.3.0"
12 changes: 12 additions & 0 deletions candle-wasm-tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Run the tests with:
```bash
RUST_LOG=wasm_bindgen_test_runner wasm-pack test --chrome --headless
```
Or:
```bash
wasm-pack test --chrome
```

If you get an "invalid session id" failure in headless mode, check that logs and
it may well be that your ChromeDriver is not at the same version as your
browser.
14 changes: 14 additions & 0 deletions candle-wasm-tests/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
pub fn add(left: usize, right: usize) -> usize {
left + right
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
}
}
140 changes: 140 additions & 0 deletions candle-wasm-tests/tests/quantized_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
use candle::{
quantized::{self, k_quants, GgmlDType, GgmlType},
test_utils::to_vec2_round,
Device, Result, Tensor,
};

use wasm_bindgen_test::*;
wasm_bindgen_test_configure!(run_in_browser);

#[wasm_bindgen_test]
fn quantized_matmul_neg() -> Result<()> {
let cpu = &Device::Cpu;
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k))
.map(|v| v as f32 - (m * k) as f32 / 2.0)
.collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..k * n)
.map(|v| v as f32 - (k * n) as f32 / 3.0)
.collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!(
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
&[
243524.0, -19596.0, -285051.0, -549815.0, 23777.0, 21651.0, 19398.0, 18367.0,
-196472.0, 63012.0, 324585.0, 587902.0
]
);
let mm = tensor_lhs.matmul(&tensor_rhs)?;
assert_eq!(
to_vec2_round(&mm, 0)?,
&[
[244064.0, -20128.0, -284320.0, -548512.0],
[23563.0, 21515.0, 19467.0, 17419.0],
[-196939.0, 63157.0, 323253.0, 583349.0]
]
);

let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor);
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243524.0, -19596.0, -285051.0, -549815.0],
[23777.0, 21651.0, 19398.0, 18367.0],
[-196472.0, 63012.0, 324585.0, 587902.0]
]
);

Ok(())
}

/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
const GGML_TEST_SIZE: usize = 32 * 128;
(0..GGML_TEST_SIZE)
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
.collect()
}

/// Very simple dot product implementation
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum()
}

/// Returns the error achieved by the GGML matmul unit test.
fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
let err = match dtype {
GgmlDType::F16 => 0.000010,
GgmlDType::Q2K => 0.004086,
GgmlDType::Q3K => 0.016148,
GgmlDType::Q4K => 0.002425,
GgmlDType::Q5K => 0.000740,
GgmlDType::Q6K => 0.000952,
GgmlDType::Q4_0 => 0.001143,
GgmlDType::Q4_1 => 0.007784,
GgmlDType::Q5_0 => 0.001353,
GgmlDType::Q5_1 => 0.001363,
GgmlDType::Q8_0 => 0.000092,
_ => candle::bail!("No GGML results for quantization type {dtype:?}",),
};
Ok(err)
}

/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
let a = create_ggml_like_vector(0.0);
let b = create_ggml_like_vector(1.0);
let length = a.len();

let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
T::from_float(&a, &mut a_quant)?;
T::VecDotType::from_float(&b, &mut b_quant)?;

let result = T::vec_dot(length, &a_quant, &b_quant)?;
let reference_result = vec_dot_reference(&a, &b);

let error = (result - reference_result).abs() / length as f32;

let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;

if error > GGML_MAX_DOT_PRODUCT_ERROR {
candle::bail!(
"Dot product error {} exceeds max error {}",
error,
GGML_MAX_DOT_PRODUCT_ERROR
);
}

// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
// => we use a slightly higher error threshold
const ERROR_LENIENCY: f32 = 0.00001;
if error - ERROR_LENIENCY > ggml_error {
candle::bail!(
"Dot product error {} exceeds ggml reference error {}",
error,
ggml_error
);
}
Ok(())
}

#[wasm_bindgen_test]
fn quantized_matmul_q40() -> Result<()> {
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ4_0>()?;
Ok(())
}

#[wasm_bindgen_test]
fn quantized_matmul_q80() -> Result<()> {
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ8_0>()?;
Ok(())
}
16 changes: 16 additions & 0 deletions candle-wasm-tests/webdriver.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"moz:firefoxOptions": {
"prefs": {
"media.navigator.streams.fake": true,
"media.navigator.permission.disabled": true
},
"args": []
},
"goog:chromeOptions": {
"args": [
"--use-fake-device-for-media-stream",
"--use-fake-ui-for-media-stream"
]
}
}

0 comments on commit 667f01c

Please sign in to comment.