Skip to content

Commit

Permalink
[BUGFix] Fix UINT/INT8 dequantize implementation and optimize the sch…
Browse files Browse the repository at this point in the history
…edule template for float32 accum (#46)

* improve e4m3 decoding.

* append fp16xint1

* Update submodule commit reference

* chore: Update shared memory scope for float32 output dtype

* BUGFIX: UINT8/INT8 Decoding

---------

Co-authored-by: LeiWang199 <leiwang199>
  • Loading branch information
LeiWang1999 committed Jun 2, 2024
1 parent efab450 commit c4400b3
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 21 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
4 changes: 3 additions & 1 deletion python/bitblas/base/roller/hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from .rasterization import *


class TensorCoreExtraConfig:
"""
This class is used to store extra information for tensorcore
Expand Down Expand Up @@ -228,7 +229,8 @@ def __repr__(self) -> str:
def complete_config(self, node: PrimFuncNode):
# analysis pass context, for int8 mma, we should merge static shared memory
merge_static_smem = False
if self.use_tc and self.intrin_info.in_dtype == "int8":
# int32 and float32 accum may take too much shared memory
if self.use_tc and self.intrin_info.out_dtype in ["float32", "int32"]:
merge_static_smem = True
self.pass_context = {"tir.merge_static_smem": merge_static_smem}
return self
2 changes: 2 additions & 0 deletions python/bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ def _score(node, thread): # small is better
intrin_info = node.get_tag("intrin_info")
if intrin_info:
codegen_dict.intrin_info = IntrinInfo(**intrin_info)
if intrin_info["out_dtype"] in ["float32"]:
codegen_dict.shared_scope = "shared.dyn"
# smem capacity
if td.smem_cost > self.arch.smem_cap:
codegen_dict.shared_scope = "shared.dyn"
Expand Down
39 changes: 36 additions & 3 deletions python/bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,11 @@ def fast_decode_impl(
TensorIntrin.register(
LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN,
*get_fast_decode_intrin(
source_bit=2, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16),
source_bit=2,
source_format="int",
storage_dtype="int8",
target_dtype="int8",
loops_extent=16),
)

LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_")
Expand All @@ -1497,10 +1501,13 @@ def fast_decode_impl(
TensorIntrin.register(
LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN,
*get_fast_decode_intrin(
source_bit=1, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16),
source_bit=1,
source_format="int",
storage_dtype="int8",
target_dtype="int8",
loops_extent=16),
)


LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_")
TensorIntrin.register(
LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN,
Expand Down Expand Up @@ -1553,6 +1560,32 @@ def fast_decode_impl(
),
)

LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i1_to_int8_to_f16_l8_")
TensorIntrin.register(
LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN,
*get_fast_decode_intrin(
source_bit=1,
storage_dtype="int8",
source_format="int",
target_dtype="float16",
loops_extent=8,
),
)

LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = (
"lop3_fast_decode_i1_to_int8_to_f16_l8_scale_")
TensorIntrin.register(
LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN,
*get_fast_decode_intrin(
source_bit=1,
storage_dtype="int8",
source_format="int",
target_dtype="float16",
loops_extent=8,
with_scale=True,
),
)


def get_lop3_intrin_group(
out_dtype: Literal["float16", "int8"],
Expand Down
13 changes: 11 additions & 2 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,18 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]):
object.__setattr__(self, "zeros_mode", "original")

def __initialize_fast_decoding(self, fast_decoding: Optional[bool]):

def is_not_fast_decoding_supported():
conditions = []
conditions.append("int" not in self.W_dtype)
conditions.append(self.W_dtype == self.A_dtype)
# int8,uint8 also do not implement and also do not require fast decoding
conditions.append(self.W_dtype in ["int8", "uint8"])
return any(conditions)

if fast_decoding is not None:
object.__setattr__(self, "fast_decoding", fast_decoding)
elif ("int" not in self.W_dtype or self.W_dtype == self.A_dtype):
elif is_not_fast_decoding_supported():
object.__setattr__(self, "fast_decoding", False)
else:
object.__setattr__(self, "fast_decoding", True)
Expand Down Expand Up @@ -450,7 +459,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
source_format, bit = self.source_format, self.bit

# Process integer source format
if source_format == "int":
if source_format == "int" and bit < 8:
assert not self.with_scaling, "scale should be False for int source format"
assert not self.with_zeros, "zeros should be False for int source format"
maxq = 2**(bit - 1)
Expand Down
52 changes: 38 additions & 14 deletions python/bitblas/ops/impl/matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def matmul_nt_dequantize_b(
with_bias=False,
zeros_mode="original",
):
assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit)
if not isinstance(M, int):
M = tvm.te.var("m")

Expand Down Expand Up @@ -78,13 +79,20 @@ def decode_func(n, k):
dtype=in_dtype,
)
elif source_format == "uint":
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
if bit == 8:
# 8 bit does not need to be compressed
w = B[n, k].astype(in_dtype)
else:
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif source_format == "int":
if bit == 1:
# Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif bit == 8:
# 8 bit does not need to be compressed
w = B[n, k].astype(in_dtype)
else:
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
Expand Down Expand Up @@ -187,6 +195,7 @@ def matmul_nt_dequantize_b_propagate_b(
zeros_mode="original",
transform_kind: TransformKind = TransformKind.IntraWarpTransform,
):
assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit)
if not isinstance(M, int):
M = tvm.te.var("m")

Expand Down Expand Up @@ -241,17 +250,24 @@ def fcompute(i, j):

def decode_func(n, k):
if source_format == "uint":
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
if bit == 8:
# 8 bit does not need to be compressed
w = B_reindex[n, k].astype(in_dtype)
else:
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
elif source_format == "int":
if bit == 1:
# Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif bit == 8:
# 8 bit does not need to be compressed
w = B_reindex[n, k].astype(in_dtype)
else:
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit,
Expand Down Expand Up @@ -360,6 +376,7 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b(
transform_kind_input: TransformKind = TransformKind.IntraWarpTransform,
transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform,
):
assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit)
if not isinstance(M, int):
M = tvm.te.var("m")

Expand Down Expand Up @@ -429,17 +446,24 @@ def fcompute(i, j):

def decode_func(n, k):
if source_format == "uint":
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
if bit == 8:
# 8 bit does not need to be compressed
w = B_reindex[n, k].astype(in_dtype)
else:
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
elif source_format == "int":
# Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
if bit == 1:
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif bit == 8:
# 8 bit does not need to be compressed
w = B_reindex[n, k].astype(in_dtype)
else:
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit,
Expand Down

0 comments on commit c4400b3

Please sign in to comment.