diff --git a/Cargo.toml b/Cargo.toml index 63cff1d6..a3413c99 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ numcodecs-wasm-guest = { version = "0.1", path = "crates/numcodecs-wasm-guest", numcodecs-bit-round = { version = "0.1", path = "codecs/bit-round", default-features = false } numcodecs-fixed-offset-scale = { version = "0.1", path = "codecs/fixed-offset-scale", default-features = false } numcodecs-identity = { version = "0.1", path = "codecs/identity", default-features = false } -numcodecs-linear-quantize = { version = "0.1", path = "codecs/linear-quantize", default-features = false } +numcodecs-linear-quantize = { version = "0.2", path = "codecs/linear-quantize", default-features = false } numcodecs-log = { version = "0.2", path = "codecs/log", default-features = false } numcodecs-reinterpret = { version = "0.1", path = "codecs/reinterpret", default-features = false } numcodecs-round = { version = "0.1", path = "codecs/round", default-features = false } diff --git a/codecs/linear-quantize/Cargo.toml b/codecs/linear-quantize/Cargo.toml index 6200061b..49cde8f5 100644 --- a/codecs/linear-quantize/Cargo.toml +++ b/codecs/linear-quantize/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "numcodecs-linear-quantize" -version = "0.1.0" +version = "0.2.0" edition = { workspace = true } authors = { workspace = true } repository = { workspace = true } diff --git a/codecs/linear-quantize/src/lib.rs b/codecs/linear-quantize/src/lib.rs index fc88170a..9953deec 100644 --- a/codecs/linear-quantize/src/lib.rs +++ b/codecs/linear-quantize/src/lib.rs @@ -96,7 +96,7 @@ impl Codec for LinearQuantizeCodec { bits @ ..=8 => AnyArray::U8( Array1::from_vec(quantize(data, |x| { let max = f32::from(u8::MAX >> (8 - bits)); - let x = (x * scale_for_bits::(bits)).clamp(0.0, max); + let x = (x * scale_for_bits::(bits) + 0.5).clamp(0.0, max); #[allow(unsafe_code)] // Safety: x is clamped beforehand unsafe { @@ -108,7 +108,7 @@ impl Codec for LinearQuantizeCodec { bits @ 9..=16 => AnyArray::U16( Array1::from_vec(quantize(data, |x| { let max = f32::from(u16::MAX >> (16 - bits)); - let x = (x * scale_for_bits::(bits)).clamp(0.0, max); + let x = (x * scale_for_bits::(bits) + 0.5).clamp(0.0, max); #[allow(unsafe_code)] // Safety: x is clamped beforehand unsafe { @@ -121,7 +121,7 @@ impl Codec for LinearQuantizeCodec { Array1::from_vec(quantize(data, |x| { // we need to use f64 here to have sufficient precision let max = f64::from(u32::MAX >> (32 - bits)); - let x = (f64::from(x) * scale_for_bits::(bits)).clamp(0.0, max); + let x = (f64::from(x) * scale_for_bits::(bits) + 0.5).clamp(0.0, max); #[allow(unsafe_code)] // Safety: x is clamped beforehand unsafe { @@ -134,9 +134,10 @@ impl Codec for LinearQuantizeCodec { Array1::from_vec(quantize(data, |x| { // we need to use TwoFloat here to have sufficient precision let max = TwoFloat::from(u64::MAX >> (64 - bits)); - let x = (TwoFloat::from(x) * scale_for_bits::(bits)) - .max(TwoFloat::from(0.0)) - .min(max); + let x = (TwoFloat::from(x) * scale_for_bits::(bits) + + TwoFloat::from(0.5)) + .max(TwoFloat::from(0.0)) + .min(max); #[allow(unsafe_code)] // Safety: x is clamped beforehand unsafe { @@ -150,7 +151,7 @@ impl Codec for LinearQuantizeCodec { bits @ ..=8 => AnyArray::U8( Array1::from_vec(quantize(data, |x| { let max = f64::from(u8::MAX >> (8 - bits)); - let x = (x * scale_for_bits::(bits)).clamp(0.0, max); + let x = (x * scale_for_bits::(bits) + 0.5).clamp(0.0, max); #[allow(unsafe_code)] // Safety: x is clamped beforehand unsafe { @@ -162,7 +163,7 @@ impl Codec for LinearQuantizeCodec { bits @ 9..=16 => AnyArray::U16( Array1::from_vec(quantize(data, |x| { let max = f64::from(u16::MAX >> (16 - bits)); - let x = (x * scale_for_bits::(bits)).clamp(0.0, max); + let x = (x * scale_for_bits::(bits) + 0.5).clamp(0.0, max); #[allow(unsafe_code)] // Safety: x is clamped beforehand unsafe { @@ -174,7 +175,7 @@ impl Codec for LinearQuantizeCodec { bits @ 17..=32 => AnyArray::U32( Array1::from_vec(quantize(data, |x| { let max = f64::from(u32::MAX >> (32 - bits)); - let x = (x * scale_for_bits::(bits)).clamp(0.0, max); + let x = (x * scale_for_bits::(bits) + 0.5).clamp(0.0, max); #[allow(unsafe_code)] // Safety: x is clamped beforehand unsafe { @@ -187,9 +188,10 @@ impl Codec for LinearQuantizeCodec { Array1::from_vec(quantize(data, |x| { // we need to use TwoFloat here to have sufficient precision let max = TwoFloat::from(u64::MAX >> (64 - bits)); - let x = (TwoFloat::from(x) * scale_for_bits::(bits)) - .max(TwoFloat::from(0.0)) - .min(max); + let x = (TwoFloat::from(x) * scale_for_bits::(bits) + + TwoFloat::from(0.5)) + .max(TwoFloat::from(0.0)) + .min(max); #[allow(unsafe_code)] // Safety: x is clamped beforehand unsafe { @@ -664,7 +666,7 @@ mod tests { use super::*; #[test] - fn exact_roundtrip_f32() -> Result<(), LinearQuantizeCodecError> { + fn exact_roundtrip_f32_from() -> Result<(), LinearQuantizeCodecError> { for bits in 1..=16 { let codec = LinearQuantizeCodec { dtype: LinearQuantizeDType::F32, @@ -697,8 +699,8 @@ mod tests { } #[test] - fn almost_roundtrip_f32() -> Result<(), LinearQuantizeCodecError> { - for bits in 17..=64 { + fn exact_roundtrip_f32_as() -> Result<(), LinearQuantizeCodecError> { + for bits in 1..=64 { let codec = LinearQuantizeCodec { dtype: LinearQuantizeDType::F32, #[allow(unsafe_code)] @@ -724,8 +726,7 @@ mod tests { }; for (o, d) in data.iter().zip(decoded.iter()) { - // FIXME: there seem to be some rounding errors - assert!((o - d).abs() <= 1.0); + assert_eq!(o.to_bits(), d.to_bits()); } } @@ -733,7 +734,7 @@ mod tests { } #[test] - fn exact_roundtrip_f64() -> Result<(), LinearQuantizeCodecError> { + fn exact_roundtrip_f64_from() -> Result<(), LinearQuantizeCodecError> { for bits in 1..=32 { let codec = LinearQuantizeCodec { dtype: LinearQuantizeDType::F64, @@ -766,8 +767,8 @@ mod tests { } #[test] - fn almost_roundtrip_f64() -> Result<(), LinearQuantizeCodecError> { - for bits in 33..=64 { + fn exact_roundtrip_f64_as() -> Result<(), LinearQuantizeCodecError> { + for bits in 1..=64 { let codec = LinearQuantizeCodec { dtype: LinearQuantizeDType::F64, #[allow(unsafe_code)] @@ -793,8 +794,7 @@ mod tests { }; for (o, d) in data.iter().zip(decoded.iter()) { - // FIXME: there seem to be some rounding errors - assert!((o - d).abs() < 2.0); + assert_eq!(o.to_bits(), d.to_bits()); } }