Skip to content

Commit

Permalink
[Dev] Fix a but within FP8 E4M3 Fast Decoding (#54)
Browse files Browse the repository at this point in the history
* improve e4m3 decoding.

* append fp16xint1

* Update submodule commit reference

* chore: Update shared memory scope for float32 output dtype

* BUGFIX: UINT8/INT8 Decoding

* feat: Add rasterization options for roller module

* Refactor tensorcore_legalization method to optimize tensor core usage

* feat: Add function to collect variables from expression, improve for splitk

* chore: Update typing import in __init__.py

* chore: Refactor CPU execution of operators

* Refactor matmul implementation for splitk layout

* Refactor matmul implementation for splitk layout

* Refactor matmul implementation for splitk layout

* chore: Update version to 0.0.1.dev8

* chore: Enable debug output in bitblas.set_debug_level()

* Refactor Linear module matmul implementation for splitk layout

* Refactor matmul implementation for splitk layout

* Refactor CUDA kernel launch string for dynamic symbolic set

* Bumpt version to v0.0.1.dev9

* Refactor CUDA kernel launch string for dynamic symbolic set

* Bump version to v0.0.1.dev10

* Refactor CUDA kernel launch string for dynamic symbolic set

* Bump version to v0.0.1.dev12 and add MatmulConfigWithSplitK and MatmulWithSplitK

---------

Co-authored-by: LeiWang199 <leiwang199>
  • Loading branch information
LeiWang1999 committed Jun 6, 2024
1 parent 857732b commit c090df6
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.1.dev10
0.0.1.dev12
3 changes: 2 additions & 1 deletion python/bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from . import testing # noqa: F401
from .utils import auto_detect_nvidia_target # noqa: F401
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401
from .ops.matmul_dequantize import MatmulWeightOnlyDequantizeConfig, MatmulWeightOnlyDequantize # noqa: F401
from .module import Linear # noqa: F401

Expand Down Expand Up @@ -81,4 +82,4 @@ def _init_logger():

_init_logger()

__version__ = "0.0.1.dev10"
__version__ = "0.0.1.dev12"
15 changes: 12 additions & 3 deletions python/bitblas/quantization/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,23 @@ def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype
return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)


def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
prefix = tir.Select(s_f16 == 0, tir.const(0x2000, "uint16"), tir.const(0xc000, "uint16"))
e_f16 = (((val & tir.const(127, "uint16")) << tir.const(7, "uint16"))) | prefix
e4 = val & tir.const(0x40, "uint16")
prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"), tir.const(0x4000, "uint16"))
e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | prefix
return tir.reinterpret("float16", s_f16 | e_f16)

def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
e4 = val & tir.const(0x40, "uint16")
e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16"))
e_f16 = e_f16 ^ tir.const(0x2000, "uint16")
return tir.reinterpret("float16", s_f16 | e_f16)

def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
Expand Down
4 changes: 4 additions & 0 deletions testing/python/operators/test_general_matmul_splitk_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def map_torch_type(intype):
matmul.forward(torch_a, torch_b, output=bitblas_out)
print("torch_ref_out", ref_out)
print("bitblas_out", bitblas_out)

matmul.forward(torch_a, torch_b, output=bitblas_out)
print("torch_ref_out", ref_out)
print("bitblas_out", bitblas_out)

torch.testing.assert_close(bitblas_out, ref_out, rtol=1e0, atol=1e-1)

Expand Down

0 comments on commit c090df6

Please sign in to comment.