From c4400b33eb000d874978ba70aca46cb04fd1afd2 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sun, 2 Jun 2024 13:47:33 +0800 Subject: [PATCH] [BUGFix] Fix UINT/INT8 dequantize implementation and optimize the schedule template for float32 accum (#46) * improve e4m3 decoding. * append fp16xint1 * Update submodule commit reference * chore: Update shared memory scope for float32 output dtype * BUGFIX: UINT8/INT8 Decoding --------- Co-authored-by: LeiWang199 --- 3rdparty/tvm | 2 +- python/bitblas/base/roller/hint.py | 4 +- .../bitblas/base/roller/policy/tensorcore.py | 2 + python/bitblas/gpu/intrin/lop3.py | 39 ++++++++++++-- python/bitblas/ops/general_matmul.py | 13 ++++- .../ops/impl/matmul_dequantize_impl.py | 52 ++++++++++++++----- 6 files changed, 91 insertions(+), 21 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 0290a887..618306ce 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 0290a887df4a0f16284e413c26a533f2ee101fb5 +Subproject commit 618306ce3baa2c606d43856afbe6655e4e67b2c8 diff --git a/python/bitblas/base/roller/hint.py b/python/bitblas/base/roller/hint.py index c5fcda36..1d3270b4 100644 --- a/python/bitblas/base/roller/hint.py +++ b/python/bitblas/base/roller/hint.py @@ -6,6 +6,7 @@ import numpy as np from .rasterization import * + class TensorCoreExtraConfig: """ This class is used to store extra information for tensorcore @@ -228,7 +229,8 @@ def __repr__(self) -> str: def complete_config(self, node: PrimFuncNode): # analysis pass context, for int8 mma, we should merge static shared memory merge_static_smem = False - if self.use_tc and self.intrin_info.in_dtype == "int8": + # int32 and float32 accum may take too much shared memory + if self.use_tc and self.intrin_info.out_dtype in ["float32", "int32"]: merge_static_smem = True self.pass_context = {"tir.merge_static_smem": merge_static_smem} return self diff --git a/python/bitblas/base/roller/policy/tensorcore.py b/python/bitblas/base/roller/policy/tensorcore.py index f52a1b80..653a8809 100644 --- a/python/bitblas/base/roller/policy/tensorcore.py +++ b/python/bitblas/base/roller/policy/tensorcore.py @@ -297,6 +297,8 @@ def _score(node, thread): # small is better intrin_info = node.get_tag("intrin_info") if intrin_info: codegen_dict.intrin_info = IntrinInfo(**intrin_info) + if intrin_info["out_dtype"] in ["float32"]: + codegen_dict.shared_scope = "shared.dyn" # smem capacity if td.smem_cost > self.arch.smem_cap: codegen_dict.shared_scope = "shared.dyn" diff --git a/python/bitblas/gpu/intrin/lop3.py b/python/bitblas/gpu/intrin/lop3.py index 7ea0f93f..b5426cf5 100644 --- a/python/bitblas/gpu/intrin/lop3.py +++ b/python/bitblas/gpu/intrin/lop3.py @@ -1483,7 +1483,11 @@ def fast_decode_impl( TensorIntrin.register( LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN, *get_fast_decode_intrin( - source_bit=2, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16), + source_bit=2, + source_format="int", + storage_dtype="int8", + target_dtype="int8", + loops_extent=16), ) LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_") @@ -1497,10 +1501,13 @@ def fast_decode_impl( TensorIntrin.register( LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN, *get_fast_decode_intrin( - source_bit=1, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16), + source_bit=1, + source_format="int", + storage_dtype="int8", + target_dtype="int8", + loops_extent=16), ) - LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_") TensorIntrin.register( LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN, @@ -1553,6 +1560,32 @@ def fast_decode_impl( ), ) +LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i1_to_int8_to_f16_l8_") +TensorIntrin.register( + LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=1, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + ), +) + +LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_i1_to_int8_to_f16_l8_scale_") +TensorIntrin.register( + LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=1, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + def get_lop3_intrin_group( out_dtype: Literal["float16", "int8"], diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index 35eee1fb..e5a23f7f 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -148,9 +148,18 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]): object.__setattr__(self, "zeros_mode", "original") def __initialize_fast_decoding(self, fast_decoding: Optional[bool]): + + def is_not_fast_decoding_supported(): + conditions = [] + conditions.append("int" not in self.W_dtype) + conditions.append(self.W_dtype == self.A_dtype) + # int8,uint8 also do not implement and also do not require fast decoding + conditions.append(self.W_dtype in ["int8", "uint8"]) + return any(conditions) + if fast_decoding is not None: object.__setattr__(self, "fast_decoding", fast_decoding) - elif ("int" not in self.W_dtype or self.W_dtype == self.A_dtype): + elif is_not_fast_decoding_supported(): object.__setattr__(self, "fast_decoding", False) else: object.__setattr__(self, "fast_decoding", True) @@ -450,7 +459,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): source_format, bit = self.source_format, self.bit # Process integer source format - if source_format == "int": + if source_format == "int" and bit < 8: assert not self.with_scaling, "scale should be False for int source format" assert not self.with_zeros, "zeros should be False for int source format" maxq = 2**(bit - 1) diff --git a/python/bitblas/ops/impl/matmul_dequantize_impl.py b/python/bitblas/ops/impl/matmul_dequantize_impl.py index 6e6b098c..d4aa02c8 100644 --- a/python/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/python/bitblas/ops/impl/matmul_dequantize_impl.py @@ -33,6 +33,7 @@ def matmul_nt_dequantize_b( with_bias=False, zeros_mode="original", ): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) if not isinstance(M, int): M = tvm.te.var("m") @@ -78,13 +79,20 @@ def decode_func(n, k): dtype=in_dtype, ) elif source_format == "uint": - w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + if bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "int": if bit == 1: # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) else: w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) @@ -187,6 +195,7 @@ def matmul_nt_dequantize_b_propagate_b( zeros_mode="original", transform_kind: TransformKind = TransformKind.IntraWarpTransform, ): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) if not isinstance(M, int): M = tvm.te.var("m") @@ -241,17 +250,24 @@ def fcompute(i, j): def decode_func(n, k): if source_format == "uint": - w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) elif source_format == "int": if bit == 1: # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) else: w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( bit, @@ -360,6 +376,7 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b( transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, ): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) if not isinstance(M, int): M = tvm.te.var("m") @@ -429,17 +446,24 @@ def fcompute(i, j): def decode_func(n, k): if source_format == "uint": - w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) elif source_format == "int": # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. if bit == 1: w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) else: w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( bit,