Skip to content

Commit

Permalink
improve e4m3 decoding. (#43)
Browse files Browse the repository at this point in the history
Co-authored-by: LeiWang199 <leiwang199>
  • Loading branch information
LeiWang1999 authored May 21, 2024
1 parent 6ba204f commit c570a76
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions python/bitblas/quantization/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype
def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "int16")) << tir.const(15, "int16")
offset = tir.Select(s_f16 == 0, tir.const(8192, "int16"), tir.const(-8192, "int16"))
e_f16 = ((val << tir.const(7, "int16")) + offset)
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
prefix = tir.Select(s_f16 == 0, tir.const(0x2000, "uint16"), tir.const(0xc000, "uint16"))
e_f16 = (((val & tir.const(127, "uint16")) << tir.const(7, "uint16"))) | prefix
return tir.reinterpret("float16", s_f16 | e_f16)


Expand Down
2 changes: 1 addition & 1 deletion testing/python/operators/test_general_matmul_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def map_torch_type(intype):
print("torch_ref_out", ref_out)
print("bitblas_out", bitblas_out)

torch.testing.assert_allclose(ref_out, bitblas_out, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(ref_out, bitblas_out, rtol=1e-1, atol=1e-1)


# fmt: on
Expand Down

0 comments on commit c570a76

Please sign in to comment.