From c570a76d2de3db38416b3c75f1d103e008d19a2c Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 21 May 2024 19:54:04 +0800 Subject: [PATCH] improve e4m3 decoding. (#43) Co-authored-by: LeiWang199 --- python/bitblas/quantization/quantization.py | 6 +++--- testing/python/operators/test_general_matmul_fp8.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/bitblas/quantization/quantization.py b/python/bitblas/quantization/quantization.py index d9f360947..d68d437d4 100644 --- a/python/bitblas/quantization/quantization.py +++ b/python/bitblas/quantization/quantization.py @@ -142,9 +142,9 @@ def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): assert nbit == 8 assert dtype == "float16" - s_f16 = (val >> tir.const(7, "int16")) << tir.const(15, "int16") - offset = tir.Select(s_f16 == 0, tir.const(8192, "int16"), tir.const(-8192, "int16")) - e_f16 = ((val << tir.const(7, "int16")) + offset) + s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") + prefix = tir.Select(s_f16 == 0, tir.const(0x2000, "uint16"), tir.const(0xc000, "uint16")) + e_f16 = (((val & tir.const(127, "uint16")) << tir.const(7, "uint16"))) | prefix return tir.reinterpret("float16", s_f16 | e_f16) diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py index 3d0a7be2f..5b7de9ab0 100644 --- a/testing/python/operators/test_general_matmul_fp8.py +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -166,7 +166,7 @@ def map_torch_type(intype): print("torch_ref_out", ref_out) print("bitblas_out", bitblas_out) - torch.testing.assert_allclose(ref_out, bitblas_out, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(ref_out, bitblas_out, rtol=1e-1, atol=1e-1) # fmt: on