Skip to content

Commit

Permalink
Fix rounding in the linear quantize codec + bump to v0.2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
juntyr committed Oct 2, 2024
1 parent 03a0077 commit 677973b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ numcodecs-wasm-guest = { version = "0.1", path = "crates/numcodecs-wasm-guest",
numcodecs-bit-round = { version = "0.1", path = "codecs/bit-round", default-features = false }
numcodecs-fixed-offset-scale = { version = "0.1", path = "codecs/fixed-offset-scale", default-features = false }
numcodecs-identity = { version = "0.1", path = "codecs/identity", default-features = false }
numcodecs-linear-quantize = { version = "0.1", path = "codecs/linear-quantize", default-features = false }
numcodecs-linear-quantize = { version = "0.2", path = "codecs/linear-quantize", default-features = false }
numcodecs-log = { version = "0.2", path = "codecs/log", default-features = false }
numcodecs-reinterpret = { version = "0.1", path = "codecs/reinterpret", default-features = false }
numcodecs-round = { version = "0.1", path = "codecs/round", default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion codecs/linear-quantize/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "numcodecs-linear-quantize"
version = "0.1.0"
version = "0.2.0"
edition = { workspace = true }
authors = { workspace = true }
repository = { workspace = true }
Expand Down
44 changes: 22 additions & 22 deletions codecs/linear-quantize/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl Codec for LinearQuantizeCodec {
bits @ ..=8 => AnyArray::U8(
Array1::from_vec(quantize(data, |x| {
let max = f32::from(u8::MAX >> (8 - bits));
let x = (x * scale_for_bits::<f32>(bits)).clamp(0.0, max);
let x = (x * scale_for_bits::<f32>(bits) + 0.5).clamp(0.0, max);
#[allow(unsafe_code)]
// Safety: x is clamped beforehand
unsafe {
Expand All @@ -108,7 +108,7 @@ impl Codec for LinearQuantizeCodec {
bits @ 9..=16 => AnyArray::U16(
Array1::from_vec(quantize(data, |x| {
let max = f32::from(u16::MAX >> (16 - bits));
let x = (x * scale_for_bits::<f32>(bits)).clamp(0.0, max);
let x = (x * scale_for_bits::<f32>(bits) + 0.5).clamp(0.0, max);
#[allow(unsafe_code)]
// Safety: x is clamped beforehand
unsafe {
Expand All @@ -121,7 +121,7 @@ impl Codec for LinearQuantizeCodec {
Array1::from_vec(quantize(data, |x| {
// we need to use f64 here to have sufficient precision
let max = f64::from(u32::MAX >> (32 - bits));
let x = (f64::from(x) * scale_for_bits::<f64>(bits)).clamp(0.0, max);
let x = (f64::from(x) * scale_for_bits::<f64>(bits) + 0.5).clamp(0.0, max);
#[allow(unsafe_code)]
// Safety: x is clamped beforehand
unsafe {
Expand All @@ -134,9 +134,10 @@ impl Codec for LinearQuantizeCodec {
Array1::from_vec(quantize(data, |x| {
// we need to use TwoFloat here to have sufficient precision
let max = TwoFloat::from(u64::MAX >> (64 - bits));
let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits))
.max(TwoFloat::from(0.0))
.min(max);
let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
+ TwoFloat::from(0.5))
.max(TwoFloat::from(0.0))
.min(max);
#[allow(unsafe_code)]
// Safety: x is clamped beforehand
unsafe {
Expand All @@ -150,7 +151,7 @@ impl Codec for LinearQuantizeCodec {
bits @ ..=8 => AnyArray::U8(
Array1::from_vec(quantize(data, |x| {
let max = f64::from(u8::MAX >> (8 - bits));
let x = (x * scale_for_bits::<f64>(bits)).clamp(0.0, max);
let x = (x * scale_for_bits::<f64>(bits) + 0.5).clamp(0.0, max);
#[allow(unsafe_code)]
// Safety: x is clamped beforehand
unsafe {
Expand All @@ -162,7 +163,7 @@ impl Codec for LinearQuantizeCodec {
bits @ 9..=16 => AnyArray::U16(
Array1::from_vec(quantize(data, |x| {
let max = f64::from(u16::MAX >> (16 - bits));
let x = (x * scale_for_bits::<f64>(bits)).clamp(0.0, max);
let x = (x * scale_for_bits::<f64>(bits) + 0.5).clamp(0.0, max);
#[allow(unsafe_code)]
// Safety: x is clamped beforehand
unsafe {
Expand All @@ -174,7 +175,7 @@ impl Codec for LinearQuantizeCodec {
bits @ 17..=32 => AnyArray::U32(
Array1::from_vec(quantize(data, |x| {
let max = f64::from(u32::MAX >> (32 - bits));
let x = (x * scale_for_bits::<f64>(bits)).clamp(0.0, max);
let x = (x * scale_for_bits::<f64>(bits) + 0.5).clamp(0.0, max);
#[allow(unsafe_code)]
// Safety: x is clamped beforehand
unsafe {
Expand All @@ -187,9 +188,10 @@ impl Codec for LinearQuantizeCodec {
Array1::from_vec(quantize(data, |x| {
// we need to use TwoFloat here to have sufficient precision
let max = TwoFloat::from(u64::MAX >> (64 - bits));
let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits))
.max(TwoFloat::from(0.0))
.min(max);
let x = (TwoFloat::from(x) * scale_for_bits::<f64>(bits)
+ TwoFloat::from(0.5))
.max(TwoFloat::from(0.0))
.min(max);
#[allow(unsafe_code)]
// Safety: x is clamped beforehand
unsafe {
Expand Down Expand Up @@ -664,7 +666,7 @@ mod tests {
use super::*;

#[test]
fn exact_roundtrip_f32() -> Result<(), LinearQuantizeCodecError> {
fn exact_roundtrip_f32_from() -> Result<(), LinearQuantizeCodecError> {
for bits in 1..=16 {
let codec = LinearQuantizeCodec {
dtype: LinearQuantizeDType::F32,
Expand Down Expand Up @@ -697,8 +699,8 @@ mod tests {
}

#[test]
fn almost_roundtrip_f32() -> Result<(), LinearQuantizeCodecError> {
for bits in 17..=64 {
fn exact_roundtrip_f32_as() -> Result<(), LinearQuantizeCodecError> {
for bits in 1..=64 {
let codec = LinearQuantizeCodec {
dtype: LinearQuantizeDType::F32,
#[allow(unsafe_code)]
Expand All @@ -724,16 +726,15 @@ mod tests {
};

for (o, d) in data.iter().zip(decoded.iter()) {
// FIXME: there seem to be some rounding errors
assert!((o - d).abs() <= 1.0);
assert_eq!(o.to_bits(), d.to_bits());
}
}

Ok(())
}

#[test]
fn exact_roundtrip_f64() -> Result<(), LinearQuantizeCodecError> {
fn exact_roundtrip_f64_from() -> Result<(), LinearQuantizeCodecError> {
for bits in 1..=32 {
let codec = LinearQuantizeCodec {
dtype: LinearQuantizeDType::F64,
Expand Down Expand Up @@ -766,8 +767,8 @@ mod tests {
}

#[test]
fn almost_roundtrip_f64() -> Result<(), LinearQuantizeCodecError> {
for bits in 33..=64 {
fn exact_roundtrip_f64_as() -> Result<(), LinearQuantizeCodecError> {
for bits in 1..=64 {
let codec = LinearQuantizeCodec {
dtype: LinearQuantizeDType::F64,
#[allow(unsafe_code)]
Expand All @@ -793,8 +794,7 @@ mod tests {
};

for (o, d) in data.iter().zip(decoded.iter()) {
// FIXME: there seem to be some rounding errors
assert!((o - d).abs() < 2.0);
assert_eq!(o.to_bits(), d.to_bits());
}
}

Expand Down

0 comments on commit 677973b

Please sign in to comment.