forked from huggingface/candle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Simd128 vec-dot for q4_0. (huggingface#974)
* Simd128 vec-dot for q4_0. * Bugfix. * Add wasm tests. * Bugfix for the q40 vecdot. * More quantization tests.
- Loading branch information
1 parent
e59784e
commit 667f01c
Showing
8 changed files
with
253 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
[package] | ||
name = "candle-wasm-tests" | ||
version.workspace = true | ||
edition.workspace = true | ||
description = "WASM tests for candle" | ||
keywords.workspace = true | ||
categories.workspace = true | ||
|
||
[dependencies] | ||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" } | ||
rand = { workspace = true } | ||
getrandom = { version = "0.2", features = ["js"] } | ||
|
||
[dev-dependencies] | ||
wasm-bindgen-test = "0.3.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
Run the tests with: | ||
```bash | ||
RUST_LOG=wasm_bindgen_test_runner wasm-pack test --chrome --headless | ||
``` | ||
Or: | ||
```bash | ||
wasm-pack test --chrome | ||
``` | ||
|
||
If you get an "invalid session id" failure in headless mode, check that logs and | ||
it may well be that your ChromeDriver is not at the same version as your | ||
browser. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
pub fn add(left: usize, right: usize) -> usize { | ||
left + right | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn it_works() { | ||
let result = add(2, 2); | ||
assert_eq!(result, 4); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
use candle::{ | ||
quantized::{self, k_quants, GgmlDType, GgmlType}, | ||
test_utils::to_vec2_round, | ||
Device, Result, Tensor, | ||
}; | ||
|
||
use wasm_bindgen_test::*; | ||
wasm_bindgen_test_configure!(run_in_browser); | ||
|
||
#[wasm_bindgen_test] | ||
fn quantized_matmul_neg() -> Result<()> { | ||
let cpu = &Device::Cpu; | ||
let (m, k, n) = (3, 64, 4); | ||
let lhs = (0..(m * k)) | ||
.map(|v| v as f32 - (m * k) as f32 / 2.0) | ||
.collect::<Vec<_>>(); | ||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?; | ||
let mut dst = vec![42.; 3 * 4]; | ||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; | ||
let rhs = (0..k * n) | ||
.map(|v| v as f32 - (k * n) as f32 / 3.0) | ||
.collect::<Vec<_>>(); | ||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?; | ||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; | ||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; | ||
assert_eq!( | ||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(), | ||
&[ | ||
243524.0, -19596.0, -285051.0, -549815.0, 23777.0, 21651.0, 19398.0, 18367.0, | ||
-196472.0, 63012.0, 324585.0, 587902.0 | ||
] | ||
); | ||
let mm = tensor_lhs.matmul(&tensor_rhs)?; | ||
assert_eq!( | ||
to_vec2_round(&mm, 0)?, | ||
&[ | ||
[244064.0, -20128.0, -284320.0, -548512.0], | ||
[23563.0, 21515.0, 19467.0, 17419.0], | ||
[-196939.0, 63157.0, 323253.0, 583349.0] | ||
] | ||
); | ||
|
||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; | ||
let matmul = quantized::QMatMul::from_qtensor(qtensor); | ||
let res = matmul.forward(&tensor_lhs)?; | ||
assert_eq!( | ||
to_vec2_round(&res, 0)?, | ||
&[ | ||
[243524.0, -19596.0, -285051.0, -549815.0], | ||
[23777.0, 21651.0, 19398.0, 18367.0], | ||
[-196472.0, 63012.0, 324585.0, 587902.0] | ||
] | ||
); | ||
|
||
Ok(()) | ||
} | ||
|
||
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 | ||
fn create_ggml_like_vector(offset: f32) -> Vec<f32> { | ||
const GGML_TEST_SIZE: usize = 32 * 128; | ||
(0..GGML_TEST_SIZE) | ||
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos()) | ||
.collect() | ||
} | ||
|
||
/// Very simple dot product implementation | ||
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 { | ||
a.iter().zip(b).map(|(a, b)| a * b).sum() | ||
} | ||
|
||
/// Returns the error achieved by the GGML matmul unit test. | ||
fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> { | ||
let err = match dtype { | ||
GgmlDType::F16 => 0.000010, | ||
GgmlDType::Q2K => 0.004086, | ||
GgmlDType::Q3K => 0.016148, | ||
GgmlDType::Q4K => 0.002425, | ||
GgmlDType::Q5K => 0.000740, | ||
GgmlDType::Q6K => 0.000952, | ||
GgmlDType::Q4_0 => 0.001143, | ||
GgmlDType::Q4_1 => 0.007784, | ||
GgmlDType::Q5_0 => 0.001353, | ||
GgmlDType::Q5_1 => 0.001363, | ||
GgmlDType::Q8_0 => 0.000092, | ||
_ => candle::bail!("No GGML results for quantization type {dtype:?}",), | ||
}; | ||
Ok(err) | ||
} | ||
|
||
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 | ||
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> { | ||
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02; | ||
let a = create_ggml_like_vector(0.0); | ||
let b = create_ggml_like_vector(1.0); | ||
let length = a.len(); | ||
|
||
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE]; | ||
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE]; | ||
T::from_float(&a, &mut a_quant)?; | ||
T::VecDotType::from_float(&b, &mut b_quant)?; | ||
|
||
let result = T::vec_dot(length, &a_quant, &b_quant)?; | ||
let reference_result = vec_dot_reference(&a, &b); | ||
|
||
let error = (result - reference_result).abs() / length as f32; | ||
|
||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?; | ||
|
||
if error > GGML_MAX_DOT_PRODUCT_ERROR { | ||
candle::bail!( | ||
"Dot product error {} exceeds max error {}", | ||
error, | ||
GGML_MAX_DOT_PRODUCT_ERROR | ||
); | ||
} | ||
|
||
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML | ||
// => we use a slightly higher error threshold | ||
const ERROR_LENIENCY: f32 = 0.00001; | ||
if error - ERROR_LENIENCY > ggml_error { | ||
candle::bail!( | ||
"Dot product error {} exceeds ggml reference error {}", | ||
error, | ||
ggml_error | ||
); | ||
} | ||
Ok(()) | ||
} | ||
|
||
#[wasm_bindgen_test] | ||
fn quantized_matmul_q40() -> Result<()> { | ||
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ4_0>()?; | ||
Ok(()) | ||
} | ||
|
||
#[wasm_bindgen_test] | ||
fn quantized_matmul_q80() -> Result<()> { | ||
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ8_0>()?; | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
{ | ||
"moz:firefoxOptions": { | ||
"prefs": { | ||
"media.navigator.streams.fake": true, | ||
"media.navigator.permission.disabled": true | ||
}, | ||
"args": [] | ||
}, | ||
"goog:chromeOptions": { | ||
"args": [ | ||
"--use-fake-device-for-media-stream", | ||
"--use-fake-ui-for-media-stream" | ||
] | ||
} | ||
} | ||
|