Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFix] Fix UINT/INT8 dequantize implementation and optimize the schedule template for float32 accum #46

Merged
merged 6 commits into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
4 changes: 3 additions & 1 deletion python/bitblas/base/roller/hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from .rasterization import *


class TensorCoreExtraConfig:
"""
This class is used to store extra information for tensorcore
Expand Down Expand Up @@ -228,7 +229,8 @@ def __repr__(self) -> str:
def complete_config(self, node: PrimFuncNode):
# analysis pass context, for int8 mma, we should merge static shared memory
merge_static_smem = False
if self.use_tc and self.intrin_info.in_dtype == "int8":
# int32 and float32 accum may take too much shared memory
if self.use_tc and self.intrin_info.out_dtype in ["float32", "int32"]:
merge_static_smem = True
self.pass_context = {"tir.merge_static_smem": merge_static_smem}
return self
2 changes: 2 additions & 0 deletions python/bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ def _score(node, thread): # small is better
intrin_info = node.get_tag("intrin_info")
if intrin_info:
codegen_dict.intrin_info = IntrinInfo(**intrin_info)
if intrin_info["out_dtype"] in ["float32"]:
codegen_dict.shared_scope = "shared.dyn"
# smem capacity
if td.smem_cost > self.arch.smem_cap:
codegen_dict.shared_scope = "shared.dyn"
Expand Down
39 changes: 36 additions & 3 deletions python/bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,11 @@ def fast_decode_impl(
TensorIntrin.register(
LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN,
*get_fast_decode_intrin(
source_bit=2, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16),
source_bit=2,
source_format="int",
storage_dtype="int8",
target_dtype="int8",
loops_extent=16),
)

LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_")
Expand All @@ -1497,10 +1501,13 @@ def fast_decode_impl(
TensorIntrin.register(
LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN,
*get_fast_decode_intrin(
source_bit=1, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16),
source_bit=1,
source_format="int",
storage_dtype="int8",
target_dtype="int8",
loops_extent=16),
)


LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_")
TensorIntrin.register(
LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN,
Expand Down Expand Up @@ -1553,6 +1560,32 @@ def fast_decode_impl(
),
)

LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i1_to_int8_to_f16_l8_")
TensorIntrin.register(
LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN,
*get_fast_decode_intrin(
source_bit=1,
storage_dtype="int8",
source_format="int",
target_dtype="float16",
loops_extent=8,
),
)

LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = (
"lop3_fast_decode_i1_to_int8_to_f16_l8_scale_")
TensorIntrin.register(
LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN,
*get_fast_decode_intrin(
source_bit=1,
storage_dtype="int8",
source_format="int",
target_dtype="float16",
loops_extent=8,
with_scale=True,
),
)


def get_lop3_intrin_group(
out_dtype: Literal["float16", "int8"],
Expand Down
13 changes: 11 additions & 2 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,18 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]):
object.__setattr__(self, "zeros_mode", "original")

def __initialize_fast_decoding(self, fast_decoding: Optional[bool]):

def is_not_fast_decoding_supported():
conditions = []
conditions.append("int" not in self.W_dtype)
conditions.append(self.W_dtype == self.A_dtype)
# int8,uint8 also do not implement and also do not require fast decoding
conditions.append(self.W_dtype in ["int8", "uint8"])
return any(conditions)

if fast_decoding is not None:
object.__setattr__(self, "fast_decoding", fast_decoding)
elif ("int" not in self.W_dtype or self.W_dtype == self.A_dtype):
elif is_not_fast_decoding_supported():
object.__setattr__(self, "fast_decoding", False)
else:
object.__setattr__(self, "fast_decoding", True)
Expand Down Expand Up @@ -450,7 +459,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
source_format, bit = self.source_format, self.bit

# Process integer source format
if source_format == "int":
if source_format == "int" and bit < 8:
assert not self.with_scaling, "scale should be False for int source format"
assert not self.with_zeros, "zeros should be False for int source format"
maxq = 2**(bit - 1)
Expand Down
52 changes: 38 additions & 14 deletions python/bitblas/ops/impl/matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def matmul_nt_dequantize_b(
with_bias=False,
zeros_mode="original",
):
assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit)
if not isinstance(M, int):
M = tvm.te.var("m")

Expand Down Expand Up @@ -78,13 +79,20 @@ def decode_func(n, k):
dtype=in_dtype,
)
elif source_format == "uint":
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
if bit == 8:
# 8 bit does not need to be compressed
w = B[n, k].astype(in_dtype)
else:
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif source_format == "int":
if bit == 1:
# Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif bit == 8:
# 8 bit does not need to be compressed
w = B[n, k].astype(in_dtype)
else:
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
Expand Down Expand Up @@ -187,6 +195,7 @@ def matmul_nt_dequantize_b_propagate_b(
zeros_mode="original",
transform_kind: TransformKind = TransformKind.IntraWarpTransform,
):
assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit)
if not isinstance(M, int):
M = tvm.te.var("m")

Expand Down Expand Up @@ -241,17 +250,24 @@ def fcompute(i, j):

def decode_func(n, k):
if source_format == "uint":
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
if bit == 8:
# 8 bit does not need to be compressed
w = B_reindex[n, k].astype(in_dtype)
else:
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
elif source_format == "int":
if bit == 1:
# Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif bit == 8:
# 8 bit does not need to be compressed
w = B_reindex[n, k].astype(in_dtype)
else:
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit,
Expand Down Expand Up @@ -360,6 +376,7 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b(
transform_kind_input: TransformKind = TransformKind.IntraWarpTransform,
transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform,
):
assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit)
if not isinstance(M, int):
M = tvm.te.var("m")

Expand Down Expand Up @@ -429,17 +446,24 @@ def fcompute(i, j):

def decode_func(n, k):
if source_format == "uint":
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
if bit == 8:
# 8 bit does not need to be compressed
w = B_reindex[n, k].astype(in_dtype)
else:
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
elif source_format == "int":
# Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
if bit == 1:
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif bit == 8:
# 8 bit does not need to be compressed
w = B_reindex[n, k].astype(in_dtype)
else:
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit,
Expand Down
Loading