From d8884e6f6a294fc8f1a325665d86a07603d43864 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:54:26 +0000 Subject: [PATCH 01/22] Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability --- bitblas/ops/impl/base.py | 16 +++ bitblas/ops/impl/batch_matmul_impl.py | 166 ++++++++++++++++---------- 2 files changed, 119 insertions(+), 63 deletions(-) create mode 100644 bitblas/ops/impl/base.py diff --git a/bitblas/ops/impl/base.py b/bitblas/ops/impl/base.py new file mode 100644 index 00000000..6d510f7d --- /dev/null +++ b/bitblas/ops/impl/base.py @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from abc import ABC, abstractmethod + +# TODO: Refactor all the tir script implementations to use this base class +# Abstract base class for TIR script emitters +class TIRScriptEmitter(ABC): + @abstractmethod + def emit(self): + raise NotImplementedError + +# Abstract base class for TIR script selectors +class TIRScriptSelector(ABC): + @abstractmethod + def select(self): + raise NotImplementedError diff --git a/bitblas/ops/impl/batch_matmul_impl.py b/bitblas/ops/impl/batch_matmul_impl.py index 09b536af..75449ea4 100644 --- a/bitblas/ops/impl/batch_matmul_impl.py +++ b/bitblas/ops/impl/batch_matmul_impl.py @@ -4,63 +4,117 @@ from bitblas import tvm from tvm import te from bitblas.ops.operator import TransformKind +from .base import TIRScriptEmitter, TIRScriptSelector +from bitblas import tvm +from tvm import te +from bitblas.ops.operator import TransformKind +class BatchMatMulEmitter(TIRScriptEmitter): + def __init__( + self, + batch, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + ): + self.batch = batch + self.M = self._validate_dimension(M, "M") + self.N = self._validate_dimension(N, "N") + self.K = self._validate_dimension(K, "K") + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.accum_dtype = accum_dtype + self.with_bias = with_bias + self.layout = layout + self._validate_layout() + + @staticmethod + def _validate_dimension(dim, name): + if not isinstance(dim, int): + return tvm.te.var(name.lower()) + return dim -def matmul_nt( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, -): - if not isinstance(M, int): - M = tvm.te.var("m") - A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) - B = te.placeholder((Batch, N, K), name="B", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + def _validate_layout(self): + if self.layout not in ["nn", "nt"]: + raise ValueError(f"Unsupported layout: {self.layout}") + if self.layout == "nn": + raise ValueError("Currently only support layout=nt") - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (Batch, M, N), - lambda b, i, j: te.sum( - A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k), - name="C", - ) - last_output = C - if accum_dtype != out_dtype: - D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") - last_output = D + def _create_placeholders(self): + A = te.placeholder((self.batch, self.M, self.K), name="A", dtype=self.in_dtype) + B = te.placeholder((self.batch, self.N, self.K), name="B", dtype=self.in_dtype) + Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None + return A, B, Bias - if with_bias: - E = te.compute((Batch, M, N), lambda b, i, j: last_output[b, i, j] + Bias[j], name="E") - last_output = E + def _compute_matmul(self, A, B): + k = te.reduce_axis((0, self.K), name="k") + C = te.compute( + (self.batch, self.M, self.N), + lambda b, i, j: te.sum( + A[b, i, k].astype(self.accum_dtype) * B[b, j, k].astype(self.accum_dtype), axis=k), + name="C", + ) + return C - args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + def _apply_bias(self, C, Bias): + if self.with_bias: + return te.compute((self.batch, self.M, self.N), lambda b, i, j: C[b, i, j] + Bias[j], name="E") + return C - func = te.create_prim_func(args) + def _convert_dtype(self, tensor): + if self.accum_dtype != self.out_dtype: + return te.compute((self.batch, self.M, self.N), lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), name="D") + return tensor - return tvm.IRModule.from_expr(func) + def emit(self): + A, B, Bias = self._create_placeholders() + C = self._compute_matmul(A, B) + last_output = self._convert_dtype(C) + if self.with_bias: + last_output = self._apply_bias(last_output, Bias) + args = [A, B, Bias, last_output] if self.with_bias else [A, B, last_output] + func = te.create_prim_func(args) + return tvm.IRModule.from_expr(func) -def matmul( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - layout="nt", -): - if layout == "nn": - raise ValueError("Currently only support layout=nt") - return matmul_nt(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) +class BatchMatMulSelector(TIRScriptSelector): + def __init__(self, propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform): + self.propagate_a = propagate_a + self.propagate_b = propagate_b + + def select( + self, + batch=1, + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + ): + if layout == "nn": + if self.propagate_a or self.propagate_b: + raise ValueError("Currently only support propagate_a=False and propagate_b=False for layout=nn") + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + elif layout == "nt": + if self.propagate_a and self.propagate_b: + raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") + elif self.propagate_a: + raise ValueError("Currently only support propagate_a=False for layout=nt") + elif self.propagate_b: + raise ValueError("Currently only support propagate_b=False for layout=nt") + else: + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + else: + raise ValueError(f"Unsupported layout: {layout}") def select_implementation( Batch=1, @@ -75,19 +129,5 @@ def select_implementation( propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform, ): - if layout == "nn": - if propagate_a or propagate_b: - raise ValueError( - "Currently only support propagate_a=False and propagate_b=False for layout=nn") - return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - elif layout == "nt": - if propagate_a and propagate_b: - raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") - elif propagate_a: - raise ValueError("Currently only support propagate_a=False for layout=nt") - elif propagate_b: - raise ValueError("Currently only support propagate_b=False for layout=nt") - else: - return matmul(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - else: - raise ValueError(f"Unsupported layout: {layout}") + selector = BatchMatMulSelector(propagate_a, propagate_b) + return selector.select(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) From fc84173f22d2f4867a8e6413117b5cd8e830ab27 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:57:43 +0000 Subject: [PATCH 02/22] Refactor import statements for improved readability and maintainability --- bitblas/ops/impl/__init__.py | 2 +- bitblas/ops/impl/base.py | 4 ++++ bitblas/ops/impl/batch_matmul_impl.py | 33 ++++++++++++++++++--------- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/bitblas/ops/impl/__init__.py b/bitblas/ops/impl/__init__.py index a254dc7f..8a9bbd2a 100644 --- a/bitblas/ops/impl/__init__.py +++ b/bitblas/ops/impl/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .lop3_permutate_impl import tir_interleave_weight +from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 diff --git a/bitblas/ops/impl/base.py b/bitblas/ops/impl/base.py index 6d510f7d..4a67987b 100644 --- a/bitblas/ops/impl/base.py +++ b/bitblas/ops/impl/base.py @@ -2,15 +2,19 @@ # Licensed under the MIT License. from abc import ABC, abstractmethod + # TODO: Refactor all the tir script implementations to use this base class # Abstract base class for TIR script emitters class TIRScriptEmitter(ABC): + @abstractmethod def emit(self): raise NotImplementedError + # Abstract base class for TIR script selectors class TIRScriptSelector(ABC): + @abstractmethod def select(self): raise NotImplementedError diff --git a/bitblas/ops/impl/batch_matmul_impl.py b/bitblas/ops/impl/batch_matmul_impl.py index 75449ea4..3904f36e 100644 --- a/bitblas/ops/impl/batch_matmul_impl.py +++ b/bitblas/ops/impl/batch_matmul_impl.py @@ -5,11 +5,10 @@ from tvm import te from bitblas.ops.operator import TransformKind from .base import TIRScriptEmitter, TIRScriptSelector -from bitblas import tvm -from tvm import te -from bitblas.ops.operator import TransformKind + class BatchMatMulEmitter(TIRScriptEmitter): + def __init__( self, batch, @@ -32,7 +31,7 @@ def __init__( self.with_bias = with_bias self.layout = layout self._validate_layout() - + @staticmethod def _validate_dimension(dim, name): if not isinstance(dim, int): @@ -48,7 +47,8 @@ def _validate_layout(self): def _create_placeholders(self): A = te.placeholder((self.batch, self.M, self.K), name="A", dtype=self.in_dtype) B = te.placeholder((self.batch, self.N, self.K), name="B", dtype=self.in_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None + Bias = te.placeholder( + (self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None return A, B, Bias def _compute_matmul(self, A, B): @@ -63,12 +63,16 @@ def _compute_matmul(self, A, B): def _apply_bias(self, C, Bias): if self.with_bias: - return te.compute((self.batch, self.M, self.N), lambda b, i, j: C[b, i, j] + Bias[j], name="E") + return te.compute((self.batch, self.M, self.N), + lambda b, i, j: C[b, i, j] + Bias[j], + name="E") return C def _convert_dtype(self, tensor): if self.accum_dtype != self.out_dtype: - return te.compute((self.batch, self.M, self.N), lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), name="D") + return te.compute((self.batch, self.M, self.N), + lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), + name="D") return tensor def emit(self): @@ -84,7 +88,10 @@ def emit(self): class BatchMatMulSelector(TIRScriptSelector): - def __init__(self, propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform): + + def __init__(self, + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform): self.propagate_a = propagate_a self.propagate_b = propagate_b @@ -102,8 +109,10 @@ def select( ): if layout == "nn": if self.propagate_a or self.propagate_b: - raise ValueError("Currently only support propagate_a=False and propagate_b=False for layout=nn") - return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn") + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, + layout).emit() elif layout == "nt": if self.propagate_a and self.propagate_b: raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") @@ -112,10 +121,12 @@ def select( elif self.propagate_b: raise ValueError("Currently only support propagate_b=False for layout=nt") else: - return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, + with_bias, layout).emit() else: raise ValueError(f"Unsupported layout: {layout}") + def select_implementation( Batch=1, M=None, From 02f64de6cf2d338c092dcf29ec55b69804fda892 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:58:06 +0000 Subject: [PATCH 03/22] Refactor import statements for improved readability and maintainability --- bitblas/ops/impl/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/ops/impl/__init__.py b/bitblas/ops/impl/__init__.py index 8a9bbd2a..67e49b2a 100644 --- a/bitblas/ops/impl/__init__.py +++ b/bitblas/ops/impl/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 +from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 From 397eee6141599e84b509594bb99a0531e409c266 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 16:25:47 +0000 Subject: [PATCH 04/22] disable failure email for ci --- .github/workflows/ci.yml | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ceb69fcc..1fbdf19d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,4 +64,13 @@ jobs: run: | source bitblas_ci/bin/activate cd testing/python - python -m pytest \ No newline at end of file + python -m pytest + + # Control notifications + notify: + runs-on: self-hosted + needs: [format-check, build-test] + if: failure() + steps: + - name: Notification + run: echo "Jobs failed, but no email will be sent." From 20f6ad1e7ca4e6e1ca9e13ad7c1bbc8c430a8e51 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 03:23:50 +0000 Subject: [PATCH 05/22] remove email notifications. --- .github/workflows/ci.yml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1fbdf19d..511b9583 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,12 +65,3 @@ jobs: source bitblas_ci/bin/activate cd testing/python python -m pytest - - # Control notifications - notify: - runs-on: self-hosted - needs: [format-check, build-test] - if: failure() - steps: - - name: Notification - run: echo "Jobs failed, but no email will be sent." From b93c39431c803e22b12f71b555939785da36b96a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 03:25:05 +0000 Subject: [PATCH 06/22] move relax pass from testing to mlc_llm --- .../mlc_llm}/test_weight_only_transform.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {testing/python/transform => integration/mlc_llm}/test_weight_only_transform.py (100%) diff --git a/testing/python/transform/test_weight_only_transform.py b/integration/mlc_llm/test_weight_only_transform.py similarity index 100% rename from testing/python/transform/test_weight_only_transform.py rename to integration/mlc_llm/test_weight_only_transform.py From 257693a7c3cb3083aac144182f58d38bfe3bcfdd Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:51:01 +0000 Subject: [PATCH 07/22] Refactor scripts with se check_eual_ref_scripts_with_emitter function --- bitblas/ops/impl/matmul_dequantize_impl.py | 224 ++++++++++++++---- .../operators/test_tir_script_emitter.py | 52 +++- 2 files changed, 216 insertions(+), 60 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 1ed6b340..e69e8fcf 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -15,8 +15,10 @@ _tir_packed_to_unsigned_convert_with_zeros, ) + # TODO: The following code should be refactored. class MatMulNTDequantizeEmitter: + def __init__( self, M, @@ -52,8 +54,8 @@ def __init__( self.fast_decoding = fast_decoding self.with_bias = with_bias self.zeros_mode = zeros_mode - self.propagate_a = propagate_a - self.propagate_b = propagate_b + self.propagate_a = self._legalize_transform_kind(propagate_a) + self.propagate_b = self._legalize_transform_kind(propagate_b) self._validate_bit() self._validate_layout() @@ -69,62 +71,169 @@ def _validate_bit(self): raise ValueError(f"Unsupported bit: {self.bit}") def _validate_layout(self): - if self.layout not in ["nt"]: - raise ValueError(f"Unsupported layout: {self.layout}") + # TODO: extend the dequantize operators into General Layout + pass + + def _legalize_group_size(self): + if self.group_size == -1: + self.group_size = self.K + + def _legalize_transform_kind(self, propagate): + if propagate is None: + return TransformKind.NonTransform + if isinstance(propagate, bool): + return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) + elif isinstance(propagate, int): + return TransformKind(propagate) def _create_placeholders(self): - storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) - n_float_per_elem = storage_nbit // self.bit - - A = te.placeholder((self.M, self.K), name="A", dtype=self.in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * self.bit), name="B", dtype=self.storage_dtype) - LUT = te.placeholder((1 << self.bit,), name="LUT", dtype=self.in_dtype) - Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=self.in_dtype) - Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=self.in_dtype) - QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * self.bit), + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + in_dtype = self.in_dtype + bit = self.bit + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * bit), + name="B", + dtype=storage_dtype) + if self.propagate_a: + A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype) + if self.propagate_b: + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // target_dtype.bits) + qr = r * bit // storage_nbit + B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=in_dtype) + QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * bit), name="QZeros", dtype=self.storage_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) - return A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem + Bias = te.placeholder((self.N,), name="Bias", dtype=in_dtype) + return A, B, LUT, Scale, Zeros, QZeros, Bias + + def _propagate_input(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="A"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=False, dtype=in_dtype, matrix_name=matrix_name) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.M, self.K), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _propagage_weight(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="B"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + bit = self.bit + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=True, dtype=in_dtype, matrix_name=matrix_name) + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inversed_index_map.initial_indices + scaling_final_indices = inversed_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inversed_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + qr = r * bit // storage_nbit - def _decode_func(self, B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem): - w = None + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.N, self.K // storage_nbit * bit), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _decode_func(self, B, LUT, Scale, Zeros, QZeros): + bit = self.bit + in_dtype = self.in_dtype + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + + # TODO: Move the decode function into a more general place def decode(n, k): + w = None if self.with_zeros and self.zeros_mode == "quantized": - qzeros_dequantize = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, + qzeros_dequantize = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, QZeros[k, n // n_float_per_elem], n % n_float_per_elem, dtype=self.storage_dtype, ) - w = _tir_packed_to_unsigned_convert_with_zeros(self.storage_dtype, storage_nbit)( - self.bit, + w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, qzeros_dequantize, - dtype=self.in_dtype, + dtype=in_dtype, ) elif self.source_format == "uint": - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "int": - if self.bit == 1: - w = _tir_packed_int_to_int_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_signed_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp": w = _tir_u32_to_f4_to_f16( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(self.bit, B[n, k], dtype=self.in_dtype) + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) elif self.source_format == "nf": - index = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, + index = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype="int32", @@ -132,7 +241,9 @@ def decode(n, k): w = LUT[index] else: raise ValueError(f"Unsupported source_format: {self.source_format}") - + + assert w is not None, "w is None" + group_size = self.group_size zeros_mode = self.zeros_mode @@ -167,7 +278,9 @@ def _compute_matmul(self, A, B_decode): def _convert_dtype(self, tensor): if self.accum_dtype != self.out_dtype: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j].astype(self.out_dtype), name="D") + return te.compute((self.M, self.N), + lambda i, j: tensor[i, j].astype(self.out_dtype), + name="D") return tensor def _apply_bias(self, tensor, Bias): @@ -176,9 +289,12 @@ def _apply_bias(self, tensor, Bias): return tensor def emit(self): - A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem = self._create_placeholders() - B_decode = self._decode_func(B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem) - C = self._compute_matmul(A, B_decode) + A, B, LUT, Scale, Zeros, QZeros, Bias = self._create_placeholders() + A_reindex = self._propagate_input(A, self.propagate_a, "A") + B_reindex = self._propagage_weight(B, self.propagate_b, "B") + + B_decode = self._decode_func(B_reindex, LUT, Scale, Zeros, QZeros) + C = self._compute_matmul(A_reindex, B_decode) D = self._convert_dtype(C) last_output = self._apply_bias(D, Bias) @@ -212,8 +328,13 @@ def emit(self): } }, ) + if self.propagate_a: + func = func.with_attr("input_transform_kind", self.propagate_a.value) + if self.propagate_b: + func = func.with_attr("weight_transform_kind", self.propagate_b.value) return tvm.IRModule.from_expr(func) + def matmul_nt_dequantize_b( M, N, @@ -335,9 +456,12 @@ def decode_func(n, k): A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: @@ -517,9 +641,11 @@ def decode_func(n, k): A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: @@ -715,9 +841,11 @@ def decode_func(n, k): ), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: diff --git a/testing/python/operators/test_tir_script_emitter.py b/testing/python/operators/test_tir_script_emitter.py index cec56b47..fcfa7d9a 100644 --- a/testing/python/operators/test_tir_script_emitter.py +++ b/testing/python/operators/test_tir_script_emitter.py @@ -1,18 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from bitblas.ops.impl.matmul_dequantize_impl import ( - MatMulNTDequantizeEmitter, - matmul_nt_dequantize_b, - matmul_nt_dequantize_b_propagate_b, - matmul_nt_dequantize_b_propagate_a_propagate_b, -) from bitblas import tvm import logging from bitblas import set_log_level set_log_level(logging.DEBUG) -def compare_tir_scripts_and_emitter( + +def check_eual_ref_scripts_with_emitter( M, N, K, @@ -28,8 +23,26 @@ def compare_tir_scripts_and_emitter( fast_decoding, with_bias, zeros_mode, + propagate_a, + propagate_b, ): - tir_script_func = matmul_nt_dequantize_b( + from bitblas.ops.impl.matmul_dequantize_impl import ( + MatMulNTDequantizeEmitter, + matmul_nt_dequantize_b, + matmul_nt_dequantize_b_propagate_b, + matmul_nt_dequantize_b_propagate_a_propagate_b, + ) + func = None + if propagate_a and propagate_b: + func = matmul_nt_dequantize_b_propagate_a_propagate_b + elif propagate_b: + func = matmul_nt_dequantize_b_propagate_b + else: + func = matmul_nt_dequantize_b + + assert func is not None, "No function found for the given configuration" + + ref_func = func( M, N, K, @@ -46,8 +59,8 @@ def compare_tir_scripts_and_emitter( with_bias, zeros_mode, ) - - emitter_func = MatMulNTDequantizeEmitter( + + emit_func = MatMulNTDequantizeEmitter( M, N, K, @@ -63,6 +76,21 @@ def compare_tir_scripts_and_emitter( fast_decoding, with_bias, zeros_mode, + propagate_a=propagate_a, + propagate_b=propagate_b, ).emit() - - tvm.ir.assert_structural_equal(tir_script_func, emitter_func) + + tvm.ir.assert_structural_equal(ref_func, emit_func) + + +def test_check_eual_ref_scripts_with_emitter(): + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) + check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", True, True) + +if __name__ == "__main__": + test_check_eual_ref_scripts_with_emitter() From 9bb7f49a968d4c71dbbc12121b4b7cb8258b2136 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:51:15 +0000 Subject: [PATCH 08/22] Lint Fix --- bitblas/ops/impl/matmul_dequantize_impl.py | 13 +++++---- .../operators/test_tir_script_emitter.py | 29 ++++++++++++++----- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index e69e8fcf..7b91764c 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -73,7 +73,7 @@ def _validate_bit(self): def _validate_layout(self): # TODO: extend the dequantize operators into General Layout pass - + def _legalize_group_size(self): if self.group_size == -1: self.group_size = self.K @@ -96,18 +96,19 @@ def _create_placeholders(self): l, r = 16, 32 # noqa: E741 A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * bit), - name="B", - dtype=storage_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * bit), name="B", dtype=storage_dtype) if self.propagate_a: A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype) if self.propagate_b: target_dtype = DataType(in_dtype) scaling_factor = 1 if bit > 0 and bit < target_dtype.bits: - scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // target_dtype.bits) + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) qr = r * bit // storage_nbit - B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), + name="B", + dtype=storage_dtype) LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype) diff --git a/testing/python/operators/test_tir_script_emitter.py b/testing/python/operators/test_tir_script_emitter.py index fcfa7d9a..b2c7a8d4 100644 --- a/testing/python/operators/test_tir_script_emitter.py +++ b/testing/python/operators/test_tir_script_emitter.py @@ -84,13 +84,28 @@ def check_eual_ref_scripts_with_emitter( def test_check_eual_ref_scripts_with_emitter(): - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) - check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", True, True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "nf", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, + "int8", "nf", True, False, -1, False, False, "original", + False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + True) + check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, + "int8", "uint", True, False, -1, False, False, "original", + True, True) + if __name__ == "__main__": test_check_eual_ref_scripts_with_emitter() From 93eb5a5fe4e3eb6242675dd5706358c4121f1672 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:53:50 +0000 Subject: [PATCH 09/22] Refactor scripts with se check_eual_ref_scripts_with_emitter function --- bitblas/ops/impl/matmul_dequantize_impl.py | 198 --------------------- 1 file changed, 198 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 1ef14100..7b91764c 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -15,204 +15,6 @@ _tir_packed_to_unsigned_convert_with_zeros, ) -# TODO: The following code should be refactored. -class MatMulNTDequantizeEmitter: - def __init__( - self, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - zeros_mode="original", - propagate_a: TransformKind = TransformKind.NonTransform, - propagate_b: TransformKind = TransformKind.NonTransform, - ): - self.M = self._validate_dimension(M, "M") - self.N = N - self.K = K - self.in_dtype = in_dtype - self.out_dtype = out_dtype - self.accum_dtype = accum_dtype - self.bit = bit - self.storage_dtype = storage_dtype - self.source_format = source_format - self.with_scaling = with_scaling - self.with_zeros = with_zeros - self.group_size = group_size if group_size != -1 else K - self.fast_decoding = fast_decoding - self.with_bias = with_bias - self.zeros_mode = zeros_mode - self.propagate_a = propagate_a - self.propagate_b = propagate_b - - self._validate_bit() - self._validate_layout() - - @staticmethod - def _validate_dimension(dim, name): - if not isinstance(dim, int): - return tvm.te.var(name.lower()) - return dim - - def _validate_bit(self): - if self.bit not in [1, 2, 4, 8]: - raise ValueError(f"Unsupported bit: {self.bit}") - - def _validate_layout(self): - if self.layout not in ["nt"]: - raise ValueError(f"Unsupported layout: {self.layout}") - - def _create_placeholders(self): - storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) - n_float_per_elem = storage_nbit // self.bit - - A = te.placeholder((self.M, self.K), name="A", dtype=self.in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * self.bit), name="B", dtype=self.storage_dtype) - LUT = te.placeholder((1 << self.bit,), name="LUT", dtype=self.in_dtype) - Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=self.in_dtype) - Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=self.in_dtype) - QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * self.bit), - name="QZeros", - dtype=self.storage_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) - return A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem - - def _decode_func(self, B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem): - w = None - def decode(n, k): - if self.with_zeros and self.zeros_mode == "quantized": - qzeros_dequantize = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, - QZeros[k, n // n_float_per_elem], - n % n_float_per_elem, - dtype=self.storage_dtype, - ) - w = _tir_packed_to_unsigned_convert_with_zeros(self.storage_dtype, storage_nbit)( - self.bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - qzeros_dequantize, - dtype=self.in_dtype, - ) - elif self.source_format == "uint": - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "int": - if self.bit == 1: - w = _tir_packed_int_to_int_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_signed_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "fp": - w = _tir_u32_to_f4_to_f16( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(self.bit, B[n, k], dtype=self.in_dtype) - elif self.source_format == "nf": - index = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", - ) - w = LUT[index] - else: - raise ValueError(f"Unsupported source_format: {self.source_format}") - - group_size = self.group_size - zeros_mode = self.zeros_mode - - if not self.with_scaling: - return w - - if not self.with_zeros: - return w * Scale[n, k // group_size] - - if zeros_mode == "original": - w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] - elif zeros_mode == "rescale": - w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] - elif zeros_mode == "quantized": - w = w * Scale[n, k // group_size] - else: - raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) - - return w - - return te.compute((self.N, self.K), decode, name="B_decode") - - def _compute_matmul(self, A, B_decode): - k = te.reduce_axis((0, self.K), name="k") - C = te.compute( - (self.M, self.N), - lambda i, j: te.sum( - A[i, k].astype(self.accum_dtype) * B_decode[j, k].astype(self.accum_dtype), axis=k), - name="C", - ) - return C - - def _convert_dtype(self, tensor): - if self.accum_dtype != self.out_dtype: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j].astype(self.out_dtype), name="D") - return tensor - - def _apply_bias(self, tensor, Bias): - if self.with_bias: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j] + Bias[j], name="E") - return tensor - - def emit(self): - A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem = self._create_placeholders() - B_decode = self._decode_func(B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem) - C = self._compute_matmul(A, B_decode) - D = self._convert_dtype(C) - last_output = self._apply_bias(D, Bias) - - args = [A, B] - if self.source_format == "nf": - args.append(LUT) - if self.with_scaling: - args.append(Scale) - if self.with_zeros: - args.append(QZeros if self.zeros_mode == "quantized" else Zeros) - if self.with_bias: - args.append(Bias) - args.append(last_output) - - func = te.create_prim_func(args).with_attr( - "dequantize_info", - { - "B_decode": { - "decode_block": "B_decode", - "fast_decoding": self.fast_decoding, - "source_format": { - "bits": self.bit, - "format": self.source_format, - }, - "storage_dtype": self.storage_dtype, - "target_format": self.in_dtype, - "with_zeros": self.with_zeros, - "zeros_mode": self.zeros_mode, - "with_scaling": self.with_scaling, - "group_size": self.group_size, - } - }, - ) - return tvm.IRModule.from_expr(func) # TODO: The following code should be refactored. class MatMulNTDequantizeEmitter: From 99515cb540cdc11cac867e965c9c04a108d57bd2 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 29 Aug 2024 08:39:22 +0000 Subject: [PATCH 10/22] buf fix for matrix support --- README.md | 16 +- .../benchmark_matmul_fast_decoding.py | 447 ++++++++++++++++++ 2 files changed, 455 insertions(+), 8 deletions(-) create mode 100644 benchmark/operators/benchmark_matmul_fast_decoding.py diff --git a/README.md b/README.md index 43f1d92d..315fa307 100644 --- a/README.md +++ b/README.md @@ -61,14 +61,14 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and | **A_dtype** | **W_dtype** | **Accum_dtype** | **Out_dtype** | **BitBLAS Support** | **Tested Platform** | |:-----------:|:-----------:|:---------------:|:--------------------:|:-------------------:|:----------------------------------------------------:| -| BF16 | BF16 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | FP4_E2M1 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | FP8_E4M3 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | INT8 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | UINT4/INT4 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | UINT2/INT2 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | UINT1 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | -| BF16 | NF4 | FP32/FP16 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | BF16 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | FP4_E2M1 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | FP8_E4M3 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | INT8 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | UINT4/INT4 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | UINT2/INT2 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | UINT1 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | +| BF16 | NF4 | FP32 | FP16 | **√** | A100(SM_80)/A6000(SM_86) | | FP16 | FP16 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | FP16 | FP4_E2M1 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | FP16 | FP8_E4M3 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | diff --git a/benchmark/operators/benchmark_matmul_fast_decoding.py b/benchmark/operators/benchmark_matmul_fast_decoding.py new file mode 100644 index 00000000..905bbc19 --- /dev/null +++ b/benchmark/operators/benchmark_matmul_fast_decoding.py @@ -0,0 +1,447 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas.benchmark import BitblasOperatorBenchmarkBase +from bitblas import Matmul, MatmulConfig +from bitblas.ops.general_matmul import OptimizeStrategy +from bitblas.utils import get_commit_id +from bitblas import set_log_level +from tabulate import tabulate +from os import path, makedirs +from typing import List +import argparse +from tqdm import tqdm + +set_log_level("DEBUG") + + +class BitblasMatmulOpsBenchmarkCompareStategies(BitblasOperatorBenchmarkBase): + + BENCHMARK_RESULTS_FILE = "benchmark_results.json" + BENCHMARK_SHAPES_FILE = "benchmark_shapes.json" + BENCHMARK_DEVICE_FILE = "benchmark_device.json" + + config_map = { + "FP16xFP16_GEMV": { + "A_dtype": "float16", + "W_dtype": "float16", + "accum_dtype": "float16", + }, + "FP16xUINT4_GEMV_DECODING_NAIVE": { + "A_dtype": "float16", + "W_dtype": "uint4", + "accum_dtype": "float16", + "fast_decoding": False, + }, + "FP16xUINT4_GEMV_DECODING_FAST": { + "A_dtype": "float16", + "W_dtype": "uint4", + "accum_dtype": "float16", + "fast_decoding": True, + }, + "FP16xUINT2_GEMV_DECODING_NAIVE": { + "A_dtype": "float16", + "W_dtype": "uint2", + "accum_dtype": "float16", + "fast_decoding": False, + }, + "FP16xUINT2_GEMV_DECODING_FAST": { + "A_dtype": "float16", + "W_dtype": "uint2", + "accum_dtype": "float16", + "fast_decoding": True, + }, + "INT8xUINT2_GEMV_DECODING_NAIVE": { + "A_dtype": "int8", + "W_dtype": "uint2", + "accum_dtype": "int32", + "out_dtype": "int32", + "fast_decoding": False, + }, + "INT8xUINT2_GEMV_DECODING_FAST": { + "A_dtype": "int8", + "W_dtype": "uint2", + "accum_dtype": "int32", + "out_dtype": "int32", + "fast_decoding": True, + }, + } + + OPT_SHAPES = 1 # our test focuses on GEMV only + + CURRENT_COMMIT_ID = get_commit_id() + + def __init__(self): + super().__init__() + + def prepare_set_group_4x(self, name: str, N, K) -> List: + assert name in self.config_map, f"Operator {name} not found in config map" + return [ + self.generate_op_unit( + self.generate_operator_config( + name, self.OPT_SHAPES, N, K)), + ] + + def prepare_benchmark_sets(self): + """Prepare benchmark sets.""" + + self.add_benchmark_set( + "FP16xUINT4_GEMV_DECODING_NAIVE", + [ + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 16384, 16384), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 3200, 3200), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 8640, 3200), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 3200, 8640), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 1024, 8192), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 8192, 8192), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 28672, 8192), + *self.prepare_set_group_4x("FP16xUINT4_GEMV_DECODING_NAIVE", 8192, 28672), + ], + ) + + self.add_benchmark_set( + "FP16xUINT4_GEMV_DECODING_FAST", + [ + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 16384, + 16384, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 3200, + 3200, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 8640, + 3200, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 3200, + 8640, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 1024, + 8192, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 8192, + 8192, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 28672, + 8192, + ), + *self.prepare_set_group_4x( + "FP16xUINT4_GEMV_DECODING_FAST", + 8192, + 28672, + ), + ], + ) + + self.add_benchmark_set( + "FP16xUINT2_GEMV_DECODING_NAIVE", + [ + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 16384, 16384), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 3200, 3200), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 8640, 3200), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 3200, 8640), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 1024, 8192), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 8192, 8192), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 28672, 8192), + *self.prepare_set_group_4x("FP16xUINT2_GEMV_DECODING_NAIVE", 8192, 28672), + ], + ) + + self.add_benchmark_set( + "FP16xUINT2_GEMV_DECODING_FAST", + [ + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 16384, + 16384, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 3200, + 3200, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 8640, + 3200, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 3200, + 8640, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 1024, + 8192, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 8192, + 8192, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 28672, + 8192, + ), + *self.prepare_set_group_4x( + "FP16xUINT2_GEMV_DECODING_FAST", + 8192, + 28672, + ), + ], + ) + + self.add_benchmark_set( + "INT8xUINT2_GEMV_DECODING_NAIVE", + [ + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 16384, 16384), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 3200, 3200), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 8640, 3200), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 3200, 8640), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 1024, 8192), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 8192, 8192), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 28672, 8192), + *self.prepare_set_group_4x("INT8xUINT2_GEMV_DECODING_NAIVE", 8192, 28672), + ], + ) + + self.add_benchmark_set( + "INT8xUINT2_GEMV_DECODING_FAST", + [ + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 16384, + 16384, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 3200, + 3200, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 8640, + 3200, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 3200, + 8640, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 1024, + 8192, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 8192, + 8192, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 28672, + 8192, + ), + *self.prepare_set_group_4x( + "INT8xUINT2_GEMV_DECODING_FAST", + 8192, + 28672, + ), + ], + ) + + def generate_operator_config(self, name: str, M, N, K) -> MatmulConfig: + """Generate configuration for the given operator.""" + if name not in self.config_map: + raise ValueError(f"Operator {name} not found in config map") + return self.get_operator_config()( + M=M, + N=N, + K=K, + **self.config_map[name], + ) + + def report(self): + """Generate and print a report of the benchmark results.""" + results4compare = {} + for name, results in self.benchmark_results.items(): + if "DECODING" not in name: + name = f"{name}" + strategy = "" + else: + name, strategy = name.split("DECODING") + results4compare.setdefault(name, {})[strategy] = results + + data = [] + for name, strategy in results4compare.items(): + table_data = [ + ["TAG:", name, "Device:", self.benchmark_target], + [ + "Shape (M-N-K / N-K_M)", + "Native Decoding Time (ms)", + "Shape (M-N-K / N-K_M)", + "Fast Decoding Time (ms)", + "Tune Time (s)", + ], + ] + + def legalize_shape(M, N, K, dyn_prof_shape): + """Generate a string representation of the operator shape. + + Args: + M: The M dimension (can be an int or a tuple). + N: The N dimension (must be an int). + K: The K dimension (must be an int). + dyn_prof_shape: The dynamic profiling shape (dict with "m" key if M is dynamic). + + Returns: + A string representing the shape in either 'M-N-K' or 'N-K_M' format. + """ + if isinstance(M, int): + return f"{M}-{N}-{K}" + elif dyn_prof_shape and "m" in dyn_prof_shape: + return f"{M}-{N}-{K}_{dyn_prof_shape['m']}" + else: + # Calculate the average of tuple M + str_m = "[" + "-".join(str(m) for m in M) + "]" + opt_m = sum(M) / len(M) + return f"{N}-{K}_{str_m}_{opt_m}" + + for strategy_name, results in strategy.items(): + tmp_data = [] + if strategy_name == "": + origin_name = name + else: + origin_name = f"{name}DECODING{strategy_name}" + for i, benchmark_set in enumerate(self.benchmark_sets[origin_name]): + op_config = benchmark_set[1] + if isinstance(self.OPT_SHAPES, int): + sub_results = results[i] + latency = sub_results[0] + dyn_prof_shape = {"m": self.OPT_SHAPES} + shape = legalize_shape("DYN", op_config.N, op_config.K, dyn_prof_shape) + latency_str = "N/A" if latency is None else f"{latency:.3f}" + tmp_data.append([shape, latency_str]) + else: + sub_results = results[i * len(self.OPT_SHAPES):(i + 1) * len(self.OPT_SHAPES)] + for i, result in enumerate(sub_results): + latency = result[0] + dyn_prof_shape = {"m": self.OPT_SHAPES[i]} + shape = legalize_shape("DYN", op_config.N, op_config.K, dyn_prof_shape) + latency_str = "N/A" if latency is None else f"{latency:.3f}" + tmp_data.append([shape, latency_str]) + if len(data) == 0: + data = tmp_data + else: + for i, item in enumerate(tmp_data): + data[i].extend(item) + + for i, item in enumerate(data): + base = item[1] + head = item[3] + + speedup = float(head) / float(base) - 1 + symbol = "+" if speedup > 0 else "-" + speedup = abs(speedup) + data[i][3] = f"{head} ({symbol}{speedup * 100 :.3f}%)" + table_data.append([*data[i], "N/A"]) + + print(tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")) + + for data in table_data: + print(data) + + def get_operator(self): + """Return the Matmul operator.""" + return Matmul + + def get_operator_config(self): + """Return the Matmul operator configuration.""" + return MatmulConfig + + def make_operator(self, operator: Matmul, config: MatmulConfig) -> Matmul: + """Make an Matmul instance.""" + # Disable default tuning when do benchmark + return operator(config, target=self.benchmark_target, enable_tuning=False) + + def benchmark(self): + """Run benchmarks on all benchmark sets.""" + # Calculate the total number of benchmark runs for the progress bar + total_runs = sum( + ( + len(benchmark_set) * + (len(self.OPT_SHAPES) + if isinstance(self.OPT_SHAPES, list) + else self.OPT_SHAPES) + ) + for benchmark_set in self.benchmark_sets.values() + ) + + with tqdm(total=total_runs, desc="Total Progress", unit="benchmark") as pbar: + for name, benchmark_set in self.benchmark_sets.items(): + self.benchmark_results[name] = [] + for op, config, _ in benchmark_set: + if isinstance(self.OPT_SHAPES, int): + print(f"Running benchmark for {name} with shape {self.OPT_SHAPES}") + self.benchmark_results[name].extend( + [self.run_benchmark(op, config, {"m": self.OPT_SHAPES})]) + # Update the progress bar after each run + pbar.update(1) + else: + for opt in self.OPT_SHAPES: + print(f"Running benchmark for {name} with shape {opt}") + self.benchmark_results[name].extend( + [self.run_benchmark(op, config, {"m": opt})]) + # Update the progress bar after each run + pbar.update(1) + + def run_compare_strategy(self, report=True, serialize=True, enable_tuning: bool = False): + """Run the benchmark process.""" + + if not path.exists(self.log_path): + makedirs(self.log_path) + + if enable_tuning: + self.enable_tuning() + + self.prepare_benchmark_sets() + self.benchmark() + + if report: + self.report() + + self.cleanup() + + def serialize_results(self) -> None: + """Serialize the benchmark results.""" + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Bitblas Matmul Operator Benchmark") + + parser.add_argument( + "--enable_tuning", + action="store_true", + help="Enable hardware-aware tuning", + ) + + args = parser.parse_args() + enable_tuning = args.enable_tuning + BitblasMatmulOpsBenchmarkCompareStategies().run_compare_strategy( + enable_tuning=args.enable_tuning) From 14406effb387732d47909582202e4007d5767f98 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 29 Aug 2024 08:39:50 +0000 Subject: [PATCH 11/22] lint fix --- .../benchmark_matmul_fast_decoding.py | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/benchmark/operators/benchmark_matmul_fast_decoding.py b/benchmark/operators/benchmark_matmul_fast_decoding.py index 905bbc19..c3c65b3b 100644 --- a/benchmark/operators/benchmark_matmul_fast_decoding.py +++ b/benchmark/operators/benchmark_matmul_fast_decoding.py @@ -3,7 +3,6 @@ from bitblas.benchmark import BitblasOperatorBenchmarkBase from bitblas import Matmul, MatmulConfig -from bitblas.ops.general_matmul import OptimizeStrategy from bitblas.utils import get_commit_id from bitblas import set_log_level from tabulate import tabulate @@ -67,7 +66,7 @@ class BitblasMatmulOpsBenchmarkCompareStategies(BitblasOperatorBenchmarkBase): }, } - OPT_SHAPES = 1 # our test focuses on GEMV only + OPT_SHAPES = 1 # our test focuses on GEMV only CURRENT_COMMIT_ID = get_commit_id() @@ -77,14 +76,12 @@ def __init__(self): def prepare_set_group_4x(self, name: str, N, K) -> List: assert name in self.config_map, f"Operator {name} not found in config map" return [ - self.generate_op_unit( - self.generate_operator_config( - name, self.OPT_SHAPES, N, K)), + self.generate_op_unit(self.generate_operator_config(name, self.OPT_SHAPES, N, K)), ] def prepare_benchmark_sets(self): """Prepare benchmark sets.""" - + self.add_benchmark_set( "FP16xUINT4_GEMV_DECODING_NAIVE", [ @@ -144,7 +141,7 @@ def prepare_benchmark_sets(self): ), ], ) - + self.add_benchmark_set( "FP16xUINT2_GEMV_DECODING_NAIVE", [ @@ -204,7 +201,7 @@ def prepare_benchmark_sets(self): ), ], ) - + self.add_benchmark_set( "INT8xUINT2_GEMV_DECODING_NAIVE", [ @@ -338,7 +335,8 @@ def legalize_shape(M, N, K, dyn_prof_shape): latency_str = "N/A" if latency is None else f"{latency:.3f}" tmp_data.append([shape, latency_str]) else: - sub_results = results[i * len(self.OPT_SHAPES):(i + 1) * len(self.OPT_SHAPES)] + sub_results = results[i * len(self.OPT_SHAPES):(i + 1) * + len(self.OPT_SHAPES)] for i, result in enumerate(sub_results): latency = result[0] dyn_prof_shape = {"m": self.OPT_SHAPES[i]} @@ -383,14 +381,9 @@ def benchmark(self): """Run benchmarks on all benchmark sets.""" # Calculate the total number of benchmark runs for the progress bar total_runs = sum( - ( - len(benchmark_set) * - (len(self.OPT_SHAPES) - if isinstance(self.OPT_SHAPES, list) - else self.OPT_SHAPES) - ) - for benchmark_set in self.benchmark_sets.values() - ) + (len(benchmark_set) * + (len(self.OPT_SHAPES) if isinstance(self.OPT_SHAPES, list) else self.OPT_SHAPES)) + for benchmark_set in self.benchmark_sets.values()) with tqdm(total=total_runs, desc="Total Progress", unit="benchmark") as pbar: for name, benchmark_set in self.benchmark_sets.items(): From d30ec4fafb32aaaa31c6f138c078714cb0918848 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 29 Aug 2024 10:13:19 +0000 Subject: [PATCH 12/22] dispatch tensor core based on shapes --- bitblas/gpu/matmul_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 16f33664..4a0ef532 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -666,7 +666,7 @@ def check_last_trait(region: List[Range]): block_stmt = sch.get(main_block) - minimal_tensorize_threshold = 16 + minimal_tensorize_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32 # the batch dimension is not taken into consideration. extent = block_stmt.iter_vars[1].dom.extent if isinstance(extent, From fde4029a55dd278594f6e3035e567f5539188181 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 30 Aug 2024 03:24:55 +0000 Subject: [PATCH 13/22] update install commands --- README.md | 6 ++++++ docs/Installation.md | 8 ++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 315fa307..70a61357 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,12 @@ The easiest way to install BitBLAS is direcly from the PyPi using pip. To instal pip install bitblas ``` +Alternatively, to install the latest version of BitBLAS from the github repository, you can run the following command: + +```bash +pip install git+https://github.com/microsoft/BitBLAS.git +``` + After installing BitBLAS, you can verify the installation by running: ```bash diff --git a/docs/Installation.md b/docs/Installation.md index f30d2dfb..a50d478e 100644 --- a/docs/Installation.md +++ b/docs/Installation.md @@ -1,7 +1,5 @@ # Installation Guide - - ## Installing with pip **Prerequisites for installation via wheel or PyPI:** @@ -23,6 +21,12 @@ Alternatively, you may choose to install BitBLAS using prebuilt packages availab pip install bitblas-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl ``` +To install the latest version of BitBLAS from the github repository, you can run the following command: + +```bash +pip install git+https://github.com/microsoft/BitBLAS.git +``` + After installing BitBLAS, you can verify the installation by running: ```bash From 6a0474963e69d647ed80df96cd4630372ef88a28 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 31 Aug 2024 17:02:30 +0000 Subject: [PATCH 14/22] import scripts --- 3rdparty/tvm | 2 +- install.sh | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 3c6317a1..7edc860c 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 3c6317a1ea614b7277ffe0b4ede18b4652afad1c +Subproject commit 7edc860c2311789d2f149f16e1f0b5580b6d90d3 diff --git a/install.sh b/install.sh index db3b3682..c3bb0fe0 100755 --- a/install.sh +++ b/install.sh @@ -53,6 +53,9 @@ echo "LLVM config path: $LLVM_CONFIG_PATH" git submodule update --init --recursive cd 3rdparty/tvm +if [ -d build ]; then + rm -rf build +fi mkdir build cp cmake/config.cmake build cd build From 9ef14e9655a62ce0a303f8459b8259794c146b29 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 1 Sep 2024 08:51:26 +0000 Subject: [PATCH 15/22] remove shared mem hack --- bitblas/gpu/matmul_mma.py | 8 ++++---- bitblas/gpu/matmul_mma_dequantize.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bitblas/gpu/matmul_mma.py b/bitblas/gpu/matmul_mma.py index 5d92f99b..3dafd395 100644 --- a/bitblas/gpu/matmul_mma.py +++ b/bitblas/gpu/matmul_mma.py @@ -571,9 +571,9 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, red # Apply Swizzling sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) # if not, apply padding to alleviate bank conflict - if not (can_swizzle or is_smooth): - pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + # if not (can_swizzle or is_smooth): + # pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + # sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) sch.annotate(f_2, "pragma_unroll_explicit", False) return block_read @@ -648,7 +648,7 @@ def inverse_permutation(i, j, ii, jj): auto_inline_consumer_chain(sch, accumulator_shared_to_global) sch.reverse_compute_at( accumulator_shared_to_global, - sch.get_loops(store)[-5], + sch.get_loops(store)[-6], preserve_unit_loops=True, ) vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index 6bc0e39b..f6f1e098 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -578,7 +578,7 @@ def get_idx(): auto_inline_consumer_chain(sch, accumulator_shared_to_global) sch.reverse_compute_at( accumulator_shared_to_global, - sch.get_loops(store)[-5], + sch.get_loops(store)[-6], preserve_unit_loops=True, ) vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) @@ -1075,7 +1075,7 @@ def get_idx(): auto_inline_consumer_chain(sch, accumulator_shared_to_global) sch.reverse_compute_at( accumulator_shared_to_global, - sch.get_loops(store)[-5], + sch.get_loops(store)[-6], preserve_unit_loops=True, ) vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) @@ -1675,7 +1675,7 @@ def get_idx(): auto_inline_consumer_chain(sch, accumulator_shared_to_global) sch.reverse_compute_at( accumulator_shared_to_global, - sch.get_loops(store)[-5], + sch.get_loops(store)[-6], preserve_unit_loops=True, ) vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) @@ -2194,7 +2194,7 @@ def get_idx(): auto_inline_consumer_chain(sch, accumulator_shared_to_global) sch.reverse_compute_at( accumulator_shared_to_global, - sch.get_loops(store)[-5], + sch.get_loops(store)[-6], preserve_unit_loops=True, ) vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) From 63f363e2a3506ffd7c8277a9a4e0170d8d1092c6 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 1 Sep 2024 08:55:14 +0000 Subject: [PATCH 16/22] revert change for swizzling --- bitblas/gpu/matmul_mma.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bitblas/gpu/matmul_mma.py b/bitblas/gpu/matmul_mma.py index 3dafd395..b8fa0b24 100644 --- a/bitblas/gpu/matmul_mma.py +++ b/bitblas/gpu/matmul_mma.py @@ -571,9 +571,9 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, red # Apply Swizzling sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) # if not, apply padding to alleviate bank conflict - # if not (can_swizzle or is_smooth): - # pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - # sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) sch.annotate(f_2, "pragma_unroll_explicit", False) return block_read From b29c66cf47d3d68a51a79fc0e7764eaf5dfb3f89 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 1 Sep 2024 12:01:10 +0000 Subject: [PATCH 17/22] bug fix --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 7edc860c..a29c8ad7 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 7edc860c2311789d2f149f16e1f0b5580b6d90d3 +Subproject commit a29c8ad7e78f61e0658946bd494f45cc9bebd36e From 28beb1317819de77479d9ddaa2fee8ec1e999d6f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 2 Sep 2024 09:47:32 +0000 Subject: [PATCH 18/22] tl examples --- .../tilelang/test_tilelang_dequantize_gemm.py | 164 +++++++++++++++++ .../tilelang/test_tilelang_flash_atten.py | 173 ++++++++++++++++++ 2 files changed, 337 insertions(+) create mode 100644 testing/python/tilelang/test_tilelang_dequantize_gemm.py create mode 100644 testing/python/tilelang/test_tilelang_flash_atten.py diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py new file mode 100644 index 00000000..a0d2feae --- /dev/null +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -0,0 +1,164 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +from bitblas import tvm as tvm +from tvm import tl +from bitblas.quantization import _tir_packed_to_unsigned_convert + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + dtypeAB, + dtypeC, + accum_dtype, + num_stages, + threads, + num_bits=4, +): + num_elems_per_byte = 8 // num_bits + storage_dtype = "int8" + A_shape = (M, K) + B_shape = (N, K // num_elems_per_byte) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + + import tvm.tl.language as T + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads + ) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment([8], storage_dtype, "local") + B_dequantize_local = T.alloc_fragment([16], dtypeAB, "local") + B_dequantize_shared = T.alloc_shared( + B_dequantize_shared_shape, dtypeAB + ) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + + for i in T.serial( + block_N * block_K // num_elems_per_byte // (threads * 16) + ): + for t in T.thread_binding(0, threads, thread="threadIdx.x"): + for v in T.vectorized(0, 16): + vi = (i * threads * 16 + t * 16 + v) // ( + block_K // num_elems_per_byte + ) + vj = (i * threads * 16 + t * 16 + v) % ( + block_K // num_elems_per_byte + ) + B_shared[vi, vj] = B[ + bx * block_N + vi, + k * block_K // num_elems_per_byte + vj, + ] + + for i in T.serial( + block_N * block_K // num_elems_per_byte // (threads * 4) + ): + for t in T.thread_binding(0, threads, thread="threadIdx.x"): + for v in T.vectorized(0, 4): + vi = (i * threads * 4 + t * 4 + v) // ( + block_K // num_elems_per_byte + ) + vj = (i * threads * 4 + t * 4 + v) % ( + block_K // num_elems_per_byte + ) + B_local[v] = B_shared[vi, vj] + for v in T.serial(0, 8): + B_dequantize_local[ + v + ] = _tir_packed_to_unsigned_convert("int", 8)( + num_bits, + B_local[v // 2], + v % 2, + dtype=dtypeAB, + ) + for v in T.vectorized(0, 8): + vi = (i * threads * 8 + t * 8 + v) // (block_K) + vj = (i * threads * 8 + t * 8 + v) % (block_K) + B_dequantize_shared[vi, vj] = B_dequantize_local[v] + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + dtypeAB, + dtypeC, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + dtypeAB, + dtypeC, + dtypeAccum, + num_stages, + num_threads, + ) + print(program) + + mod, params = tl.lower(program) + mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) + + out = mod.run_once() + + print(f"output is {out}") + + with open("debug/kernel.cu", "w") as f: + f.write(mod.mod.imported_modules[0].get_source()) + + def ref_program(A, qB): + import torch + + B = ( + torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half) + .to(torch.half) + .to(A.device) + ) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to( + torch.half + ) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + mod.assert_allclose(ref_program) + + +def test_run_dequantize_gemm(): + run_gemm(16, 16, 16, "int8", "int32", "int32", 16, 16, 16, num_threads=128) + + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/tilelang/test_tilelang_flash_atten.py b/testing/python/tilelang/test_tilelang_flash_atten.py new file mode 100644 index 00000000..63a52bba --- /dev/null +++ b/testing/python/tilelang/test_tilelang_flash_atten.py @@ -0,0 +1,173 @@ +import argparse +import torch +from tvm import tl +import tvm.tl.language as T +from tvm.tl.autotuner import * +from functools import partial +import itertools + + +def get_configs(): + block_M = [32, 64, 128] + block_N = [32, 64, 128] + num_stages = [1, 2] + thread_num = [128, 256] + _configs = list(itertools.product(block_M, block_N, num_stages, thread_num)) + + configs = [ + { + "block_M": c[0], + "block_N": c[1], + "num_stages": c[2], + "thread_num": c[3], + } + for c in _configs + ] + return configs + + +def ref_program(Q, K, V, casual): + from flash_attn.flash_attn_interface import flash_attn_func + + return flash_attn_func(Q, K, V, causal=casual) + + +def flashattn(batch, heads, seq_len, dim, is_casual): + + @autotune( + configs=get_configs(), + keys=["block_M", "block_N", "num_stages", "thread_num"], + warmup=10, + rep=5, + ) + @jit( + out_idx=[3], + supply_type=tl.TensorSupplyType.Normal, + ref_prog=partial(ref_program, casual=is_casual), + rtol=0.01, + atol=0.01, + ) + def kernel(block_M=None, block_N=None, num_stages=None, thread_num=None): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main( + Q: T.Buffer(shape, dtype), # type: ignore + K: T.Buffer(shape, dtype), # type: ignore + V: T.Buffer(shape, dtype), # type: ignore + Output: T.Buffer(shape, dtype), # type: ignore + ): + with T.Kernel( + T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num + ) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + Q_local = T.alloc_fragment([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout( + {Q_shared: tl.layout.make_swizzled_layout(Q_shared)} + ) + T.copy( + Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared + ) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.copy(Q_shared, Q_local) + for i, j in T.Parallel(block_M, dim): + Q_local[i, j] *= scale + loop_range = ( + T.ceildiv((bx + 1) * block_M, block_N) + if is_casual + else T.ceildiv(seq_len, block_N) + ) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy( + K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared + ) + if is_casual: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, + 0, + -T.infinity(acc_s.dtype), + ) + else: + T.clear(acc_s) + T.gemm( + Q_local, + K_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy( + V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared + ) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2( + scores_max_prev[i] - scores_max[i] + ) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) + T.copy(acc_s, acc_s_cast) + T.gemm( + acc_s_cast, + V_shared, + acc_o, + policy=T.GemmWarpPolicy.FullRow, + ) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy( + acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :] + ) + + return main + + return kernel() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=64, help="Batch size") + parser.add_argument("--h", type=int, default=12, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=2048, help="Context size") + parser.add_argument( + "--d_head", type=int, default=256, help="Head dimension" + ) + parser.add_argument("--casual", type=bool, default=True, help="Casual flag") + args = parser.parse_args() + BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head + casual = args.casual + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 2 * flops_per_matmul + if casual: + total_flops *= 0.5 + + best_latency, best_config, ref_latency = flashattn( + BATCH, H, N_CTX, D_HEAD, casual + ) + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}") From c0b476f02efeaf7e6c94b74a8c7177ad6e494ef2 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 2 Sep 2024 17:09:15 +0000 Subject: [PATCH 19/22] Enhance Swizzle --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index a29c8ad7..c5902a0a 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit a29c8ad7e78f61e0658946bd494f45cc9bebd36e +Subproject commit c5902a0a2f6d9c21b56958d272781101b1165068 From 2bf14a83358680dc25a351ea78c2ada656377d64 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 2 Sep 2024 17:09:55 +0000 Subject: [PATCH 20/22] lint fix --- .../tilelang/test_tilelang_dequantize_gemm.py | 60 +++++------------ .../tilelang/test_tilelang_flash_atten.py | 67 ++++++------------- 2 files changed, 40 insertions(+), 87 deletions(-) diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index a0d2feae..9f915573 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -32,57 +32,37 @@ def matmul( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), dtypeC), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads - ) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, dtypeAB) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment([8], storage_dtype, "local") B_dequantize_local = T.alloc_fragment([16], dtypeAB, "local") - B_dequantize_shared = T.alloc_shared( - B_dequantize_shared_shape, dtypeAB - ) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, dtypeAB) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(A[by * block_M, k * block_K], A_shared) - for i in T.serial( - block_N * block_K // num_elems_per_byte // (threads * 16) - ): + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 16)): for t in T.thread_binding(0, threads, thread="threadIdx.x"): for v in T.vectorized(0, 16): - vi = (i * threads * 16 + t * 16 + v) // ( - block_K // num_elems_per_byte - ) - vj = (i * threads * 16 + t * 16 + v) % ( - block_K // num_elems_per_byte - ) - B_shared[vi, vj] = B[ - bx * block_N + vi, - k * block_K // num_elems_per_byte + vj, - ] - - for i in T.serial( - block_N * block_K // num_elems_per_byte // (threads * 4) - ): + vi = (i * threads * 16 + t * 16 + v) // (block_K // num_elems_per_byte) + vj = (i * threads * 16 + t * 16 + v) % (block_K // num_elems_per_byte) + B_shared[vi, vj] = B[bx * block_N + vi, + k * block_K // num_elems_per_byte + vj,] + + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 4)): for t in T.thread_binding(0, threads, thread="threadIdx.x"): for v in T.vectorized(0, 4): - vi = (i * threads * 4 + t * 4 + v) // ( - block_K // num_elems_per_byte - ) - vj = (i * threads * 4 + t * 4 + v) % ( - block_K // num_elems_per_byte - ) + vi = (i * threads * 4 + t * 4 + v) // (block_K // num_elems_per_byte) + vj = (i * threads * 4 + t * 4 + v) % (block_K // num_elems_per_byte) B_local[v] = B_shared[vi, vj] for v in T.serial(0, 8): - B_dequantize_local[ - v - ] = _tir_packed_to_unsigned_convert("int", 8)( + B_dequantize_local[v] = _tir_packed_to_unsigned_convert("int", 8)( num_bits, B_local[v // 2], v % 2, @@ -140,15 +120,11 @@ def ref_program(A, qB): import torch B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half) - .to(torch.half) - .to(A.device) - ) + torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, + dtype=torch.half).to(torch.half).to(A.device)) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to( - torch.half - ) + B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C diff --git a/testing/python/tilelang/test_tilelang_flash_atten.py b/testing/python/tilelang/test_tilelang_flash_atten.py index 63a52bba..a8b8c498 100644 --- a/testing/python/tilelang/test_tilelang_flash_atten.py +++ b/testing/python/tilelang/test_tilelang_flash_atten.py @@ -1,5 +1,4 @@ import argparse -import torch from tvm import tl import tvm.tl.language as T from tvm.tl.autotuner import * @@ -14,15 +13,12 @@ def get_configs(): thread_num = [128, 256] _configs = list(itertools.product(block_M, block_N, num_stages, thread_num)) - configs = [ - { - "block_M": c[0], - "block_N": c[1], - "num_stages": c[2], - "thread_num": c[3], - } - for c in _configs - ] + configs = [{ + "block_M": c[0], + "block_N": c[1], + "num_stages": c[2], + "thread_num": c[3], + } for c in _configs] return configs @@ -48,21 +44,20 @@ def flashattn(batch, heads, seq_len, dim, is_casual): atol=0.01, ) def kernel(block_M=None, block_N=None, num_stages=None, thread_num=None): - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def main( - Q: T.Buffer(shape, dtype), # type: ignore - K: T.Buffer(shape, dtype), # type: ignore - V: T.Buffer(shape, dtype), # type: ignore - Output: T.Buffer(shape, dtype), # type: ignore + Q: T.Buffer(shape, dtype), # type: ignore + K: T.Buffer(shape, dtype), # type: ignore + V: T.Buffer(shape, dtype), # type: ignore + Output: T.Buffer(shape, dtype), # type: ignore ): with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num - ) as (bx, by, bz): + T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) Q_local = T.alloc_fragment([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) @@ -76,12 +71,8 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout( - {Q_shared: tl.layout.make_swizzled_layout(Q_shared)} - ) - T.copy( - Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared - ) + T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -89,14 +80,10 @@ def main( for i, j in T.Parallel(block_M, dim): Q_local[i, j] *= scale loop_range = ( - T.ceildiv((bx + 1) * block_M, block_N) - if is_casual - else T.ceildiv(seq_len, block_N) - ) + T.ceildiv( + (bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N)) for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy( - K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared - ) + T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) if is_casual: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( @@ -113,15 +100,11 @@ def main( transpose_B=True, policy=T.GemmWarpPolicy.FullRow, ) - T.copy( - V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared - ) + T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): - scores_scale[i] = T.exp2( - scores_max_prev[i] - scores_max[i] - ) + scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] for i, j in T.Parallel(block_M, block_N): @@ -138,9 +121,7 @@ def main( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy( - acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :] - ) + T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) return main @@ -152,9 +133,7 @@ def main( parser.add_argument("--batch", type=int, default=64, help="Batch size") parser.add_argument("--h", type=int, default=12, help="Number of heads") parser.add_argument("--n_ctx", type=int, default=2048, help="Context size") - parser.add_argument( - "--d_head", type=int, default=256, help="Head dimension" - ) + parser.add_argument("--d_head", type=int, default=256, help="Head dimension") parser.add_argument("--casual", type=bool, default=True, help="Casual flag") args = parser.parse_args() BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head @@ -164,9 +143,7 @@ def main( if casual: total_flops *= 0.5 - best_latency, best_config, ref_latency = flashattn( - BATCH, H, N_CTX, D_HEAD, casual - ) + best_latency, best_config, ref_latency = flashattn(BATCH, H, N_CTX, D_HEAD, casual) print(f"Best latency: {best_latency}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best config: {best_config}") From 19aa9850252bc71d72043dd766bb704f7871e3c9 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 3 Sep 2024 03:24:21 +0000 Subject: [PATCH 21/22] test fix --- 3rdparty/tvm | 2 +- testing/python/tilelang/test_tilelang_dequantize_gemm.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index c5902a0a..a1d78ebc 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c5902a0a2f6d9c21b56958d272781101b1165068 +Subproject commit a1d78ebc682dbaec70e792470c6842b9ec3342c6 diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index 9f915573..5d32d40b 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -133,7 +133,9 @@ def ref_program(A, qB): def test_run_dequantize_gemm(): - run_gemm(16, 16, 16, "int8", "int32", "int32", 16, 16, 16, num_threads=128) + run_gemm( + 256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128 + ) if __name__ == "__main__": From ef8f93c8bed95f0b0de375720fbc5e5b8a220154 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 3 Sep 2024 03:32:13 +0000 Subject: [PATCH 22/22] lint fix --- testing/python/tilelang/test_tilelang_dequantize_gemm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index 5d32d40b..9db978cd 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -133,9 +133,7 @@ def ref_program(A, qB): def test_run_dequantize_gemm(): - run_gemm( - 256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128 - ) + run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128) if __name__ == "__main__":