Skip to content

Commit c570a76

Browse files
authored
improve e4m3 decoding. (#43)
Co-authored-by: LeiWang199 <leiwang199>
1 parent 6ba204f commit c570a76

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

python/bitblas/quantization/quantization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype
142142
def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
143143
assert nbit == 8
144144
assert dtype == "float16"
145-
s_f16 = (val >> tir.const(7, "int16")) << tir.const(15, "int16")
146-
offset = tir.Select(s_f16 == 0, tir.const(8192, "int16"), tir.const(-8192, "int16"))
147-
e_f16 = ((val << tir.const(7, "int16")) + offset)
145+
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
146+
prefix = tir.Select(s_f16 == 0, tir.const(0x2000, "uint16"), tir.const(0xc000, "uint16"))
147+
e_f16 = (((val & tir.const(127, "uint16")) << tir.const(7, "uint16"))) | prefix
148148
return tir.reinterpret("float16", s_f16 | e_f16)
149149

150150

testing/python/operators/test_general_matmul_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def map_torch_type(intype):
166166
print("torch_ref_out", ref_out)
167167
print("bitblas_out", bitblas_out)
168168

169-
torch.testing.assert_allclose(ref_out, bitblas_out, rtol=1e-2, atol=1e-2)
169+
torch.testing.assert_close(ref_out, bitblas_out, rtol=1e-1, atol=1e-1)
170170

171171

172172
# fmt: on

0 commit comments

Comments
 (0)