From 75d2f3d5c2713064944fea9f6f0372d966b0e1d1 Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Tue, 21 May 2024 11:51:02 +0000 Subject: [PATCH 01/17] improve e4m3 decoding. --- python/bitblas/quantization/quantization.py | 6 +++--- testing/python/operators/test_general_matmul_fp8.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/bitblas/quantization/quantization.py b/python/bitblas/quantization/quantization.py index d9f360947..d68d437d4 100644 --- a/python/bitblas/quantization/quantization.py +++ b/python/bitblas/quantization/quantization.py @@ -142,9 +142,9 @@ def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): assert nbit == 8 assert dtype == "float16" - s_f16 = (val >> tir.const(7, "int16")) << tir.const(15, "int16") - offset = tir.Select(s_f16 == 0, tir.const(8192, "int16"), tir.const(-8192, "int16")) - e_f16 = ((val << tir.const(7, "int16")) + offset) + s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") + prefix = tir.Select(s_f16 == 0, tir.const(0x2000, "uint16"), tir.const(0xc000, "uint16")) + e_f16 = (((val & tir.const(127, "uint16")) << tir.const(7, "uint16"))) | prefix return tir.reinterpret("float16", s_f16 | e_f16) diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py index 3d0a7be2f..5b7de9ab0 100644 --- a/testing/python/operators/test_general_matmul_fp8.py +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -166,7 +166,7 @@ def map_torch_type(intype): print("torch_ref_out", ref_out) print("bitblas_out", bitblas_out) - torch.testing.assert_allclose(ref_out, bitblas_out, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(ref_out, bitblas_out, rtol=1e-1, atol=1e-1) # fmt: on From 00bfa319caa321c0e946d5f87d374982beea263e Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Sat, 25 May 2024 08:26:02 +0000 Subject: [PATCH 02/17] append fp16xint1 --- python/bitblas/gpu/intrin/lop3.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/python/bitblas/gpu/intrin/lop3.py b/python/bitblas/gpu/intrin/lop3.py index 7ea0f93f4..adff76ab9 100644 --- a/python/bitblas/gpu/intrin/lop3.py +++ b/python/bitblas/gpu/intrin/lop3.py @@ -1553,6 +1553,32 @@ def fast_decode_impl( ), ) +LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i1_to_int8_to_f16_l8_") +TensorIntrin.register( + LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=1, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + ), +) + +LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_i1_to_int8_to_f16_l8_scale_") +TensorIntrin.register( + LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=1, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + def get_lop3_intrin_group( out_dtype: Literal["float16", "int8"], From 8cd8b10c18e3f484a3273f6b6bda498bb022c783 Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Sat, 1 Jun 2024 14:24:57 +0000 Subject: [PATCH 03/17] Update submodule commit reference --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 0290a887d..618306ce3 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 0290a887df4a0f16284e413c26a533f2ee101fb5 +Subproject commit 618306ce3baa2c606d43856afbe6655e4e67b2c8 From 9122ff7ca0e07e53b120940dae0fee434a6b0e08 Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Sat, 1 Jun 2024 14:25:33 +0000 Subject: [PATCH 04/17] chore: Update shared memory scope for float32 output dtype --- python/bitblas/base/roller/hint.py | 3 ++- python/bitblas/base/roller/policy/tensorcore.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/bitblas/base/roller/hint.py b/python/bitblas/base/roller/hint.py index c5fcda366..bcbf0b2af 100644 --- a/python/bitblas/base/roller/hint.py +++ b/python/bitblas/base/roller/hint.py @@ -228,7 +228,8 @@ def __repr__(self) -> str: def complete_config(self, node: PrimFuncNode): # analysis pass context, for int8 mma, we should merge static shared memory merge_static_smem = False - if self.use_tc and self.intrin_info.in_dtype == "int8": + # int32 and float32 accum may take too much shared memory + if self.use_tc and self.intrin_info.out_dtype in ["float32", "int32"]: merge_static_smem = True self.pass_context = {"tir.merge_static_smem": merge_static_smem} return self diff --git a/python/bitblas/base/roller/policy/tensorcore.py b/python/bitblas/base/roller/policy/tensorcore.py index f52a1b80b..653a8809c 100644 --- a/python/bitblas/base/roller/policy/tensorcore.py +++ b/python/bitblas/base/roller/policy/tensorcore.py @@ -297,6 +297,8 @@ def _score(node, thread): # small is better intrin_info = node.get_tag("intrin_info") if intrin_info: codegen_dict.intrin_info = IntrinInfo(**intrin_info) + if intrin_info["out_dtype"] in ["float32"]: + codegen_dict.shared_scope = "shared.dyn" # smem capacity if td.smem_cost > self.arch.smem_cap: codegen_dict.shared_scope = "shared.dyn" From b508acc60fa650bac5eb4be2a9277a2c9bd8b08f Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Sun, 2 Jun 2024 05:43:56 +0000 Subject: [PATCH 05/17] BUGFIX: UINT8/INT8 Decoding --- python/bitblas/base/roller/hint.py | 1 + python/bitblas/gpu/intrin/lop3.py | 13 +++-- python/bitblas/ops/general_matmul.py | 13 ++++- .../ops/impl/matmul_dequantize_impl.py | 52 ++++++++++++++----- 4 files changed, 60 insertions(+), 19 deletions(-) diff --git a/python/bitblas/base/roller/hint.py b/python/bitblas/base/roller/hint.py index bcbf0b2af..1d3270b4d 100644 --- a/python/bitblas/base/roller/hint.py +++ b/python/bitblas/base/roller/hint.py @@ -6,6 +6,7 @@ import numpy as np from .rasterization import * + class TensorCoreExtraConfig: """ This class is used to store extra information for tensorcore diff --git a/python/bitblas/gpu/intrin/lop3.py b/python/bitblas/gpu/intrin/lop3.py index adff76ab9..b5426cf59 100644 --- a/python/bitblas/gpu/intrin/lop3.py +++ b/python/bitblas/gpu/intrin/lop3.py @@ -1483,7 +1483,11 @@ def fast_decode_impl( TensorIntrin.register( LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN, *get_fast_decode_intrin( - source_bit=2, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16), + source_bit=2, + source_format="int", + storage_dtype="int8", + target_dtype="int8", + loops_extent=16), ) LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_") @@ -1497,10 +1501,13 @@ def fast_decode_impl( TensorIntrin.register( LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN, *get_fast_decode_intrin( - source_bit=1, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16), + source_bit=1, + source_format="int", + storage_dtype="int8", + target_dtype="int8", + loops_extent=16), ) - LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_") TensorIntrin.register( LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN, diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index 35eee1fb6..e5a23f7f0 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -148,9 +148,18 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]): object.__setattr__(self, "zeros_mode", "original") def __initialize_fast_decoding(self, fast_decoding: Optional[bool]): + + def is_not_fast_decoding_supported(): + conditions = [] + conditions.append("int" not in self.W_dtype) + conditions.append(self.W_dtype == self.A_dtype) + # int8,uint8 also do not implement and also do not require fast decoding + conditions.append(self.W_dtype in ["int8", "uint8"]) + return any(conditions) + if fast_decoding is not None: object.__setattr__(self, "fast_decoding", fast_decoding) - elif ("int" not in self.W_dtype or self.W_dtype == self.A_dtype): + elif is_not_fast_decoding_supported(): object.__setattr__(self, "fast_decoding", False) else: object.__setattr__(self, "fast_decoding", True) @@ -450,7 +459,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): source_format, bit = self.source_format, self.bit # Process integer source format - if source_format == "int": + if source_format == "int" and bit < 8: assert not self.with_scaling, "scale should be False for int source format" assert not self.with_zeros, "zeros should be False for int source format" maxq = 2**(bit - 1) diff --git a/python/bitblas/ops/impl/matmul_dequantize_impl.py b/python/bitblas/ops/impl/matmul_dequantize_impl.py index 6e6b098c1..d4aa02c84 100644 --- a/python/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/python/bitblas/ops/impl/matmul_dequantize_impl.py @@ -33,6 +33,7 @@ def matmul_nt_dequantize_b( with_bias=False, zeros_mode="original", ): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) if not isinstance(M, int): M = tvm.te.var("m") @@ -78,13 +79,20 @@ def decode_func(n, k): dtype=in_dtype, ) elif source_format == "uint": - w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + if bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "int": if bit == 1: # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) else: w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) @@ -187,6 +195,7 @@ def matmul_nt_dequantize_b_propagate_b( zeros_mode="original", transform_kind: TransformKind = TransformKind.IntraWarpTransform, ): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) if not isinstance(M, int): M = tvm.te.var("m") @@ -241,17 +250,24 @@ def fcompute(i, j): def decode_func(n, k): if source_format == "uint": - w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) elif source_format == "int": if bit == 1: # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) else: w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( bit, @@ -360,6 +376,7 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b( transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, ): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) if not isinstance(M, int): M = tvm.te.var("m") @@ -429,17 +446,24 @@ def fcompute(i, j): def decode_func(n, k): if source_format == "uint": - w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) elif source_format == "int": # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. if bit == 1: w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) else: w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( bit, From 58d55b7cb750bd0678985a30b532929196765d13 Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Wed, 5 Jun 2024 07:11:48 +0000 Subject: [PATCH 06/17] feat: Add rasterization options for roller module --- python/bitblas/base/roller/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/bitblas/base/roller/__init__.py b/python/bitblas/base/roller/__init__.py index 7ca6f15c2..9afd7cff0 100644 --- a/python/bitblas/base/roller/__init__.py +++ b/python/bitblas/base/roller/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .node import PrimFuncNode # noqa: F401 +from .rasterization import NoRasterization, Rasterization2DRow, Rasterization2DColumn # noqa: F401 from .hint import Hint # noqa: F401 from .policy import DefaultPolicy, TensorCorePolicy # noqa: F401 from .arch import TileDevice, CUDA # noqa: F401 From e7547ced6f222de7668a2963f190471791aa0be3 Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Wed, 5 Jun 2024 07:12:00 +0000 Subject: [PATCH 07/17] Refactor tensorcore_legalization method to optimize tensor core usage --- python/bitblas/base/roller/hint.py | 6 ++++++ python/bitblas/base/roller/policy/tensorcore.py | 1 + 2 files changed, 7 insertions(+) diff --git a/python/bitblas/base/roller/hint.py b/python/bitblas/base/roller/hint.py index 1d3270b4d..89f607cde 100644 --- a/python/bitblas/base/roller/hint.py +++ b/python/bitblas/base/roller/hint.py @@ -211,6 +211,12 @@ def from_dict(self, dic: Dict) -> "Hint": setattr(self, k, v) return self + def tensorcore_legalization(self): + # only keep the last 2 axes for tensorcore + self.warp = self.warp[-2:] + self.block = self.block[-2:] + return self + @property def raxis_order(self) -> List[int]: if self._raxis_order != []: diff --git a/python/bitblas/base/roller/policy/tensorcore.py b/python/bitblas/base/roller/policy/tensorcore.py index 653a8809c..97edb50fc 100644 --- a/python/bitblas/base/roller/policy/tensorcore.py +++ b/python/bitblas/base/roller/policy/tensorcore.py @@ -307,6 +307,7 @@ def _score(node, thread): # small is better codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size) codegen_dict.arch = self.arch codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes") + codegen_dict.tensorcore_legalization() return codegen_dict def plan_rasterization(self, td: TileDict): From 678a2e15fc10c2aeb0a2ae8e16e685cc5bb102f8 Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Wed, 5 Jun 2024 07:12:24 +0000 Subject: [PATCH 08/17] feat: Add function to collect variables from expression, improve for splitk --- python/bitblas/gpu/matmul_analysis.py | 34 ++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/python/bitblas/gpu/matmul_analysis.py b/python/bitblas/gpu/matmul_analysis.py index 2fa9c16a4..c53211e2b 100644 --- a/python/bitblas/gpu/matmul_analysis.py +++ b/python/bitblas/gpu/matmul_analysis.py @@ -17,11 +17,24 @@ get_reduction_blocks, ) from tvm.target.target import Target -from tvm.tir import IndexMap +from tvm.tir import IndexMap, Var +from tvm.tir.stmt_functor import pre_order_visit import logging logger = logging.getLogger(__name__) +def collect_vars_from_expr(prim_expr): + vars = [] + + def callback(node): + if isinstance(node, Var): + vars.append(node) + return True + + pre_order_visit(prim_expr, callback) + + return vars + def _is_one(x: PrimExpr) -> bool: return isinstance(x, tir.IntImm) and x.value == 1 @@ -337,9 +350,16 @@ def is_common_reduce(var: Var) -> bool: return True return False + def has_common_reduce(var: Var) -> bool: + vars = collect_vars_from_expr(var) + for v in vars: + if is_common_reduce(v): + return True + return False + def check_last_trait(region: List[Range]): axes = get_ordered_axes(region) - return is_common_reduce(axes[-1]) + return has_common_reduce(axes[-1]) def infer_layout(layout: str, region: List[Range], kind: str = "A"): """ @@ -583,10 +603,18 @@ def is_common_reduce(var: Var) -> bool: return True return False + def has_common_reduce(var: Var) -> bool: + vars = collect_vars_from_expr(var) + for v in vars: + if is_common_reduce(v): + return True + return False + def check_last_trait(region: List[Range]): axes = get_ordered_axes(region) - return is_common_reduce(axes[-1]) + return has_common_reduce(axes[-1]) + intrin_info: dict = {} in_dtype, out_dtype = get_in_out_dtypes(block_stmt) intrin_info["in_dtype"] = in_dtype From 3088b35b90042b00fba9a5300f89405368606174 Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Wed, 5 Jun 2024 07:12:41 +0000 Subject: [PATCH 09/17] chore: Update typing import in __init__.py --- python/bitblas/module/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/bitblas/module/__init__.py b/python/bitblas/module/__init__.py index e6e393e6c..eaf15bc1d 100644 --- a/python/bitblas/module/__init__.py +++ b/python/bitblas/module/__init__.py @@ -11,7 +11,7 @@ logger = getLogger(__name__) -from typing import List, Union +from typing import List, Union, Optional from bitblas.cache import global_operator_cache, get_database_path from bitblas import Matmul, MatmulConfig @@ -67,7 +67,7 @@ def __init__( opt_M: Union[int, List[int]] = opt_M, # performance related configs enable_tuning: bool = True, - fast_decoding: bool = True, + fast_decoding: Optional[bool] = None, propagate_b: bool = False, ): """ From 5d206b3dae4e240571b8bf161d410b52754f83e7 Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Wed, 5 Jun 2024 07:13:54 +0000 Subject: [PATCH 10/17] chore: Refactor CPU execution of operators --- python/bitblas/ops/general_matmul.py | 38 ++++------------------------ python/bitblas/ops/operator.py | 31 +++++++++++++++++++++++ 2 files changed, 36 insertions(+), 33 deletions(-) diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index e5a23f7f0..c79a95e90 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -5,14 +5,13 @@ import operator from functools import reduce from bitblas.base.roller.arch.cuda import CUDA -from typing import Any, List, Literal, Optional, Tuple, Union -from .operator import Operator, TransformKind +from typing import Any, Literal, Optional, Tuple, Union +from .operator import Operator, TransformKind, OPExecutorCPU from .impl.matmul_dequantize_impl import ( select_implementation as weight_dequantize_implementation,) from .impl.matmul_impl import select_implementation as consistent_implementation from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4 from bitblas.utils.target_detector import auto_detect_nvidia_target -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch from dataclasses import dataclass from .ladder_permutate import LadderPermutate, LadderPermutateConfig from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig @@ -41,35 +40,6 @@ def is_native_compute(A_dtype, W_dtype) -> bool: return (A_dtype, W_dtype) in NATIVE_COMPUTE_PATTERNS - -class OPExecutorCPU: - - def __init__(self, operators: Optional[List[Operator]] = None): - if operators is None: - operators = [] - self.operators = operators - - def append(self, op): - self.operators.append(op) - - def is_none(self): - return len(self.operators) == 0 - - def forward(self, weight): - inputs = [weight] - for op in self.operators: - inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) - inputs = [op.forward(*inputs)] - return inputs[-1] - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **kwds) - - @property - def size(self): - return len(self.operators) - - @dataclass(frozen=True) class MatmulConfig: M: Union[int, Tuple[int]] = None @@ -527,10 +497,12 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if self.dynamic_range is not None: m = reduce(operator.mul, A.shape[:-1], 1) args.append(m) + + stream = torch.cuda.current_stream() if self.lib is None: self._forward_from_torch_func(*args) - self._forward_from_prebuild_lib(*args) + self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) return output diff --git a/python/bitblas/ops/operator.py b/python/bitblas/ops/operator.py index 3bb45e217..cf8b9aef0 100644 --- a/python/bitblas/ops/operator.py +++ b/python/bitblas/ops/operator.py @@ -15,6 +15,7 @@ from ..base import fast_tune, fast_tune_with_dynamic_range from copy import deepcopy from bitblas.base.roller.arch import get_arch +from bitblas.utils.tensor_adapter import tvm_tensor_to_torch from bitblas.wrapper import CUDASourceWrapper, CUDASourceWrapperWithDynamic from dataclasses import dataclass from enum import IntEnum @@ -333,3 +334,33 @@ def _select_implementation(self) -> IRModule: @property def prim_func(self): return self.prim_func_mod["main"] + + +class OPExecutorCPU: + """ + A class to execute a sequence of operators on the CPU. + """ + def __init__(self, operators: Optional[List[Operator]] = None): + if operators is None: + operators = [] + self.operators = operators + + def append(self, op): + self.operators.append(op) + + def is_none(self): + return len(self.operators) == 0 + + def forward(self, weight): + inputs = [weight] + for op in self.operators: + inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) + inputs = [op.forward(*inputs)] + return inputs[-1] + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.forward(*args, **kwds) + + @property + def size(self): + return len(self.operators) From e06ce10fe19d03321b80ccae6e6ee5bbc9b03dab Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Wed, 5 Jun 2024 07:14:06 +0000 Subject: [PATCH 11/17] Refactor matmul implementation for splitk layout --- python/bitblas/ops/general_matmul_splitk.py | 191 +++++++++ .../ops/impl/batch_matmul_dequantize_impl.py | 389 ++++++++++++++++++ python/bitblas/ops/impl/batch_matmul_impl.py | 93 +++++ .../ops/impl/matmul_dequantize_splitk_impl.py | 191 +++++++++ python/bitblas/ops/impl/matmul_splitk_impl.py | 94 +++++ .../test_general_matmul_splitk_ops.py | 121 ++++++ 6 files changed, 1079 insertions(+) create mode 100644 python/bitblas/ops/general_matmul_splitk.py create mode 100644 python/bitblas/ops/impl/batch_matmul_dequantize_impl.py create mode 100644 python/bitblas/ops/impl/batch_matmul_impl.py create mode 100644 python/bitblas/ops/impl/matmul_dequantize_splitk_impl.py create mode 100644 python/bitblas/ops/impl/matmul_splitk_impl.py create mode 100644 testing/python/operators/test_general_matmul_splitk_ops.py diff --git a/python/bitblas/ops/general_matmul_splitk.py b/python/bitblas/ops/general_matmul_splitk.py new file mode 100644 index 000000000..e8af5845b --- /dev/null +++ b/python/bitblas/ops/general_matmul_splitk.py @@ -0,0 +1,191 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm.target import Target +import operator +from functools import reduce +from typing import Any, Optional, Union +from .operator import TransformKind +from .impl.matmul_splitk_impl import select_implementation as consistent_implementation +from .impl.matmul_dequantize_splitk_impl import select_implementation as weight_dequantize_implementation +from dataclasses import dataclass +import logging +import torch +from .general_matmul import MatmulConfig, Matmul +from .general_matmul import is_native_compute + +logger = logging.getLogger(__name__) + +WORKSPACE_SIZE = 1024 * 1024 * 256 + +@dataclass(frozen=True) +class MatmulConfigWithSplitK(MatmulConfig): + k_split: int = 1 # split K dimension + + +class MatmulWithSplitK(Matmul): + + def __init__( + self, + config: MatmulConfig, + name: str = "matmul", + target: Optional[Union[str, Target]] = None, + enable_tuning: bool = True, + from_database: bool = False, + ): + super().__init__(config, name, target, enable_tuning, from_database) + + def _select_implementation(self): + # the major implementation + if is_native_compute(self.A_dtype, self.W_dtype): + return consistent_implementation( + SplitK=self.k_split, + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.A_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + with_bias=self.with_bias, + layout=self.layout, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) + else: + return weight_dequantize_implementation( + SplitK=self.k_split, + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.A_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + bit=self.bit, + storage_dtype=self.storage_dtype, + source_format=self.source_format, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + group_size=self.group_size, + fast_decoding=self.fast_decoding, + with_bias=self.with_bias, + layout=self.layout, + zeros_mode=self.zeros_mode, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) + + def retrieve_weight_shape(self): + return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] + + def transform_weight(self, weight, scale=None, zeros=None, bias=None): + """ + Transforms the given weight tensor based on the specified quantization parameters and + returns the transformed weight along with optional scale, zeros, and bias. + + Parameters: + - weight: The input weight tensor to be transformed. + - scale: Optional scaling factor for the weight tensor. + - zeros: Optional zero-point adjustment for the weight tensor. + - bias: Optional bias to be added to the weight tensor. + + Returns: + A list containing the transformed weight tensor and optionally the scale, zeros, and bias. + """ + weight = weight.contiguous() + if self.W_dtype == self.A_dtype: + if self.weight_transform is not None: + return self.weight_transform(weight.cpu()).cuda().contiguous() + return weight + + from bitblas.quantization import general_compress + import torch + import numpy as np + + source_format, bit = self.source_format, self.bit + + # Process integer source format + if source_format == "int" and bit < 8: + assert not self.with_scaling, "scale should be False for int source format" + assert not self.with_zeros, "zeros should be False for int source format" + maxq = 2**(bit - 1) + # Clamp weight values to be within the quantizable range and adjust + weight = torch.clamp(weight, -maxq, maxq).int() + maxq + elif source_format in ["fp_e5m2", "fp_e4m3"]: + weight = weight.view(torch.int8) + weight = weight.int() + else: + # For non-integer formats, simply convert weights to integers + weight = weight.int() + + np_storage_dtype = getattr(np, self.storage_dtype) + + weight = general_compress( + weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype) + + weight = torch.from_numpy(weight).cuda().contiguous() + + # Apply an optional weight transformation if specified + if self.weight_transform is not None: + weight = self.weight_transform(weight.cpu()).cuda().contiguous() + + # Prepare the return list with the transformed weight and optionally include scale, zeros, and bias + result = [weight] + if scale is not None: + result.append(scale) + if zeros is not None: + result.append(zeros) + if bias is not None: + result.append(bias) + + return next(iter(result), result) + + def transform_input(self, input_tensor): + if self.propagate_a is not TransformKind.NonTransform: + # check workspace size + if input_tensor.numel() > WORKSPACE_SIZE: + raise ValueError( + f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size." + ) + self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace) + return self.workspace + return input_tensor + + def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: + args = [] + args.append(self.transform_input(A)) + args.append(W) + + if self.lut is not None: + args.append(self.lut) + + if output is None: + output = torch.empty((self.k_split, ) + + A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device) + if scale is not None: + args.append(scale) + if zeros is not None: + args.append(zeros) + if bias is not None: + args.append(bias) + args.append(output) + + if self.dynamic_range is not None: + m = reduce(operator.mul, A.shape[:-1], 1) + args.append(m) + + stream = torch.cuda.current_stream() + + if self.lib is None: + self._forward_from_torch_func(*args) + self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) + output = torch.sum(output, dim=0) + return output + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.forward(*args, **kwds) + + @property + def k_split(self): + return self.config.k_split + + +__all__ = ["MatmulConfigWithSplitK", "MatmulWithSplitK"] diff --git a/python/bitblas/ops/impl/batch_matmul_dequantize_impl.py b/python/bitblas/ops/impl/batch_matmul_dequantize_impl.py new file mode 100644 index 000000000..0660f9f5b --- /dev/null +++ b/python/bitblas/ops/impl/batch_matmul_dequantize_impl.py @@ -0,0 +1,389 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm import te, DataType +from tvm.tir import IndexMap +from bitblas.ops.operator import TransformKind +from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.quantization import ( + _tir_packed_int_to_int_convert, + _tir_packed_to_signed_convert, + _tir_packed_to_unsigned_convert, + _tir_u32_to_f4_to_f16, + _tir_u8_to_f8_e4m3_to_f16, + _tir_packed_to_unsigned_convert_with_zeros, +) + +def matmul_nt_dequantize_b( + Batch, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) + B = te.placeholder((Batch, N, K // storage_nbit * bit), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((Batch, N, K // group_size), name="Scale", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + + def decode_func(b, n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B[b, n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "int": + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B[b, n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[b, n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B[b, n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[b, n, k // group_size] + + return w + + B_decode = te.compute((Batch, N, K), decode_func, name="B_decode") + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (Batch, M, N), + lambda b, i, j: te.sum( + A[b, i, k].astype(accum_dtype) * B_decode[b, j, k].astype(accum_dtype), axis=k), + name="C", + ) + D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") + args = [A, B] + last_output = D + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_bias: + E = te.compute((Batch, M, N), lambda b, i, j: D[b, i, j] + Bias[j], name="E") + last_output = E + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_scaling": with_scaling, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "group_size": group_size, + } + }, + ) + return tvm.IRModule.from_expr(func) + + +def matmul_nt_dequantize_b_propagate_b( + Batch, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + transform_kind: TransformKind = TransformKind.IntraWarpTransform, +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + _, inverse_indexmap = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inverse_indexmap.initial_indices + scaling_final_indices = inverse_indexmap.map_indices(initial_indices[:-1] + + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inverse_indexmap = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + qr = r * bit // storage_nbit + A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) + B = te.placeholder((Batch, N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((Batch, N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((Batch, N, K // group_size), name="Zeros", dtype=in_dtype) + Bias = te.placeholder((Batch, N,), name="Bias", dtype=in_dtype) + + def fcompute(b, i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inverse_indexmap.map_indices([warp_i, warp_j]) + new_index = (b, *spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (Batch, N, K // storage_nbit * bit), + fcompute, + name="B_reindex", + ) + + def decode_func(b, n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[b, n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[b, n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "int": + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B_reindex[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[b, n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, + B_reindex[b, n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, + B_reindex[b, n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[b, n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[b, n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[b, n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[b, n, k // group_size]) * Scale[b, n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[b, n, k // group_size] - Zeros[b, n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + B_decode = te.compute((Batch, N, K), decode_func, name="B_decode") + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (Batch, M, N), + lambda b, i, j: te.sum( + A[b, i, k].astype(accum_dtype) * B_decode[b, j, k].astype(accum_dtype), axis=k), + name="C", + ) + D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") + args = [A, B] + last_output = D + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_zeros: + args.append(Zeros) + if with_bias: + E = te.compute((Batch, M, N), lambda b, i, j: D[b, i, j] + Bias[j], name="E") + last_output = E + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "with_scaling": with_scaling, + "group_size": group_size, + } + }, + ) + func = func.with_attr("weight_transform_kind", transform_kind.value) + return tvm.IRModule.from_expr(func) + + +def select_implementation( + Batch=1, + M=None, + N=1024, + K=1024, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + layout="nt", + zeros_mode="original", + propagate_a=False, + propagate_b=False, +): + if layout == "nn": + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn in Dequantize Implementation" + ) + elif layout == "nt": + if propagate_a and propagate_b: + raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") + elif propagate_a: + raise ValueError("Currently only support propagate_a=False for layout=nt") + elif propagate_b: + return matmul_nt_dequantize_b_propagate_b( + Batch, + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + transform_kind=propagate_b, + ) + else: + return matmul_nt_dequantize_b( + Batch, + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + ) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/python/bitblas/ops/impl/batch_matmul_impl.py b/python/bitblas/ops/impl/batch_matmul_impl.py new file mode 100644 index 000000000..bc3994ec4 --- /dev/null +++ b/python/bitblas/ops/impl/batch_matmul_impl.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm import te +from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.ops.operator import TransformKind + + +def matmul_nt( + Batch, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + if not isinstance(M, int): + M = tvm.te.var("m") + A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) + B = te.placeholder((Batch, N, K), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (Batch, M, N), + lambda b, i, j: te.sum(A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((Batch, M, N), lambda b, i, j: last_output[b, i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def matmul( + Batch, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", +): + if layout == "nn": + raise ValueError("Currently only support layout=nt") + return matmul_nt(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) + + +def select_implementation( + Batch=1, + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, +): + if layout == "nn": + if propagate_a or propagate_b: + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn") + return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + elif layout == "nt": + if propagate_a and propagate_b: + raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") + elif propagate_a: + raise ValueError("Currently only support propagate_a=False for layout=nt") + elif propagate_b: + raise ValueError("Currently only support propagate_b=False for layout=nt") + else: + return matmul(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/python/bitblas/ops/impl/matmul_dequantize_splitk_impl.py b/python/bitblas/ops/impl/matmul_dequantize_splitk_impl.py new file mode 100644 index 000000000..542ac8351 --- /dev/null +++ b/python/bitblas/ops/impl/matmul_dequantize_splitk_impl.py @@ -0,0 +1,191 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm import te, DataType +from tvm.tir import IndexMap +from bitblas.ops.operator import TransformKind +from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.quantization import ( + _tir_packed_int_to_int_convert, + _tir_packed_to_signed_convert, + _tir_packed_to_unsigned_convert, + _tir_u32_to_f4_to_f16, + _tir_u8_to_f8_e4m3_to_f16, + _tir_packed_to_unsigned_convert_with_zeros, +) + +def matmul_nt_dequantize_b( + SplitK, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N, K // storage_nbit * bit), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + + def decode_func(n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "int": + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + # Describe the matrix multiplication in TE + RK = K // SplitK + k = te.reduce_axis((0, RK), name="k") + C = te.compute( + (SplitK, M, N), + lambda sk, i, j: te.sum( + A[i, sk * RK + k].astype(accum_dtype) * B_decode[j, sk * RK + k].astype(accum_dtype), axis=k), + name="C", + ) + D = te.compute((SplitK, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") + args = [A, B] + last_output = D + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_bias: + E = te.compute((SplitK, M, N), lambda b, i, j: D[b, i, j] + Bias[j], name="E") + last_output = E + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_scaling": with_scaling, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "group_size": group_size, + } + }, + ) + return tvm.IRModule.from_expr(func) + + +def select_implementation( + SplitK=1, + M=None, + N=1024, + K=1024, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + layout="nt", + zeros_mode="original", + propagate_a=False, + propagate_b=False, +): + if layout == "nn": + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn in Dequantize Implementation" + ) + elif layout == "nt": + if propagate_a and propagate_b: + raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") + elif propagate_a: + raise ValueError("Currently only support propagate_a=False for layout=nt") + elif propagate_b: + raise ValueError("Currently only support propagate_b=False for layout=nt") + else: + return matmul_nt_dequantize_b( + SplitK, + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + ) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/python/bitblas/ops/impl/matmul_splitk_impl.py b/python/bitblas/ops/impl/matmul_splitk_impl.py new file mode 100644 index 000000000..a373cd252 --- /dev/null +++ b/python/bitblas/ops/impl/matmul_splitk_impl.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm import te +from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.ops.operator import TransformKind + + +def matmul_nt( + SplitK, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + if not isinstance(M, int): + M = tvm.te.var("m") + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N, K), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + # Describe the matrix multiplication in TE + RK = K // SplitK + k = te.reduce_axis((0, RK), name="k") + C = te.compute( + (SplitK, M, N), + lambda sk, i, j: te.sum(A[i, sk * RK + k].astype(accum_dtype) * B[j, sk * RK + k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((SplitK, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((SplitK, M, N), lambda b, i, j: last_output[b, i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def matmul( + SplitK, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", +): + if layout == "nn": + raise ValueError("Currently only support layout=nt") + return matmul_nt(SplitK, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) + + +def select_implementation( + SplitK=1, + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, +): + if layout == "nn": + if propagate_a or propagate_b: + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn") + return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + elif layout == "nt": + if propagate_a and propagate_b: + raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") + elif propagate_a: + raise ValueError("Currently only support propagate_a=False for layout=nt") + elif propagate_b: + raise ValueError("Currently only support propagate_b=False for layout=nt") + else: + return matmul(SplitK, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/testing/python/operators/test_general_matmul_splitk_ops.py b/testing/python/operators/test_general_matmul_splitk_ops.py new file mode 100644 index 000000000..830cd59ee --- /dev/null +++ b/testing/python/operators/test_general_matmul_splitk_ops.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import pytest +import bitblas +from bitblas.ops.general_matmul_splitk import MatmulWithSplitK, MatmulConfigWithSplitK +import logging +from bitblas import set_log_level + + +def get_codegen_result(ops): + code = ops.get_source() + return code + + +# fmt: off +@pytest.mark.parametrize( + "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", + [ + (1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, + None), + (16, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, + None), + ], +) +def test_matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode): + + matmul_config = MatmulConfigWithSplitK( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + ) + matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) + assert get_codegen_result(matmul) + + +@pytest.mark.parametrize( + "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", + [ + (1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, + None), + (16, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, + None), + ], +) +def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + group_size, with_scaling, with_zeros, zeros_mode): + + matmul_config = MatmulConfigWithSplitK( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + ) + matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) + matmul.hardware_aware_finetune(topk=10) + assert get_codegen_result(matmul) + +@pytest.mark.parametrize( + "SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", + [ + (1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, + None), + (4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, + None), + ], +) +def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + group_size, with_scaling, with_zeros, zeros_mode): + import torch + torch.random.manual_seed(0) + matmul_config = MatmulConfigWithSplitK( + k_split=SplitK, + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + ) + matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) + + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda() - 0.5) + + output_bitblas = matmul.forward(*inputs) + output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) + torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1) + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main() From d67cc6d58a592852f8f18e2f9fb062a52d43899b Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Wed, 5 Jun 2024 07:16:36 +0000 Subject: [PATCH 12/17] Refactor matmul implementation for splitk layout --- python/bitblas/gpu/matmul_analysis.py | 17 ++++-------- python/bitblas/ops/general_matmul.py | 3 ++- python/bitblas/ops/general_matmul_splitk.py | 13 +++++---- .../ops/impl/batch_matmul_dequantize_impl.py | 27 ++++++++++--------- python/bitblas/ops/impl/batch_matmul_impl.py | 4 +-- .../ops/impl/matmul_dequantize_splitk_impl.py | 21 +++++---------- python/bitblas/ops/impl/matmul_splitk_impl.py | 4 +-- python/bitblas/ops/operator.py | 3 ++- .../test_general_matmul_splitk_ops.py | 21 ++++++++------- 9 files changed, 54 insertions(+), 59 deletions(-) diff --git a/python/bitblas/gpu/matmul_analysis.py b/python/bitblas/gpu/matmul_analysis.py index c53211e2b..4638f2e72 100644 --- a/python/bitblas/gpu/matmul_analysis.py +++ b/python/bitblas/gpu/matmul_analysis.py @@ -8,7 +8,7 @@ from typing import List, Optional, Set, Union, Tuple, Dict from tvm import tir from tvm.ir import Range -from tvm.tir import IterVar, PrimExpr, Var, BufferRegion +from tvm.tir import IterVar, PrimExpr, Var, BufferRegion, IndexMap from tvm.tir.analysis import undefined_vars from tvm.tir.schedule.schedule import BlockRV from ..base.analysis import ( @@ -17,12 +17,12 @@ get_reduction_blocks, ) from tvm.target.target import Target -from tvm.tir import IndexMap, Var from tvm.tir.stmt_functor import pre_order_visit import logging logger = logging.getLogger(__name__) + def collect_vars_from_expr(prim_expr): vars = [] @@ -352,10 +352,7 @@ def is_common_reduce(var: Var) -> bool: def has_common_reduce(var: Var) -> bool: vars = collect_vars_from_expr(var) - for v in vars: - if is_common_reduce(v): - return True - return False + return any(is_common_reduce(v) for v in vars) def check_last_trait(region: List[Range]): axes = get_ordered_axes(region) @@ -605,16 +602,12 @@ def is_common_reduce(var: Var) -> bool: def has_common_reduce(var: Var) -> bool: vars = collect_vars_from_expr(var) - for v in vars: - if is_common_reduce(v): - return True - return False - + return any(is_common_reduce(v) for v in vars) + def check_last_trait(region: List[Range]): axes = get_ordered_axes(region) return has_common_reduce(axes[-1]) - intrin_info: dict = {} in_dtype, out_dtype = get_in_out_dtypes(block_stmt) intrin_info["in_dtype"] = in_dtype diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index c79a95e90..9fe7d1345 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -40,6 +40,7 @@ def is_native_compute(A_dtype, W_dtype) -> bool: return (A_dtype, W_dtype) in NATIVE_COMPUTE_PATTERNS + @dataclass(frozen=True) class MatmulConfig: M: Union[int, Tuple[int]] = None @@ -497,7 +498,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if self.dynamic_range is not None: m = reduce(operator.mul, A.shape[:-1], 1) args.append(m) - + stream = torch.cuda.current_stream() if self.lib is None: diff --git a/python/bitblas/ops/general_matmul_splitk.py b/python/bitblas/ops/general_matmul_splitk.py index e8af5845b..e951bf126 100644 --- a/python/bitblas/ops/general_matmul_splitk.py +++ b/python/bitblas/ops/general_matmul_splitk.py @@ -3,7 +3,7 @@ from tvm.target import Target import operator from functools import reduce -from typing import Any, Optional, Union +from typing import Any, Optional, Union from .operator import TransformKind from .impl.matmul_splitk_impl import select_implementation as consistent_implementation from .impl.matmul_dequantize_splitk_impl import select_implementation as weight_dequantize_implementation @@ -17,9 +17,10 @@ WORKSPACE_SIZE = 1024 * 1024 * 256 + @dataclass(frozen=True) class MatmulConfigWithSplitK(MatmulConfig): - k_split: int = 1 # split K dimension + k_split: int = 1 # split K dimension class MatmulWithSplitK(Matmul): @@ -158,8 +159,10 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: args.append(self.lut) if output is None: - output = torch.empty((self.k_split, ) + - A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device) + output = torch.empty( + (self.k_split,) + A.shape[:-1] + (self.N,), + dtype=self.torch_output_dtype, + device=A.device) if scale is not None: args.append(scale) if zeros is not None: @@ -171,7 +174,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if self.dynamic_range is not None: m = reduce(operator.mul, A.shape[:-1], 1) args.append(m) - + stream = torch.cuda.current_stream() if self.lib is None: diff --git a/python/bitblas/ops/impl/batch_matmul_dequantize_impl.py b/python/bitblas/ops/impl/batch_matmul_dequantize_impl.py index 0660f9f5b..a3ab5ebef 100644 --- a/python/bitblas/ops/impl/batch_matmul_dequantize_impl.py +++ b/python/bitblas/ops/impl/batch_matmul_dequantize_impl.py @@ -6,14 +6,10 @@ from tvm.tir import IndexMap from bitblas.ops.operator import TransformKind from bitblas.gpu.matmul_analysis import get_propagate_map -from bitblas.quantization import ( - _tir_packed_int_to_int_convert, - _tir_packed_to_signed_convert, - _tir_packed_to_unsigned_convert, - _tir_u32_to_f4_to_f16, - _tir_u8_to_f8_e4m3_to_f16, - _tir_packed_to_unsigned_convert_with_zeros, -) +from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, + _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, + _tir_u8_to_f8_e4m3_to_f16) + def matmul_nt_dequantize_b( Batch, @@ -48,7 +44,6 @@ def matmul_nt_dequantize_b( Scale = te.placeholder((Batch, N, K // group_size), name="Scale", dtype=in_dtype) Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - def decode_func(b, n, k): if source_format == "uint": if bit == 8: @@ -187,11 +182,16 @@ def matmul_nt_dequantize_b_propagate_b( group_size = K qr = r * bit // storage_nbit A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) - B = te.placeholder((Batch, N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + B = te.placeholder((Batch, N // l, (K // scaling_factor) // qr, l, qr), + name="B", + dtype=storage_dtype) LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) Scale = te.placeholder((Batch, N, K // group_size), name="Scale", dtype=in_dtype) Zeros = te.placeholder((Batch, N, K // group_size), name="Zeros", dtype=in_dtype) - Bias = te.placeholder((Batch, N,), name="Bias", dtype=in_dtype) + Bias = te.placeholder(( + Batch, + N, + ), name="Bias", dtype=in_dtype) def fcompute(b, i, j): warp_i, warp_j = i % l, j % qr @@ -223,7 +223,10 @@ def decode_func(b, n, k): if bit == 1: # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( - bit, B_reindex[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + bit, + B_reindex[b, n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype) elif bit == 8: # 8 bit does not need to be compressed w = B_reindex[b, n, k].astype(in_dtype) diff --git a/python/bitblas/ops/impl/batch_matmul_impl.py b/python/bitblas/ops/impl/batch_matmul_impl.py index bc3994ec4..1828ed15d 100644 --- a/python/bitblas/ops/impl/batch_matmul_impl.py +++ b/python/bitblas/ops/impl/batch_matmul_impl.py @@ -3,7 +3,6 @@ # pre-transformed tir expression of matmul import tvm from tvm import te -from bitblas.gpu.matmul_analysis import get_propagate_map from bitblas.ops.operator import TransformKind @@ -27,7 +26,8 @@ def matmul_nt( k = te.reduce_axis((0, K), name="k") C = te.compute( (Batch, M, N), - lambda b, i, j: te.sum(A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k), + lambda b, i, j: te.sum( + A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k), name="C", ) last_output = C diff --git a/python/bitblas/ops/impl/matmul_dequantize_splitk_impl.py b/python/bitblas/ops/impl/matmul_dequantize_splitk_impl.py index 542ac8351..afe241b65 100644 --- a/python/bitblas/ops/impl/matmul_dequantize_splitk_impl.py +++ b/python/bitblas/ops/impl/matmul_dequantize_splitk_impl.py @@ -2,18 +2,11 @@ # Licensed under the MIT License. # pre-transformed tir expression of matmul import tvm -from tvm import te, DataType -from tvm.tir import IndexMap -from bitblas.ops.operator import TransformKind -from bitblas.gpu.matmul_analysis import get_propagate_map -from bitblas.quantization import ( - _tir_packed_int_to_int_convert, - _tir_packed_to_signed_convert, - _tir_packed_to_unsigned_convert, - _tir_u32_to_f4_to_f16, - _tir_u8_to_f8_e4m3_to_f16, - _tir_packed_to_unsigned_convert_with_zeros, -) +from tvm import te +from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, + _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, + _tir_u8_to_f8_e4m3_to_f16) + def matmul_nt_dequantize_b( SplitK, @@ -48,7 +41,6 @@ def matmul_nt_dequantize_b( Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - def decode_func(n, k): if source_format == "uint": if bit == 8: @@ -98,7 +90,8 @@ def decode_func(n, k): C = te.compute( (SplitK, M, N), lambda sk, i, j: te.sum( - A[i, sk * RK + k].astype(accum_dtype) * B_decode[j, sk * RK + k].astype(accum_dtype), axis=k), + A[i, sk * RK + k].astype(accum_dtype) * B_decode[j, sk * RK + k].astype(accum_dtype), + axis=k), name="C", ) D = te.compute((SplitK, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") diff --git a/python/bitblas/ops/impl/matmul_splitk_impl.py b/python/bitblas/ops/impl/matmul_splitk_impl.py index a373cd252..c437f64cb 100644 --- a/python/bitblas/ops/impl/matmul_splitk_impl.py +++ b/python/bitblas/ops/impl/matmul_splitk_impl.py @@ -3,7 +3,6 @@ # pre-transformed tir expression of matmul import tvm from tvm import te -from bitblas.gpu.matmul_analysis import get_propagate_map from bitblas.ops.operator import TransformKind @@ -28,7 +27,8 @@ def matmul_nt( k = te.reduce_axis((0, RK), name="k") C = te.compute( (SplitK, M, N), - lambda sk, i, j: te.sum(A[i, sk * RK + k].astype(accum_dtype) * B[j, sk * RK + k].astype(accum_dtype), axis=k), + lambda sk, i, j: te.sum( + A[i, sk * RK + k].astype(accum_dtype) * B[j, sk * RK + k].astype(accum_dtype), axis=k), name="C", ) last_output = C diff --git a/python/bitblas/ops/operator.py b/python/bitblas/ops/operator.py index cf8b9aef0..90930d6d3 100644 --- a/python/bitblas/ops/operator.py +++ b/python/bitblas/ops/operator.py @@ -289,7 +289,7 @@ def _forward_from_prebuild_lib(self, *args, stream=0): ] ctypes_args.append(ctypes.c_void_p(stream)) self.lib.call(*ctypes_args) - + def call_lib(self, *args, stream=0): self.lib.call(*args, ctypes.c_void_p(stream)) @@ -340,6 +340,7 @@ class OPExecutorCPU: """ A class to execute a sequence of operators on the CPU. """ + def __init__(self, operators: Optional[List[Operator]] = None): if operators is None: operators = [] diff --git a/testing/python/operators/test_general_matmul_splitk_ops.py b/testing/python/operators/test_general_matmul_splitk_ops.py index 830cd59ee..dd9b29d51 100644 --- a/testing/python/operators/test_general_matmul_splitk_ops.py +++ b/testing/python/operators/test_general_matmul_splitk_ops.py @@ -3,8 +3,6 @@ import pytest import bitblas from bitblas.ops.general_matmul_splitk import MatmulWithSplitK, MatmulConfigWithSplitK -import logging -from bitblas import set_log_level def get_codegen_result(ops): @@ -75,17 +73,19 @@ def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo matmul.hardware_aware_finetune(topk=10) assert get_codegen_result(matmul) + @pytest.mark.parametrize( "SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", [ - (1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None), - (4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None), + (1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, + False, None), + (4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, + False, None), ], ) -def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, - group_size, with_scaling, with_zeros, zeros_mode): +def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, + layout, with_bias, group_size, with_scaling, with_zeros, + zeros_mode): import torch torch.random.manual_seed(0) matmul_config = MatmulConfigWithSplitK( @@ -111,11 +111,12 @@ def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accu inputs = [] inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda() - 0.5) - + output_bitblas = matmul.forward(*inputs) - output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) + output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1) + # fmt: on if __name__ == "__main__": bitblas.testing.main() From 9e36b6dd4a62442c09674b76b49cc69607389d45 Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Wed, 5 Jun 2024 11:57:55 +0000 Subject: [PATCH 13/17] Refactor matmul implementation for splitk layout --- integration/BitNet/utils_quant.py | 3 ++- python/bitblas/base/utils.py | 3 ++- python/bitblas/ops/general_matmul.py | 3 ++- python/bitblas/ops/matmul.py | 3 ++- python/bitblas/ops/matmul_dequantize.py | 3 ++- python/bitblas/utils/__init__.py | 2 +- python/bitblas/utils/post_process.py | 9 +++++++++ 7 files changed, 20 insertions(+), 6 deletions(-) diff --git a/integration/BitNet/utils_quant.py b/integration/BitNet/utils_quant.py index 06a8dc119..121649387 100644 --- a/integration/BitNet/utils_quant.py +++ b/integration/BitNet/utils_quant.py @@ -119,7 +119,6 @@ def native_forward(self, input): return out def forward_fp32_simulated(self, input): - print("input: ", input) quant_input = self.activation_quant(input, self.input_bits).detach() quant_weight = self.weight_quant(self.weight).detach() @@ -139,6 +138,8 @@ def forward_fp32_simulated(self, input): return out def forward(self, input): + # return self.forward_fp32_simulated(input) + quant_input = self.activation_quant(input, self.input_bits).detach() fp32_out = self.bitblas_matmul(quant_input, self.weight) sw = self.sw diff --git a/python/bitblas/base/utils.py b/python/bitblas/base/utils.py index 23a817f78..7da309dd5 100644 --- a/python/bitblas/base/utils.py +++ b/python/bitblas/base/utils.py @@ -19,7 +19,7 @@ import tempfile import itertools from tvm.ir.supply import GlobalVarSupply -from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4 +from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 import logging logger = logging.getLogger(__name__) @@ -205,6 +205,7 @@ def _build(context) -> str: def tvm_callback_cuda_postproc(code, _): code = tensor_replace_dp4a(code) code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) return code with tvm.transform.PassContext(config={"tir.use_async_copy": True, **config.pass_context}): diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index 9fe7d1345..af2da3f02 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -10,7 +10,7 @@ from .impl.matmul_dequantize_impl import ( select_implementation as weight_dequantize_implementation,) from .impl.matmul_impl import select_implementation as consistent_implementation -from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4 +from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 from bitblas.utils.target_detector import auto_detect_nvidia_target from dataclasses import dataclass from .ladder_permutate import LadderPermutate, LadderPermutateConfig @@ -398,6 +398,7 @@ def _select_implementation(self): def post_process(self, code: str) -> str: code = tensor_replace_dp4a(code) code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) return code def retrieve_weight_shape(self): diff --git a/python/bitblas/ops/matmul.py b/python/bitblas/ops/matmul.py index 59729a426..7783c4972 100644 --- a/python/bitblas/ops/matmul.py +++ b/python/bitblas/ops/matmul.py @@ -7,7 +7,7 @@ from typing import List, Union, Optional, Any, Tuple from .operator import Operator, TransformKind from .impl.matmul_impl import select_implementation -from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4 +from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 from dataclasses import dataclass from .ladder_permutate import LadderPermutate, LadderPermutateConfig import logging @@ -189,6 +189,7 @@ def _select_implementation(self): def post_process(self, code: str) -> str: code = tensor_replace_dp4a(code) code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) return code def _profile_latency_with_dynamic_range(self) -> List: diff --git a/python/bitblas/ops/matmul_dequantize.py b/python/bitblas/ops/matmul_dequantize.py index d1dc35c94..25c68b121 100644 --- a/python/bitblas/ops/matmul_dequantize.py +++ b/python/bitblas/ops/matmul_dequantize.py @@ -6,7 +6,7 @@ from typing import Any, List, Literal, Optional, Tuple, Union from .operator import Operator, TransformKind from .impl.matmul_dequantize_impl import select_implementation -from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4 +from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 from bitblas.utils.tensor_adapter import tvm_tensor_to_torch from dataclasses import dataclass from .ladder_permutate import LadderPermutate, LadderPermutateConfig @@ -234,6 +234,7 @@ def _select_implementation(self): def post_process(self, code: str) -> str: code = tensor_replace_dp4a(code) code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) return code def retrieve_weight_shape(self): diff --git a/python/bitblas/utils/__init__.py b/python/bitblas/utils/__init__.py index f9587964c..7eb879bce 100644 --- a/python/bitblas/utils/__init__.py +++ b/python/bitblas/utils/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4 # noqa: F401 +from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 # noqa: F401 from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401 from .target_detector import auto_detect_nvidia_target # noqa: F401 diff --git a/python/bitblas/utils/post_process.py b/python/bitblas/utils/post_process.py index e4fe5f95f..cabee6be1 100644 --- a/python/bitblas/utils/post_process.py +++ b/python/bitblas/utils/post_process.py @@ -27,3 +27,12 @@ def tensor_remove_make_int4(source: str) -> str: "make_int4(0, 0, 0, 0)", ) return source + +def tensor_remove_make_int2(source: str) -> str: + # remove make_int4 with 16 signed char arguments + # TODO(lei): this is a stuff that should be fixed in the tvm in the future + source = source.replace( + "make_int2((signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0)", + "make_int2(0, 0)", + ) + return source From e1a01496f6a59054d2904bcf51d70759b755b143 Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Wed, 5 Jun 2024 11:58:16 +0000 Subject: [PATCH 14/17] chore: Update version to 0.0.1.dev8 --- VERSION | 2 +- python/bitblas/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/VERSION b/VERSION index 407ab24ea..9eac5e019 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.1.dev7 \ No newline at end of file +0.0.1.dev8 \ No newline at end of file diff --git a/python/bitblas/__init__.py b/python/bitblas/__init__.py index 3f806bfde..3bd32875e 100644 --- a/python/bitblas/__init__.py +++ b/python/bitblas/__init__.py @@ -81,4 +81,4 @@ def _init_logger(): _init_logger() -__version__ = "0.0.1.dev7" +__version__ = "0.0.1.dev8" From df0ed7a34b76b717ed5366fa384568b15b6386ec Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Wed, 5 Jun 2024 16:28:39 +0000 Subject: [PATCH 15/17] chore: Enable debug output in bitblas.set_debug_level() --- docs/QuickStart.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/QuickStart.md b/docs/QuickStart.md index 2285a2313..5a57edbb2 100644 --- a/docs/QuickStart.md +++ b/docs/QuickStart.md @@ -12,6 +12,9 @@ Here is an example for a $W_{INT4}A_{FP16}$ mixed-precision matrix multiplicatio import bitblas import torch +# enabling debug output + +bitblas.set_debug_level("Debug") matmul_config = bitblas.MatmulConfig( M=1, # M dimension N=1024, # N dimension @@ -125,6 +128,9 @@ Here is an example to define a ```bitblas.Linear``` of $W_{INT4}A_{FP16}$: import bitblas import torch +# enabling debug output +bitblas.set_debug_level("Debug") + model = bitblas.Linear( in_features=1024, out_features=1024, @@ -178,6 +184,9 @@ from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import ( QuantLinear as CudaOldQuantLinear, ) +# enabling debug output +bitblas.set_debug_level("Debug") + in_features = 1024 out_features = 1024 group_size = 128 From a0f651a3ba2c8146f06dc21b6aba48443fe3b956 Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Wed, 5 Jun 2024 16:28:48 +0000 Subject: [PATCH 16/17] Refactor Linear module matmul implementation for splitk layout --- python/bitblas/module/__init__.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/bitblas/module/__init__.py b/python/bitblas/module/__init__.py index eaf15bc1d..e29c9de0f 100644 --- a/python/bitblas/module/__init__.py +++ b/python/bitblas/module/__init__.py @@ -232,15 +232,18 @@ def forward(self, A, output=None): A = A.half() # can be lifted to post init. self.init_params() - + if output is None: output = torch.empty( A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device) m = ctypes.c_int32(reduce(operator.mul, A.shape[:-1], 1)) A = self.bitblas_matmul.transform_input(A) + stream = torch.cuda.current_stream() + A_void = ctypes.c_void_p(A.data_ptr()) + stream_handle = ctypes.c_void_p(stream.cuda_stream) # m is the product of the last n - 1 dimensions of A - self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m) + self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m, stream_handle) return output From 88295a73deff9b5b61d8865627316c8ed33dd499 Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Wed, 5 Jun 2024 16:28:55 +0000 Subject: [PATCH 17/17] Refactor matmul implementation for splitk layout --- python/bitblas/ops/general_matmul_splitk.py | 11 ++- .../operators/test_general_matmul_fp8.py | 4 +- .../test_general_matmul_splitk_ops.py | 70 +++++++++++++------ 3 files changed, 61 insertions(+), 24 deletions(-) diff --git a/python/bitblas/ops/general_matmul_splitk.py b/python/bitblas/ops/general_matmul_splitk.py index e951bf126..28e3cbbf2 100644 --- a/python/bitblas/ops/general_matmul_splitk.py +++ b/python/bitblas/ops/general_matmul_splitk.py @@ -160,7 +160,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if output is None: output = torch.empty( - (self.k_split,) + A.shape[:-1] + (self.N,), + A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device) if scale is not None: @@ -169,7 +169,12 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: args.append(zeros) if bias is not None: args.append(bias) - args.append(output) + + sk_output = torch.empty((self.k_split,) + + A.shape[:-1] + (self.N,), + dtype=self.torch_output_dtype, + device=A.device) + args.append(sk_output) if self.dynamic_range is not None: m = reduce(operator.mul, A.shape[:-1], 1) @@ -180,7 +185,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if self.lib is None: self._forward_from_torch_func(*args) self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) - output = torch.sum(output, dim=0) + torch.sum(sk_output, dim=0, out=output) return output def __call__(self, *args: Any, **kwds: Any) -> Any: diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py index 5b7de9ab0..603a57248 100644 --- a/testing/python/operators/test_general_matmul_fp8.py +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -171,4 +171,6 @@ def map_torch_type(intype): # fmt: on if __name__ == "__main__": - bitblas.testing.main() + # bitblas.testing.main() + test_matmul_torch_forward_weight_dequantize(1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None, + None, None) diff --git a/testing/python/operators/test_general_matmul_splitk_ops.py b/testing/python/operators/test_general_matmul_splitk_ops.py index dd9b29d51..ac3a15a9c 100644 --- a/testing/python/operators/test_general_matmul_splitk_ops.py +++ b/testing/python/operators/test_general_matmul_splitk_ops.py @@ -41,20 +41,22 @@ def test_matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtyp matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) assert get_codegen_result(matmul) - @pytest.mark.parametrize( - "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", + "SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", [ - (1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None), - (16, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None), + (1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, + False, None), + (4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, + False, None), ], ) -def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, - group_size, with_scaling, with_zeros, zeros_mode): - +def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, + layout, with_bias, group_size, with_scaling, with_zeros, + zeros_mode): + import torch + torch.random.manual_seed(0) matmul_config = MatmulConfigWithSplitK( + k_split=SplitK, M=M, N=N, K=K, @@ -70,20 +72,27 @@ def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo zeros_mode=zeros_mode, ) matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) - matmul.hardware_aware_finetune(topk=10) - assert get_codegen_result(matmul) + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda() - 0.5) + + output_bitblas = matmul.forward(*inputs) + output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) + torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1) @pytest.mark.parametrize( "SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", [ - (1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, + (1, 16, 4096, 12800, "float16", "e4m3_float8", "float32", "float16", "nt", False, -1, False, False, None), - (4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, + (4, 16, 4096, 12800, "float16", "e4m3_float8", "float32", "float16", "nt", False, -1, False, False, None), ], ) -def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, +def test_matmul_torch_forward_fp8e4m3(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode): import torch @@ -103,18 +112,39 @@ def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accu with_scaling=with_scaling, with_zeros=with_zeros, zeros_mode=zeros_mode, + propagate_a=False, + propagate_b=False, ) matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) - inputs = [] - inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) - inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda() - 0.5) + def map_torch_type(intype): - output_bitblas = matmul.forward(*inputs) - output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) - torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1) + typemap = { + 'e4m3_float8': torch.float8_e4m3fn, + 'e5m2_float8': torch.float8_e5m2, + } + if intype in typemap: + return typemap[intype] + else: + return getattr(torch, intype) + + numpytype_a = map_torch_type(A_dtype) + numpytype_b = map_torch_type(W_dtype) + + torch_a = torch.rand(M * K).uniform_(-1, 1).reshape(input_shape).type(numpytype_a).cuda() + torch_b = torch.rand(N * K).uniform_(-1, 1).reshape(weight_shape).type(numpytype_b).cuda() + ref_out = torch.matmul(torch_a.to(torch.float32), + torch_b.t().to(torch.float32)) if layout == "nt" else torch.matmul( + torch_a.to(torch.float32), torch_b.to(torch.float32)) + ref_out = ref_out.to(torch.float16) + bitblas_out = torch.empty_like(ref_out) + matmul.forward(torch_a, torch_b, output=bitblas_out) + print("torch_ref_out", ref_out) + print("bitblas_out", bitblas_out) + + torch.testing.assert_close(bitblas_out, ref_out, rtol=1e0, atol=1e-1) # fmt: on