From d8884e6f6a294fc8f1a325665d86a07603d43864 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:54:26 +0000 Subject: [PATCH 01/88] Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability --- bitblas/ops/impl/base.py | 16 +++ bitblas/ops/impl/batch_matmul_impl.py | 166 ++++++++++++++++---------- 2 files changed, 119 insertions(+), 63 deletions(-) create mode 100644 bitblas/ops/impl/base.py diff --git a/bitblas/ops/impl/base.py b/bitblas/ops/impl/base.py new file mode 100644 index 000000000..6d510f7da --- /dev/null +++ b/bitblas/ops/impl/base.py @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from abc import ABC, abstractmethod + +# TODO: Refactor all the tir script implementations to use this base class +# Abstract base class for TIR script emitters +class TIRScriptEmitter(ABC): + @abstractmethod + def emit(self): + raise NotImplementedError + +# Abstract base class for TIR script selectors +class TIRScriptSelector(ABC): + @abstractmethod + def select(self): + raise NotImplementedError diff --git a/bitblas/ops/impl/batch_matmul_impl.py b/bitblas/ops/impl/batch_matmul_impl.py index 09b536afa..75449ea4b 100644 --- a/bitblas/ops/impl/batch_matmul_impl.py +++ b/bitblas/ops/impl/batch_matmul_impl.py @@ -4,63 +4,117 @@ from bitblas import tvm from tvm import te from bitblas.ops.operator import TransformKind +from .base import TIRScriptEmitter, TIRScriptSelector +from bitblas import tvm +from tvm import te +from bitblas.ops.operator import TransformKind +class BatchMatMulEmitter(TIRScriptEmitter): + def __init__( + self, + batch, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + ): + self.batch = batch + self.M = self._validate_dimension(M, "M") + self.N = self._validate_dimension(N, "N") + self.K = self._validate_dimension(K, "K") + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.accum_dtype = accum_dtype + self.with_bias = with_bias + self.layout = layout + self._validate_layout() + + @staticmethod + def _validate_dimension(dim, name): + if not isinstance(dim, int): + return tvm.te.var(name.lower()) + return dim -def matmul_nt( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, -): - if not isinstance(M, int): - M = tvm.te.var("m") - A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) - B = te.placeholder((Batch, N, K), name="B", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + def _validate_layout(self): + if self.layout not in ["nn", "nt"]: + raise ValueError(f"Unsupported layout: {self.layout}") + if self.layout == "nn": + raise ValueError("Currently only support layout=nt") - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (Batch, M, N), - lambda b, i, j: te.sum( - A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k), - name="C", - ) - last_output = C - if accum_dtype != out_dtype: - D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") - last_output = D + def _create_placeholders(self): + A = te.placeholder((self.batch, self.M, self.K), name="A", dtype=self.in_dtype) + B = te.placeholder((self.batch, self.N, self.K), name="B", dtype=self.in_dtype) + Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None + return A, B, Bias - if with_bias: - E = te.compute((Batch, M, N), lambda b, i, j: last_output[b, i, j] + Bias[j], name="E") - last_output = E + def _compute_matmul(self, A, B): + k = te.reduce_axis((0, self.K), name="k") + C = te.compute( + (self.batch, self.M, self.N), + lambda b, i, j: te.sum( + A[b, i, k].astype(self.accum_dtype) * B[b, j, k].astype(self.accum_dtype), axis=k), + name="C", + ) + return C - args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + def _apply_bias(self, C, Bias): + if self.with_bias: + return te.compute((self.batch, self.M, self.N), lambda b, i, j: C[b, i, j] + Bias[j], name="E") + return C - func = te.create_prim_func(args) + def _convert_dtype(self, tensor): + if self.accum_dtype != self.out_dtype: + return te.compute((self.batch, self.M, self.N), lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), name="D") + return tensor - return tvm.IRModule.from_expr(func) + def emit(self): + A, B, Bias = self._create_placeholders() + C = self._compute_matmul(A, B) + last_output = self._convert_dtype(C) + if self.with_bias: + last_output = self._apply_bias(last_output, Bias) + args = [A, B, Bias, last_output] if self.with_bias else [A, B, last_output] + func = te.create_prim_func(args) + return tvm.IRModule.from_expr(func) -def matmul( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - layout="nt", -): - if layout == "nn": - raise ValueError("Currently only support layout=nt") - return matmul_nt(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) +class BatchMatMulSelector(TIRScriptSelector): + def __init__(self, propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform): + self.propagate_a = propagate_a + self.propagate_b = propagate_b + + def select( + self, + batch=1, + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + ): + if layout == "nn": + if self.propagate_a or self.propagate_b: + raise ValueError("Currently only support propagate_a=False and propagate_b=False for layout=nn") + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + elif layout == "nt": + if self.propagate_a and self.propagate_b: + raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") + elif self.propagate_a: + raise ValueError("Currently only support propagate_a=False for layout=nt") + elif self.propagate_b: + raise ValueError("Currently only support propagate_b=False for layout=nt") + else: + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + else: + raise ValueError(f"Unsupported layout: {layout}") def select_implementation( Batch=1, @@ -75,19 +129,5 @@ def select_implementation( propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform, ): - if layout == "nn": - if propagate_a or propagate_b: - raise ValueError( - "Currently only support propagate_a=False and propagate_b=False for layout=nn") - return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - elif layout == "nt": - if propagate_a and propagate_b: - raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") - elif propagate_a: - raise ValueError("Currently only support propagate_a=False for layout=nt") - elif propagate_b: - raise ValueError("Currently only support propagate_b=False for layout=nt") - else: - return matmul(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - else: - raise ValueError(f"Unsupported layout: {layout}") + selector = BatchMatMulSelector(propagate_a, propagate_b) + return selector.select(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) From fc84173f22d2f4867a8e6413117b5cd8e830ab27 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:57:43 +0000 Subject: [PATCH 02/88] Refactor import statements for improved readability and maintainability --- bitblas/ops/impl/__init__.py | 2 +- bitblas/ops/impl/base.py | 4 ++++ bitblas/ops/impl/batch_matmul_impl.py | 33 ++++++++++++++++++--------- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/bitblas/ops/impl/__init__.py b/bitblas/ops/impl/__init__.py index a254dc7fb..8a9bbd2a5 100644 --- a/bitblas/ops/impl/__init__.py +++ b/bitblas/ops/impl/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .lop3_permutate_impl import tir_interleave_weight +from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 diff --git a/bitblas/ops/impl/base.py b/bitblas/ops/impl/base.py index 6d510f7da..4a67987be 100644 --- a/bitblas/ops/impl/base.py +++ b/bitblas/ops/impl/base.py @@ -2,15 +2,19 @@ # Licensed under the MIT License. from abc import ABC, abstractmethod + # TODO: Refactor all the tir script implementations to use this base class # Abstract base class for TIR script emitters class TIRScriptEmitter(ABC): + @abstractmethod def emit(self): raise NotImplementedError + # Abstract base class for TIR script selectors class TIRScriptSelector(ABC): + @abstractmethod def select(self): raise NotImplementedError diff --git a/bitblas/ops/impl/batch_matmul_impl.py b/bitblas/ops/impl/batch_matmul_impl.py index 75449ea4b..3904f36e6 100644 --- a/bitblas/ops/impl/batch_matmul_impl.py +++ b/bitblas/ops/impl/batch_matmul_impl.py @@ -5,11 +5,10 @@ from tvm import te from bitblas.ops.operator import TransformKind from .base import TIRScriptEmitter, TIRScriptSelector -from bitblas import tvm -from tvm import te -from bitblas.ops.operator import TransformKind + class BatchMatMulEmitter(TIRScriptEmitter): + def __init__( self, batch, @@ -32,7 +31,7 @@ def __init__( self.with_bias = with_bias self.layout = layout self._validate_layout() - + @staticmethod def _validate_dimension(dim, name): if not isinstance(dim, int): @@ -48,7 +47,8 @@ def _validate_layout(self): def _create_placeholders(self): A = te.placeholder((self.batch, self.M, self.K), name="A", dtype=self.in_dtype) B = te.placeholder((self.batch, self.N, self.K), name="B", dtype=self.in_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None + Bias = te.placeholder( + (self.N,), name="Bias", dtype=self.in_dtype) if self.with_bias else None return A, B, Bias def _compute_matmul(self, A, B): @@ -63,12 +63,16 @@ def _compute_matmul(self, A, B): def _apply_bias(self, C, Bias): if self.with_bias: - return te.compute((self.batch, self.M, self.N), lambda b, i, j: C[b, i, j] + Bias[j], name="E") + return te.compute((self.batch, self.M, self.N), + lambda b, i, j: C[b, i, j] + Bias[j], + name="E") return C def _convert_dtype(self, tensor): if self.accum_dtype != self.out_dtype: - return te.compute((self.batch, self.M, self.N), lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), name="D") + return te.compute((self.batch, self.M, self.N), + lambda b, i, j: tensor[b, i, j].astype(self.out_dtype), + name="D") return tensor def emit(self): @@ -84,7 +88,10 @@ def emit(self): class BatchMatMulSelector(TIRScriptSelector): - def __init__(self, propagate_a: TransformKind = TransformKind.NonTransform, propagate_b: TransformKind = TransformKind.NonTransform): + + def __init__(self, + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform): self.propagate_a = propagate_a self.propagate_b = propagate_b @@ -102,8 +109,10 @@ def select( ): if layout == "nn": if self.propagate_a or self.propagate_b: - raise ValueError("Currently only support propagate_a=False and propagate_b=False for layout=nn") - return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn") + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, + layout).emit() elif layout == "nt": if self.propagate_a and self.propagate_b: raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") @@ -112,10 +121,12 @@ def select( elif self.propagate_b: raise ValueError("Currently only support propagate_b=False for layout=nt") else: - return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout).emit() + return BatchMatMulEmitter(batch, M, N, K, in_dtype, out_dtype, accum_dtype, + with_bias, layout).emit() else: raise ValueError(f"Unsupported layout: {layout}") + def select_implementation( Batch=1, M=None, From 02f64de6cf2d338c092dcf29ec55b69804fda892 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 08:58:06 +0000 Subject: [PATCH 03/88] Refactor import statements for improved readability and maintainability --- bitblas/ops/impl/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/ops/impl/__init__.py b/bitblas/ops/impl/__init__.py index 8a9bbd2a5..67e49b2ae 100644 --- a/bitblas/ops/impl/__init__.py +++ b/bitblas/ops/impl/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 +from .lop3_permutate_impl import tir_interleave_weight # noqa: F401 From 397eee6141599e84b509594bb99a0531e409c266 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 5 Jul 2024 16:25:47 +0000 Subject: [PATCH 04/88] disable failure email for ci --- .github/workflows/ci.yml | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ceb69fcc7..1fbdf19dd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,4 +64,13 @@ jobs: run: | source bitblas_ci/bin/activate cd testing/python - python -m pytest \ No newline at end of file + python -m pytest + + # Control notifications + notify: + runs-on: self-hosted + needs: [format-check, build-test] + if: failure() + steps: + - name: Notification + run: echo "Jobs failed, but no email will be sent." From 20f6ad1e7ca4e6e1ca9e13ad7c1bbc8c430a8e51 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 03:23:50 +0000 Subject: [PATCH 05/88] remove email notifications. --- .github/workflows/ci.yml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1fbdf19dd..511b95833 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,12 +65,3 @@ jobs: source bitblas_ci/bin/activate cd testing/python python -m pytest - - # Control notifications - notify: - runs-on: self-hosted - needs: [format-check, build-test] - if: failure() - steps: - - name: Notification - run: echo "Jobs failed, but no email will be sent." From b93c39431c803e22b12f71b555939785da36b96a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 03:25:05 +0000 Subject: [PATCH 06/88] move relax pass from testing to mlc_llm --- .../mlc_llm}/test_weight_only_transform.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {testing/python/transform => integration/mlc_llm}/test_weight_only_transform.py (100%) diff --git a/testing/python/transform/test_weight_only_transform.py b/integration/mlc_llm/test_weight_only_transform.py similarity index 100% rename from testing/python/transform/test_weight_only_transform.py rename to integration/mlc_llm/test_weight_only_transform.py From 257693a7c3cb3083aac144182f58d38bfe3bcfdd Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:51:01 +0000 Subject: [PATCH 07/88] Refactor scripts with se check_eual_ref_scripts_with_emitter function --- bitblas/ops/impl/matmul_dequantize_impl.py | 224 ++++++++++++++---- .../operators/test_tir_script_emitter.py | 52 +++- 2 files changed, 216 insertions(+), 60 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 1ed6b3404..e69e8fcfb 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -15,8 +15,10 @@ _tir_packed_to_unsigned_convert_with_zeros, ) + # TODO: The following code should be refactored. class MatMulNTDequantizeEmitter: + def __init__( self, M, @@ -52,8 +54,8 @@ def __init__( self.fast_decoding = fast_decoding self.with_bias = with_bias self.zeros_mode = zeros_mode - self.propagate_a = propagate_a - self.propagate_b = propagate_b + self.propagate_a = self._legalize_transform_kind(propagate_a) + self.propagate_b = self._legalize_transform_kind(propagate_b) self._validate_bit() self._validate_layout() @@ -69,62 +71,169 @@ def _validate_bit(self): raise ValueError(f"Unsupported bit: {self.bit}") def _validate_layout(self): - if self.layout not in ["nt"]: - raise ValueError(f"Unsupported layout: {self.layout}") + # TODO: extend the dequantize operators into General Layout + pass + + def _legalize_group_size(self): + if self.group_size == -1: + self.group_size = self.K + + def _legalize_transform_kind(self, propagate): + if propagate is None: + return TransformKind.NonTransform + if isinstance(propagate, bool): + return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) + elif isinstance(propagate, int): + return TransformKind(propagate) def _create_placeholders(self): - storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) - n_float_per_elem = storage_nbit // self.bit - - A = te.placeholder((self.M, self.K), name="A", dtype=self.in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * self.bit), name="B", dtype=self.storage_dtype) - LUT = te.placeholder((1 << self.bit,), name="LUT", dtype=self.in_dtype) - Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=self.in_dtype) - Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=self.in_dtype) - QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * self.bit), + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + in_dtype = self.in_dtype + bit = self.bit + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * bit), + name="B", + dtype=storage_dtype) + if self.propagate_a: + A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype) + if self.propagate_b: + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // target_dtype.bits) + qr = r * bit // storage_nbit + B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=in_dtype) + QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * bit), name="QZeros", dtype=self.storage_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) - return A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem + Bias = te.placeholder((self.N,), name="Bias", dtype=in_dtype) + return A, B, LUT, Scale, Zeros, QZeros, Bias + + def _propagate_input(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="A"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=False, dtype=in_dtype, matrix_name=matrix_name) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.M, self.K), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _propagage_weight(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="B"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + bit = self.bit + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=True, dtype=in_dtype, matrix_name=matrix_name) + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inversed_index_map.initial_indices + scaling_final_indices = inversed_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inversed_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + qr = r * bit // storage_nbit - def _decode_func(self, B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem): - w = None + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.N, self.K // storage_nbit * bit), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _decode_func(self, B, LUT, Scale, Zeros, QZeros): + bit = self.bit + in_dtype = self.in_dtype + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + + # TODO: Move the decode function into a more general place def decode(n, k): + w = None if self.with_zeros and self.zeros_mode == "quantized": - qzeros_dequantize = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, + qzeros_dequantize = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, QZeros[k, n // n_float_per_elem], n % n_float_per_elem, dtype=self.storage_dtype, ) - w = _tir_packed_to_unsigned_convert_with_zeros(self.storage_dtype, storage_nbit)( - self.bit, + w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, qzeros_dequantize, - dtype=self.in_dtype, + dtype=in_dtype, ) elif self.source_format == "uint": - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "int": - if self.bit == 1: - w = _tir_packed_int_to_int_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_signed_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp": w = _tir_u32_to_f4_to_f16( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif self.source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(self.bit, B[n, k], dtype=self.in_dtype) + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) elif self.source_format == "nf": - index = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, + index = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype="int32", @@ -132,7 +241,9 @@ def decode(n, k): w = LUT[index] else: raise ValueError(f"Unsupported source_format: {self.source_format}") - + + assert w is not None, "w is None" + group_size = self.group_size zeros_mode = self.zeros_mode @@ -167,7 +278,9 @@ def _compute_matmul(self, A, B_decode): def _convert_dtype(self, tensor): if self.accum_dtype != self.out_dtype: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j].astype(self.out_dtype), name="D") + return te.compute((self.M, self.N), + lambda i, j: tensor[i, j].astype(self.out_dtype), + name="D") return tensor def _apply_bias(self, tensor, Bias): @@ -176,9 +289,12 @@ def _apply_bias(self, tensor, Bias): return tensor def emit(self): - A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem = self._create_placeholders() - B_decode = self._decode_func(B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem) - C = self._compute_matmul(A, B_decode) + A, B, LUT, Scale, Zeros, QZeros, Bias = self._create_placeholders() + A_reindex = self._propagate_input(A, self.propagate_a, "A") + B_reindex = self._propagage_weight(B, self.propagate_b, "B") + + B_decode = self._decode_func(B_reindex, LUT, Scale, Zeros, QZeros) + C = self._compute_matmul(A_reindex, B_decode) D = self._convert_dtype(C) last_output = self._apply_bias(D, Bias) @@ -212,8 +328,13 @@ def emit(self): } }, ) + if self.propagate_a: + func = func.with_attr("input_transform_kind", self.propagate_a.value) + if self.propagate_b: + func = func.with_attr("weight_transform_kind", self.propagate_b.value) return tvm.IRModule.from_expr(func) + def matmul_nt_dequantize_b( M, N, @@ -335,9 +456,12 @@ def decode_func(n, k): A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: @@ -517,9 +641,11 @@ def decode_func(n, k): A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: @@ -715,9 +841,11 @@ def decode_func(n, k): ), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: diff --git a/testing/python/operators/test_tir_script_emitter.py b/testing/python/operators/test_tir_script_emitter.py index cec56b473..fcfa7d9af 100644 --- a/testing/python/operators/test_tir_script_emitter.py +++ b/testing/python/operators/test_tir_script_emitter.py @@ -1,18 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from bitblas.ops.impl.matmul_dequantize_impl import ( - MatMulNTDequantizeEmitter, - matmul_nt_dequantize_b, - matmul_nt_dequantize_b_propagate_b, - matmul_nt_dequantize_b_propagate_a_propagate_b, -) from bitblas import tvm import logging from bitblas import set_log_level set_log_level(logging.DEBUG) -def compare_tir_scripts_and_emitter( + +def check_eual_ref_scripts_with_emitter( M, N, K, @@ -28,8 +23,26 @@ def compare_tir_scripts_and_emitter( fast_decoding, with_bias, zeros_mode, + propagate_a, + propagate_b, ): - tir_script_func = matmul_nt_dequantize_b( + from bitblas.ops.impl.matmul_dequantize_impl import ( + MatMulNTDequantizeEmitter, + matmul_nt_dequantize_b, + matmul_nt_dequantize_b_propagate_b, + matmul_nt_dequantize_b_propagate_a_propagate_b, + ) + func = None + if propagate_a and propagate_b: + func = matmul_nt_dequantize_b_propagate_a_propagate_b + elif propagate_b: + func = matmul_nt_dequantize_b_propagate_b + else: + func = matmul_nt_dequantize_b + + assert func is not None, "No function found for the given configuration" + + ref_func = func( M, N, K, @@ -46,8 +59,8 @@ def compare_tir_scripts_and_emitter( with_bias, zeros_mode, ) - - emitter_func = MatMulNTDequantizeEmitter( + + emit_func = MatMulNTDequantizeEmitter( M, N, K, @@ -63,6 +76,21 @@ def compare_tir_scripts_and_emitter( fast_decoding, with_bias, zeros_mode, + propagate_a=propagate_a, + propagate_b=propagate_b, ).emit() - - tvm.ir.assert_structural_equal(tir_script_func, emitter_func) + + tvm.ir.assert_structural_equal(ref_func, emit_func) + + +def test_check_eual_ref_scripts_with_emitter(): + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) + check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", True, True) + +if __name__ == "__main__": + test_check_eual_ref_scripts_with_emitter() From 9bb7f49a968d4c71dbbc12121b4b7cb8258b2136 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:51:15 +0000 Subject: [PATCH 08/88] Lint Fix --- bitblas/ops/impl/matmul_dequantize_impl.py | 13 +++++---- .../operators/test_tir_script_emitter.py | 29 ++++++++++++++----- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index e69e8fcfb..7b91764ca 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -73,7 +73,7 @@ def _validate_bit(self): def _validate_layout(self): # TODO: extend the dequantize operators into General Layout pass - + def _legalize_group_size(self): if self.group_size == -1: self.group_size = self.K @@ -96,18 +96,19 @@ def _create_placeholders(self): l, r = 16, 32 # noqa: E741 A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * bit), - name="B", - dtype=storage_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * bit), name="B", dtype=storage_dtype) if self.propagate_a: A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype) if self.propagate_b: target_dtype = DataType(in_dtype) scaling_factor = 1 if bit > 0 and bit < target_dtype.bits: - scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // target_dtype.bits) + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) qr = r * bit // storage_nbit - B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), + name="B", + dtype=storage_dtype) LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype) diff --git a/testing/python/operators/test_tir_script_emitter.py b/testing/python/operators/test_tir_script_emitter.py index fcfa7d9af..b2c7a8d4f 100644 --- a/testing/python/operators/test_tir_script_emitter.py +++ b/testing/python/operators/test_tir_script_emitter.py @@ -84,13 +84,28 @@ def check_eual_ref_scripts_with_emitter( def test_check_eual_ref_scripts_with_emitter(): - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, "int8", "nf", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, False) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) - check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", False, True) - check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", True, False, -1, False, False, "original", True, True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "nf", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, + "int8", "nf", True, False, -1, False, False, "original", + False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + True) + check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, + "int8", "uint", True, False, -1, False, False, "original", + True, True) + if __name__ == "__main__": test_check_eual_ref_scripts_with_emitter() From 93eb5a5fe4e3eb6242675dd5706358c4121f1672 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 05:53:50 +0000 Subject: [PATCH 09/88] Refactor scripts with se check_eual_ref_scripts_with_emitter function --- bitblas/ops/impl/matmul_dequantize_impl.py | 198 --------------------- 1 file changed, 198 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 1ef14100d..7b91764ca 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -15,204 +15,6 @@ _tir_packed_to_unsigned_convert_with_zeros, ) -# TODO: The following code should be refactored. -class MatMulNTDequantizeEmitter: - def __init__( - self, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - zeros_mode="original", - propagate_a: TransformKind = TransformKind.NonTransform, - propagate_b: TransformKind = TransformKind.NonTransform, - ): - self.M = self._validate_dimension(M, "M") - self.N = N - self.K = K - self.in_dtype = in_dtype - self.out_dtype = out_dtype - self.accum_dtype = accum_dtype - self.bit = bit - self.storage_dtype = storage_dtype - self.source_format = source_format - self.with_scaling = with_scaling - self.with_zeros = with_zeros - self.group_size = group_size if group_size != -1 else K - self.fast_decoding = fast_decoding - self.with_bias = with_bias - self.zeros_mode = zeros_mode - self.propagate_a = propagate_a - self.propagate_b = propagate_b - - self._validate_bit() - self._validate_layout() - - @staticmethod - def _validate_dimension(dim, name): - if not isinstance(dim, int): - return tvm.te.var(name.lower()) - return dim - - def _validate_bit(self): - if self.bit not in [1, 2, 4, 8]: - raise ValueError(f"Unsupported bit: {self.bit}") - - def _validate_layout(self): - if self.layout not in ["nt"]: - raise ValueError(f"Unsupported layout: {self.layout}") - - def _create_placeholders(self): - storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) - n_float_per_elem = storage_nbit // self.bit - - A = te.placeholder((self.M, self.K), name="A", dtype=self.in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * self.bit), name="B", dtype=self.storage_dtype) - LUT = te.placeholder((1 << self.bit,), name="LUT", dtype=self.in_dtype) - Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=self.in_dtype) - Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=self.in_dtype) - QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * self.bit), - name="QZeros", - dtype=self.storage_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) - return A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem - - def _decode_func(self, B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem): - w = None - def decode(n, k): - if self.with_zeros and self.zeros_mode == "quantized": - qzeros_dequantize = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, - QZeros[k, n // n_float_per_elem], - n % n_float_per_elem, - dtype=self.storage_dtype, - ) - w = _tir_packed_to_unsigned_convert_with_zeros(self.storage_dtype, storage_nbit)( - self.bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - qzeros_dequantize, - dtype=self.in_dtype, - ) - elif self.source_format == "uint": - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "int": - if self.bit == 1: - w = _tir_packed_int_to_int_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_signed_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "fp": - w = _tir_u32_to_f4_to_f16( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(self.bit, B[n, k], dtype=self.in_dtype) - elif self.source_format == "nf": - index = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", - ) - w = LUT[index] - else: - raise ValueError(f"Unsupported source_format: {self.source_format}") - - group_size = self.group_size - zeros_mode = self.zeros_mode - - if not self.with_scaling: - return w - - if not self.with_zeros: - return w * Scale[n, k // group_size] - - if zeros_mode == "original": - w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] - elif zeros_mode == "rescale": - w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] - elif zeros_mode == "quantized": - w = w * Scale[n, k // group_size] - else: - raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) - - return w - - return te.compute((self.N, self.K), decode, name="B_decode") - - def _compute_matmul(self, A, B_decode): - k = te.reduce_axis((0, self.K), name="k") - C = te.compute( - (self.M, self.N), - lambda i, j: te.sum( - A[i, k].astype(self.accum_dtype) * B_decode[j, k].astype(self.accum_dtype), axis=k), - name="C", - ) - return C - - def _convert_dtype(self, tensor): - if self.accum_dtype != self.out_dtype: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j].astype(self.out_dtype), name="D") - return tensor - - def _apply_bias(self, tensor, Bias): - if self.with_bias: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j] + Bias[j], name="E") - return tensor - - def emit(self): - A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem = self._create_placeholders() - B_decode = self._decode_func(B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem) - C = self._compute_matmul(A, B_decode) - D = self._convert_dtype(C) - last_output = self._apply_bias(D, Bias) - - args = [A, B] - if self.source_format == "nf": - args.append(LUT) - if self.with_scaling: - args.append(Scale) - if self.with_zeros: - args.append(QZeros if self.zeros_mode == "quantized" else Zeros) - if self.with_bias: - args.append(Bias) - args.append(last_output) - - func = te.create_prim_func(args).with_attr( - "dequantize_info", - { - "B_decode": { - "decode_block": "B_decode", - "fast_decoding": self.fast_decoding, - "source_format": { - "bits": self.bit, - "format": self.source_format, - }, - "storage_dtype": self.storage_dtype, - "target_format": self.in_dtype, - "with_zeros": self.with_zeros, - "zeros_mode": self.zeros_mode, - "with_scaling": self.with_scaling, - "group_size": self.group_size, - } - }, - ) - return tvm.IRModule.from_expr(func) # TODO: The following code should be refactored. class MatMulNTDequantizeEmitter: From aa66a9080d41330ba63f38b76c539c6be0362906 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 08:03:46 +0000 Subject: [PATCH 10/88] bug fix in test --- bitblas/ops/impl/matmul_dequantize_impl.py | 9 ++-- testing/python/module/test_bitblas_linear.py | 41 ++++++++----------- .../operators/test_general_matmul_ops.py | 2 +- 3 files changed, 22 insertions(+), 30 deletions(-) diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 7b91764ca..55d672097 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -473,8 +473,7 @@ def decode_func(n, k): else: args.append(Zeros) if with_bias: - E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") - last_output = E + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") args.append(Bias) args.append(last_output) @@ -654,8 +653,7 @@ def decode_func(n, k): if with_zeros: args.append(Zeros) if with_bias: - E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") - last_output = E + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") args.append(Bias) args.append(last_output) @@ -854,8 +852,7 @@ def decode_func(n, k): if with_zeros: args.append(Zeros) if with_bias: - E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") - last_output = E + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") args.append(Bias) args.append(last_output) diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index eeaf90475..eee08c93c 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -11,16 +11,7 @@ torch.manual_seed(0) bitblas.set_log_level("DEBUG") -@pytest.mark.parametrize( - "m, in_features, out_features, bias", - [ - (1, 1024, 1024, False), - (1, 1024, 1024, True), - (1024, 1024, 1024, True), - ([1, 1024], 1024, 1024, True), - ], -) -def test_correctness_consistent(m, in_features, out_features, bias): +def correctness_consistent(m, in_features, out_features, bias): linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda()) linear_bitblas = BitBLASLinear( in_features, @@ -48,19 +39,13 @@ def test_correctness_consistent(m, in_features, out_features, bias): torch.testing.assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2) -@pytest.mark.parametrize( - "m, in_features, out_features, bias, W_dtype, group_size, with_scaling, with_zeros, zeros_mode", - [ - (1, 1024, 1024, False, "uint4", -1, False, False, None), - (1, 1024, 1024, False, "uint4", -1, False, False, None), - (1024, 1024, 1024, True, "uint4", -1, False, False, None), - (1, 1024, 1024, True, "uint2", -1, True, False, None), - (1, 1024, 1024, True, "uint2", 128, True, True, "original"), - (1024, 1024, 1024, True, "uint2", 128, True, True, "original"), - (1, 1024, 1024, True, "uint2", 128, True, True, "rescale"), - ], -) -def test_correctness_weight_only_dequantize( +def test_correctness_consistent(): + correctness_consistent(1, 1024, 1024, False) + correctness_consistent(1, 1024, 1024, True) + correctness_consistent(1024, 1024, 1024, True) + correctness_consistent([1, 1024], 1024, 1024, True) + +def correctness_weight_only_dequantize( m, in_features, out_features, @@ -169,6 +154,16 @@ def test_correctness_weight_only_dequantize( torch.testing.assert_close(output_bitblas, ref_result, rtol=1e0, atol=1e0) +def test_correctness_weight_only_dequantize(): + correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) + correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) + correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None) + correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", -1, True, False, None) + correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original") + correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original") + correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale") + + def profile(model, input_data): model = model.cuda() model.eval() diff --git a/testing/python/operators/test_general_matmul_ops.py b/testing/python/operators/test_general_matmul_ops.py index 05e0a45f4..62808e2a7 100644 --- a/testing/python/operators/test_general_matmul_ops.py +++ b/testing/python/operators/test_general_matmul_ops.py @@ -195,7 +195,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo if with_bias: permuted_inputs.append(bias) permuted_inputs.append(inputs[2]) - matmul(*permuted_inputs[:2], output=permuted_inputs[-1]) + matmul(*permuted_inputs[:-1], output=permuted_inputs[-1]) if zeros_mode == "rescale": torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0) else: From 79b08e415ffe79d7d4320e815cc2f5e603775e57 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 6 Jul 2024 08:21:08 +0000 Subject: [PATCH 11/88] lint fix. --- testing/python/module/test_bitblas_linear.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index eee08c93c..f329a146e 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -6,11 +6,11 @@ import time import numpy as np import torch.nn as nn -import pytest torch.manual_seed(0) bitblas.set_log_level("DEBUG") + def correctness_consistent(m, in_features, out_features, bias): linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda()) linear_bitblas = BitBLASLinear( @@ -45,6 +45,7 @@ def test_correctness_consistent(): correctness_consistent(1024, 1024, 1024, True) correctness_consistent([1, 1024], 1024, 1024, True) + def correctness_weight_only_dequantize( m, in_features, From 86fd0361bb74e21a87a26159022282ca25a4b282 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 7 Jul 2024 14:03:33 +0000 Subject: [PATCH 12/88] test cuda i4 kernel --- testing/cpp/CMakeLists.txt | 1 + .../cpp/efficient_i4_cuda_impl/CMakeLists.txt | 20 + .../efficient_i4_cuda_impl/efficient_i4.cu | 391 +++++++++ .../cpp/efficient_i4_cuda_impl/i4matmul.hpp | 822 ++++++++++++++++++ .../param_permutate.cpp | 89 ++ 5 files changed, 1323 insertions(+) create mode 100644 testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt create mode 100644 testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu create mode 100644 testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp create mode 100644 testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp diff --git a/testing/cpp/CMakeLists.txt b/testing/cpp/CMakeLists.txt index cf8eb0d3a..b92fa8da7 100644 --- a/testing/cpp/CMakeLists.txt +++ b/testing/cpp/CMakeLists.txt @@ -12,4 +12,5 @@ find_package(GTest REQUIRED) include_directories(${GTEST_INCLUDE_DIRS}) +add_subdirectory(efficient_i4_cuda_impl) add_subdirectory(lop3_type_conversion) diff --git a/testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt b/testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt new file mode 100644 index 000000000..36ffdf548 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/CMakeLists.txt @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +function (ADD_CUDA_TEST_EXECUTABLE name) + add_executable(${name} ${name}.cu) + set_target_properties(${name} PROPERTIES CUDA_ARCHITECTURES 80) + # add flags + target_compile_options(${name} PRIVATE --expt-relaxed-constexpr) + set_target_properties(${name} PROPERTIES + CUDA_SEPARABLE_COMPILATION ON) + target_link_libraries(${name} gtest gtest_main) +endfunction(ADD_CUDA_TEST_EXECUTABLE) + +ADD_CUDA_TEST_EXECUTABLE(efficient_i4) + +function (ADD_CPP_TEST_EXECUTABLE name) + add_executable(${name} ${name}.cpp) + target_link_libraries(${name} gtest gtest_main pthread) +endfunction(ADD_CPP_TEST_EXECUTABLE) + +ADD_CPP_TEST_EXECUTABLE(param_permutate) diff --git a/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu b/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu new file mode 100644 index 000000000..257f49a31 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/efficient_i4.cu @@ -0,0 +1,391 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include +#include +#include +#include +#include "i4matmul.hpp" + +#define cudaCheckLastError(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } +inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) +{ + if (code != cudaSuccess) + { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) + exit(code); + } +} + +void general_compress(const int8_t *lowbit, int8_t *compressed, const int nbit, const int N, bool isSigned = false) +{ + int zero_point = isSigned ? ((1 << (nbit - 1)) - 1) : 0; + const int nbit_per_byte = 8 / nbit; + + for (int i = 0; i < N / nbit_per_byte; i++) + { + compressed[i] = 0; + for (int j = 0; j < nbit_per_byte; j++) + { + compressed[i] |= ((lowbit[nbit_per_byte * i + j] + zero_point) << (nbit * j)); + } + } +} + + +// Helper function to interleave the perm array +std::vector interleave_perms(const std::vector& perm) { + std::vector interleaved_perm; + std::array interleave = {0, 2, 4, 6, 1, 3, 5, 7}; + + int num_rows = perm.size() / 8; + for (int i = 0; i < num_rows; ++i) { + std::array row; + std::copy(perm.begin() + i * 8, perm.begin() + (i + 1) * 8, row.begin()); + for (int j : interleave) { + interleaved_perm.push_back(row[j]); + } + } + + return interleaved_perm; +} + + +std::tuple, std::vector, std::vector> get_perms() { + std::vector perm; + + for (int i = 0; i < 32; ++i) { + std::vector perm1; + int col = i / 4; + for (int block : {0, 1}) { + for (int row : { + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1 + }) { + perm1.push_back(16 * row + col + 8 * block); + } + } + for (int j = 0; j < 4; ++j) { + for (int p : perm1) { + perm.push_back(p + 256 * j); + } + } + } + + // Interleave the perm array + perm = interleave_perms(perm); + + std::vector scale_perm; + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + scale_perm.push_back(i + 8 * j); + } + } + + std::vector scale_perm_single; + for (int i = 0; i < 4; ++i) { + for (int j : {0, 1, 8, 9, 16, 17, 24, 25}) { + scale_perm_single.push_back(2 * i + j); + } + } + + return std::make_tuple(perm, scale_perm, scale_perm_single); +} + +void weight_pre_process(const int8_t *lowbit, int8_t *compressed, const int nbit, const int K, const int N) +{ + int8_t* tmp1 = new int8_t[K * N]; + const int maxq = 15; + auto [perm, scale_perm, scale_perm_single] = get_perms(); + const int tile_size = 16; + // transform the lowbit matrix to the compressed matrix + for (int i = 0; i < (K / tile_size); i += 1) + { + for (int j = 0; j < (N / tile_size); j += 1) + { + for (int k = 0; k < tile_size; k++) + { + for (int l = 0; l < tile_size; l++) + { + int idx_target = i * N * tile_size + j * tile_size * tile_size + k * tile_size + l; + int idx_source = (i * tile_size + k) * N + j * tile_size + l; + tmp1[idx_target] = lowbit[idx_source] + (maxq + 1) / 2; + } + } + } + } + // print the first 10 of tmp2 + printf("tmp1\n"); + for (int i = 0; i < 10; i++) + { + printf("%d ", tmp1[i]); + } + printf(" ... "); + for (int i = K * N - 10; i < K * N; i++) + { + printf("%d ", tmp1[i]); + } + printf("\n"); + // permute the matrix + int32_t* tmp2 = new int32_t[K * N]; + const int perm_size = perm.size(); + for (int i = 0; i < (N * K / perm_size); i++) + { + for (int j = 0; j < perm_size; j++) + { + int idx_target = i * perm_size + j; + int idx_source = i * perm_size + perm[j]; + tmp2[idx_target] = (int32_t)tmp1[idx_source]; + } + } + // print the first 10 of tmp2 + printf("tmp2\n"); + for (int i = 0; i < 10; i++) + { + printf("%d ", tmp2[i]); + } + printf(" ... "); + for (int i = K * N / (32 / nbit) - 10; i < K * N / (32 / nbit); i++) + { + printf("%d ", tmp2[i]); + } + printf("\n"); + // compress + int32_t* tmp3 = new int32_t[K * N / (32 / nbit)]; + // set zero + for (int i = 0; i < K * N / (32 / nbit); i++) + { + tmp3[i] = 0; + } + for (int i = 0; i < (K / tile_size); i++) + { + for (int j = 0; j < (N * tile_size / 8); j++) + { + for (int k = 0; k < 8; k++) + { + int idx_target = i * N * tile_size / 8 + j; + int idx_source = i * N * tile_size + j * 8 + k; + tmp3[idx_target] |= (tmp2[idx_source] << (nbit * (k % 8))); + } + } + } + // print the first 10 of tmp3 + printf("tmp3\n"); + for (int i = 0; i < 10; i++) + { + printf("%d ", tmp3[i]); + } + printf(" ... "); + for (int i = K * N / (32 / nbit) - 10; i < K * N / (32 / nbit); i++) + { + printf("%d ", tmp3[i]); + } + printf("\n"); + // copy tmp3 to compressed + for (int i = 0; i < K * N / (32 / nbit); i++) + { + ((int32_t *)(compressed))[i] = tmp3[i]; + } +} + +void scale_pre_process(const half *scale, half *scale_perm, const int K, const int N, int group_size) +{ + auto [perm, scale_perm_group, scale_perm_single] = get_perms(); + if (group_size == -1) + group_size = K; + if (group_size == K){ + const int perm_size = scale_perm_single.size(); + for (int i = 0; i < (N * K / group_size / perm_size); i++) + { + for (int j = 0; j < perm_size; j++) + { + int idx_target = i * perm_size + j; + int idx_source = i * perm_size + scale_perm_single[j]; + if (idx_target < 10){ + printf("idx_target = %d, idx_source = %d\n", idx_target, idx_source); + } + scale_perm[idx_target] = scale[idx_source]; + } + } + } + else{ + const int perm_size = scale_perm_group.size(); + for (int i = 0; i < (N * K / group_size / perm_size); i++) + { + for (int j = 0; j < perm_size; j++) + { + int idx_target = i * perm_size + j; + int idx_source = i * perm_size + scale_perm_group[j]; + scale_perm[idx_target] = scale[idx_source]; + } + } + } + // print the first 10 of tmp2 + printf("scale_perm\n"); + for (int i = 0; i < 10; i++) + { + printf("%f ", (float)scale_perm[i]); + } + printf(" ... "); + for (int i = K * N / group_size - 10; i < K * N / group_size; i++) + { + printf("%f ", (float)scale_perm[i]); + } +} + +TEST(EfficientI4MatmulTest, GEMVTest) +{ + const int prom_m = 1; + const int prom_n = 256; + const int prom_k = 256; + const int bits = 4; + const int group_size = prom_k; + + half* A = new half[prom_m * prom_k]; + int8_t* B = new int8_t[prom_k * prom_n]; + int8_t* qB_interleave = new int8_t[prom_k * prom_n / (8 / bits)]; + half* C = new half[prom_m * prom_n]; + half* s = new half[prom_n * (prom_k / group_size)]; + half* s_perm = new half[prom_n * (prom_k / group_size)]; + + // Initialize A and B + for (int i = 0; i < prom_m * prom_k; i++) + { + A[i] = __float2half(rand() / (float)RAND_MAX); + } + for (int i = 0; i < prom_k * prom_n; i++) + { + B[i] = rand() % 4 - 2; + } + for (int i = 0; i < prom_k * prom_n / group_size; i++) + { + // s[i] = __float2half(0.1); + s[i] = __float2half(rand() / (float)RAND_MAX); + } + + weight_pre_process(B, qB_interleave, bits, prom_k, prom_n); + // print the first 10 elements and last 10 elements of C + for (int i = 0; i < 10; i++) + { + printf("%d ", B[i]); + } + printf(" ... "); + for (int i = prom_k * prom_n - 10; i < prom_k * prom_n; i++) + { + printf("%d ", B[i]); + } + // print interleave of B + for (int i = 0; i < 10; i++) + { + printf("%d ", qB_interleave[i]); + } + printf(" ... "); + for (int i = prom_k * prom_n / (8 / bits) - 10; i < prom_k * prom_n / (8 / bits); i++) + { + printf("%d ", qB_interleave[i]); + } + printf("\n"); + // print last 10 of qb_interleave + for (int i = prom_k * prom_n / (8 / bits) - 10; i < prom_k * prom_n / (8 / bits); i++) + { + printf("%d ", qB_interleave[i]); + } + printf("\n"); + // print last 10 of B + for (int i = prom_k * prom_n - 10; i < prom_k * prom_n; i++) + { + printf("%d ", B[i]); + } + printf("\n"); + // print last 10 of s + for (int i = prom_n * (prom_k / group_size) - 10; i < prom_n * (prom_k / group_size); i++) + { + printf("%f ", __half2float(s[i])); + } + printf("\n"); + scale_pre_process(s, s_perm, prom_k, prom_n, group_size); + // define cuda variables + float* d_workspace = nullptr; + cudaCheckLastError(cudaMalloc((void**)&d_workspace, prom_n * prom_k * 16 * sizeof(float))); + + half* d_A; + int8_t* d_qB; + half* d_C; + half* d_s; + cudaCheckLastError(cudaMalloc((void**)&d_A, prom_m * prom_k * sizeof(half))); + cudaCheckLastError(cudaMalloc((void**)&d_qB, prom_k * prom_n / (8 / bits) * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void**)&d_C, prom_m * prom_n * sizeof(half))); + cudaCheckLastError(cudaMalloc((void**)&d_s, prom_n * (prom_k / group_size) * sizeof(half))); + // copy A and B to device + cudaCheckLastError(cudaMemcpy(d_A, A, prom_m * prom_k * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(d_qB, qB_interleave, prom_n * prom_k / (8 / bits) * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(d_s, s_perm, prom_n * (prom_k / group_size) * sizeof(half), cudaMemcpyHostToDevice)); + + // allocate workspace + // call the kernel + int ret = marlin_cuda(d_A, d_qB, d_C, d_s, prom_m, prom_n, prom_k, d_workspace, group_size == prom_k? -1: group_size); + printf("ret = %d\n", ret); + + // copy C back to host + cudaCheckLastError(cudaMemcpy(C, d_C, prom_m * prom_n * sizeof(half), cudaMemcpyDeviceToHost)); + // print the first 10 elements and last 10 elements of C + for (int i = 0; i < 10; i++) + { + printf("%f ", __half2float(C[i])); + } + printf(" ... "); + for (int i = prom_m * prom_n - 10; i < prom_m * prom_n; i++) + { + printf("%f ", __half2float(C[i])); + } + printf("\n"); + + // ref calculation + float* ref_C = new float[prom_m * prom_n]; + // zero fill + for (int i = 0; i < prom_m * prom_n; i++) + { + ref_C[i] = __float2half(0.0); + } + // + for (int i = 0; i < prom_m; i++) + { + for (int j = 0; j < prom_n; j++) + { + ref_C[i * prom_n + j] = __float2half(0.0); + for (int k = 0; k < prom_k; k++) + { + ref_C[i * prom_n + j] += float(A[i * prom_k + k]) * (float(B[k * prom_n + j]) * float(s[(k / group_size) * prom_n + j])); + } + } + } + for (int i = 0; i < 10; i++) + { + printf("%f ", __half2float(ref_C[i])); + } + printf(" ... "); + for (int i = prom_m * prom_n - 10; i < prom_m * prom_n; i++) + { + printf("%f ", __half2float(ref_C[i])); + } + printf("\n"); + + // check the result + for (int i = 0; i < prom_m * prom_n; i++) + { + EXPECT_NEAR(__half2float(C[i]), __half2float(ref_C[i]), 1e-1); + } + + // free memory + delete[] A; + delete[] B; + delete[] C; + cudaCheckLastError(cudaFree(d_A)); + cudaCheckLastError(cudaFree(d_qB)); + cudaCheckLastError(cudaFree(d_C)); +} diff --git a/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp new file mode 100644 index 000000000..ae4cef5a2 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp @@ -0,0 +1,822 @@ +/* + * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#ifndef MARLIN_CUDA_KERNEL_CUH +#define MARLIN_CUDA_KERNEL_CUH + + +#include +#include +#include +#include + + +constexpr int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core +// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { + return elems[i]; + } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that +// are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for +// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need +// for inputs A and outputs C. +__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .b64 p;\n" + " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" + " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" :: "n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]) + ); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem) + ); +} + +// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to +// automatically recognize it in all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut) + ); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. +// We mostly follow the strategy in the link below, with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2( + *reinterpret_cast(&lo), + *reinterpret_cast(&SUB) + ); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible globally. + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier. + asm volatile ("fence.acq_rel.gpu;\n"); + asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + + +template < + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const int stages, // number of stages for the async global->shared fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale +> +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple + // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs + // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as + // possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case + // where a stripe starts in the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to top + + // We can easily implement parallel problem execution by just remapping indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for synchronization. + auto init_slice = [&] () { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time constant + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major + // layout in the former and in row-major in the latter case. + if (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than + // required for a certain tilesize or when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank + // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of + // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based + // on NSight-Compute) that each warp must also write a consecutive memory segment? + auto transform_a = [&] (int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory + // accesses are static, we simply precompute both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between + // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&] () { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location. + auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i] + ); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) + cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&] () { + // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when + // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer. + auto fetch_to_registers = [&] (int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a + // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the + // compiler and correspondingly a noticable drop in performance. + if (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&] (int k) { + // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + FragB frag_b0 = dequant(b_quant); + // If there are no groups, we can just scale the final output once and can avoid doing so for each weight. + if (group_blocks != -1) + scale(frag_b0, frag_s[k % 2][j], 0); + FragB frag_b1 = dequant(b_quant_shift); + if (group_blocks != -1) + scale(frag_b1, frag_s[k % 2][j], 1); + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n + // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&] () { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations, + // e.g., for two warps we write only once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over + // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&] (bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. + // To do this, we write out results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns, + // hence we also use async-copies even though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m + ); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float( + reinterpret_cast<__half*>(&c_red)[j] + ); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = __float2half( + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] + ); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step, + // the reduction above is performed in fragment layout. + auto write_result = [&] () { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final global write patterns + auto write = [&] (int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + if (group_blocks == -1) // for per-column quantization we finally apply the scale here + res = __hmul2(res, s[0]); + ((half2*) sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&] () { + #pragma unroll + for (int i = 0; i < stages - 1; i++) + fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are + // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) + break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most + // readable, other ways of writing the loop seemed to noticeably worse performance after compliation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before write-out + if (group_blocks == -1 && last) { + if (s_sh_wr_pred) + cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + thread_block_reduce(); + if (group_blocks == -1 && last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + if (slice_count > 1) { // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + + +// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more +// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. +const int THREADS = 256; +const int STAGES = 4; // 4 pipeline stages fit into shared memory +const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if ( \ + thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS \ + ) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + SHARED_MEM \ + ); \ + Marlin< \ + THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \ + ><<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, \ + prob_m, prob_n, prob_k, \ + locks \ + ); \ + } + +const int ERR_PROB_SHAPE = 1; +const int ERR_KERN_SHAPE = 2; + +int marlin_cuda( + const void* A, + const void* B, + void* C, + void* s, + int prob_m, + int prob_n, + int prob_k, + void* workspace, + int groupsize = -1, + int dev = 0, + cudaStream_t stream = 0, + int thread_k = -1, + int thread_n = -1, + int sms = -1, + int max_par = 16 +) { + int tot_m = prob_m; + int tot_m_blocks = ceildiv(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + if (thread_k == -1 || thread_n == -1) { + if (prob_m <= 16) { + // For small batchizes, better partioning is slightly more important than better compute utilization + thread_k = 128; + thread_n = 128; + } else { + thread_k = 64; + thread_n = 256; + } + } + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0)) + return ERR_PROB_SHAPE; + if (prob_m == 0 || prob_n == 0 || prob_k == 0) + return 0; + + const int4* A_ptr = (const int4*) A; + const int4* B_ptr = (const int4*) B; + int4* C_ptr = (int4*) C; + const int4* s_ptr = (const int4*) s; + + int cols = prob_n / thread_n; + int* locks = (int*) workspace; + + int ret = 0; + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) + par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance) + // in our testing, however many more are, in principle, possible. + if (false) {} + CALL_IF(1, 8, 8, -1) + CALL_IF(1, 8, 8, 8) + CALL_IF(1, 16, 4, -1) + CALL_IF(1, 16, 4, 8) + CALL_IF(2, 16, 4, -1) + CALL_IF(2, 16, 4, 8) + CALL_IF(3, 16, 4, -1) + CALL_IF(3, 16, 4, 8) + CALL_IF(4, 16, 4, -1) + CALL_IF(4, 16, 4, 8) + else + ret = ERR_KERN_SHAPE; + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } + + return ret; +} + + +#endif diff --git a/testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp b/testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp new file mode 100644 index 000000000..64248b3d1 --- /dev/null +++ b/testing/cpp/efficient_i4_cuda_impl/param_permutate.cpp @@ -0,0 +1,89 @@ +#include +#include +#include +#include +#include +#include + +// Helper function to interleave the perm array +std::vector interleave_perms(const std::vector& perm) { + std::vector interleaved_perm; + std::array interleave = {0, 2, 4, 6, 1, 3, 5, 7}; + + int num_rows = perm.size() / 8; + for (int i = 0; i < num_rows; ++i) { + std::array row; + std::copy(perm.begin() + i * 8, perm.begin() + (i + 1) * 8, row.begin()); + for (int j : interleave) { + interleaved_perm.push_back(row[j]); + } + } + + return interleaved_perm; +} + +std::tuple, std::vector, std::vector> get_perms() { + std::vector perm; + + for (int i = 0; i < 32; ++i) { + std::vector perm1; + int col = i / 4; + for (int block : {0, 1}) { + for (int row : { + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1 + }) { + perm1.push_back(16 * row + col + 8 * block); + } + } + for (int j = 0; j < 4; ++j) { + for (int p : perm1) { + perm.push_back(p + 256 * j); + } + } + } + + // Interleave the perm array + perm = interleave_perms(perm); + + std::vector scale_perm; + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + scale_perm.push_back(i + 8 * j); + } + } + + std::vector scale_perm_single; + for (int i = 0; i < 4; ++i) { + for (int j : {0, 1, 8, 9, 16, 17, 24, 25}) { + scale_perm_single.push_back(2 * i + j); + } + } + + return std::make_tuple(perm, scale_perm, scale_perm_single); +} + +TEST(EfficientI4MatmulTest, ParamPermutate) +{ + auto [perm, scale_perm, scale_perm_single] = get_perms(); + + std::cout << "perm: "; + for (int i = 0; i < 10; ++i) { + std::cout << perm[i] << " "; + } + std::cout << std::endl; + + std::cout << "scale_perm: "; + for (const auto& val : scale_perm) { + std::cout << val << " "; + } + std::cout << std::endl; + + std::cout << "scale_perm_single: "; + for (const auto& val : scale_perm_single) { + std::cout << val << " "; + } + std::cout << std::endl; +} From 6b73a210ff846fb55bed253b5c0a1c089c3e95f1 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 7 Jul 2024 16:39:04 +0000 Subject: [PATCH 13/88] Refactor copyright notice in i4matmul.hpp --- THIRDPARTYNOTICES.txt | 204 ++++++++++++++++++ .../cpp/efficient_i4_cuda_impl/i4matmul.hpp | 36 ++-- 2 files changed, 224 insertions(+), 16 deletions(-) diff --git a/THIRDPARTYNOTICES.txt b/THIRDPARTYNOTICES.txt index f377e67bb..d959effbb 100644 --- a/THIRDPARTYNOTICES.txt +++ b/THIRDPARTYNOTICES.txt @@ -206,3 +206,207 @@ Notice for apache/tvm limitations under the License. ------------------------------------------------------------------------------------ +Notice for IST-DASLab/marlin/ +------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +------------------------------------------------------------------------------------ diff --git a/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp index ae4cef5a2..a12a57dcd 100644 --- a/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp +++ b/testing/cpp/efficient_i4_cuda_impl/i4matmul.hpp @@ -1,19 +1,23 @@ -/* - * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - +// Copyright 2018 The apache/tvm Authors. All Rights Reserved. +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// Modifications Copyright (c) Microsoft. +// The code below is mostly copied from marlin_cuda in IST-DASLab/marlin. #ifndef MARLIN_CUDA_KERNEL_CUH #define MARLIN_CUDA_KERNEL_CUH From 086d208fd07984c8cf876f1529fc80ade4cff21d Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 7 Jul 2024 16:54:51 +0000 Subject: [PATCH 14/88] Refactor BitBLASLinear test module for improved readability and maintainability --- testing/python/module/test_bitblas_linear.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index 3da4a73b6..f329a146e 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -10,6 +10,7 @@ torch.manual_seed(0) bitblas.set_log_level("DEBUG") + def correctness_consistent(m, in_features, out_features, bias): linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda()) linear_bitblas = BitBLASLinear( @@ -44,6 +45,7 @@ def test_correctness_consistent(): correctness_consistent(1024, 1024, 1024, True) correctness_consistent([1, 1024], 1024, 1024, True) + def correctness_weight_only_dequantize( m, in_features, From 47a3abdb805d2d27241888dce310c9f08f0771fb Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Jul 2024 04:47:15 +0000 Subject: [PATCH 15/88] refactor test as version below python 3.9 cannot handle int32 overflow. --- .../test_int4b_fp16_convert.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/testing/python/type_conversion/test_int4b_fp16_convert.py b/testing/python/type_conversion/test_int4b_fp16_convert.py index 92b0e0788..c1ed480fb 100644 --- a/testing/python/type_conversion/test_int4b_fp16_convert.py +++ b/testing/python/type_conversion/test_int4b_fp16_convert.py @@ -44,25 +44,25 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): if nbits == 1 and target_dtype == "int8": # special handling for 1b interleave - n16_weight = new_qweight & np.int32(0xF0F00F0F) - n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 - n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24 - n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 - n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 + n16_weight = new_qweight & np.int32(np.uint32(0xF0F00F0F)) + n16_weight |= ((new_qweight & np.int32(np.uint32(0x000000F0))) >> 4) << 16 + n16_weight |= ((new_qweight & np.int32(np.uint32(0x0000F000))) >> 12) << 24 + n16_weight |= ((new_qweight & np.int32(np.uint32(0x000F0000))) >> 16) << 4 + n16_weight |= ((new_qweight & np.int32(np.uint32(0x0F000000))) >> 24) << 12 return n16_weight.view(np.int8) elif nbits == 2 and target_dtype == "float16": - n8_weight = new_qweight & np.int32(0xFF0000FF) - n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 - n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 + n8_weight = new_qweight & np.int32(np.uint32(0xFF0000FF)) + n8_weight |= ((new_qweight & np.int32(np.uint32(0x0000FF00))) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x00FF0000))) >> 16) << 8 return n8_weight.view(np.int8) elif nbits == 1 and target_dtype == "float16": - n8_weight = new_qweight & 0xF000000F - n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 - n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 - n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24 - n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4 - n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12 - n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20 + n8_weight = new_qweight & np.int32(np.uint32(0xF000000F)) + n8_weight |= ((new_qweight & np.int32(np.uint32(0x000000F0))) >> 4) << 8 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x00000F00))) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x0000F000))) >> 12) << 24 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x000F0000))) >> 16) << 4 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x00F00000))) >> 20) << 12 + n8_weight |= ((new_qweight & np.int32(np.uint32(0x0F000000))) >> 24) << 20 return new_qweight.view(np.int8) From 024b2474b8b81a3f855be5e6ccfcd05139bd285a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Jul 2024 04:47:36 +0000 Subject: [PATCH 16/88] format lint for test --- .../test_int4b_fp16_convert.py | 56 +++++-------------- 1 file changed, 15 insertions(+), 41 deletions(-) diff --git a/testing/python/type_conversion/test_int4b_fp16_convert.py b/testing/python/type_conversion/test_int4b_fp16_convert.py index c1ed480fb..2af765047 100644 --- a/testing/python/type_conversion/test_int4b_fp16_convert.py +++ b/testing/python/type_conversion/test_int4b_fp16_convert.py @@ -21,9 +21,7 @@ def general_compress_to_int8(lowprecision_weight, source_bits=4): ) for j in range(lowprecision_weight.shape[-1] // elems_per_byte): for k in range(elems_per_byte): - int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << ( - source_bits * k - ) + int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << (source_bits * k) return int8_weight @@ -80,17 +78,11 @@ def interleave_weight(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32 with T.block("B"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) @T.prim_func - def interleave_weight_f16_2b( - A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") - ): + def interleave_weight_f16_2b(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32")): B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") @@ -98,12 +90,8 @@ def interleave_weight_f16_2b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -114,9 +102,7 @@ def interleave_weight_f16_2b( B[v0, v1] = B_tmp_1[v0, v1] | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] @T.prim_func - def interleave_weight_f16_1b( - A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") - ): + def interleave_weight_f16_1b(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32")): B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") @@ -128,12 +114,8 @@ def interleave_weight_f16_1b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -152,13 +134,10 @@ def interleave_weight_f16_1b( | B_tmp_4[v0, v1] | B_tmp_5[v0, v1] | B_tmp_6[v0, v1] - | B_tmp_7[v0, v1] - ) + | B_tmp_7[v0, v1]) @T.prim_func - def interleave_weight_int8_1b( - A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32") - ): + def interleave_weight_int8_1b(A: T.Buffer((N, QK), "int32"), B: T.Buffer((N, QK), "int32")): B_tmp_1 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_2 = T.alloc_buffer((N, QK), "int32", scope="local") B_tmp_3 = T.alloc_buffer((N, QK), "int32", scope="local") @@ -168,12 +147,8 @@ def interleave_weight_int8_1b( with T.block("B_tmp"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + ( - offset // num_groups - ) * bits - B[v0, v1] = B[v0, v1] | ( - ((A[v0, v1] >> (bits * offset)) & mask) << shift - ) + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) for ax0, ax1 in T.grid(N, QK): with T.block("B"): @@ -188,8 +163,7 @@ def interleave_weight_int8_1b( | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] | B_tmp_4[v0, v1] - | B_tmp_5[v0, v1] - ) + | B_tmp_5[v0, v1]) if target_dtype == "float16" and bits == 2: return interleave_weight_f16_2b @@ -207,7 +181,7 @@ def test_lop3_interleave_weight(): K = 16 target_dtype = "float16" torch.manual_seed(0) - uint_max = 2 ** (source_nbits) - 1 + uint_max = 2**(source_nbits) - 1 raw_data = torch.randint(0, uint_max, (N, K), dtype=torch.int8).cpu().numpy() compressed_b = general_compress_to_int8(raw_data, source_nbits) interleaved_weight = interleave_weight(compressed_b, source_nbits, target_dtype) From bfedeaa813c269330ea332cce9938e05eac809ec Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Jul 2024 04:50:49 +0000 Subject: [PATCH 17/88] Refactor test_int4b_fp16_convert.py for improved readability and maintainability --- testing/python/type_conversion/test_int4b_fp16_convert.py | 1 - 1 file changed, 1 deletion(-) diff --git a/testing/python/type_conversion/test_int4b_fp16_convert.py b/testing/python/type_conversion/test_int4b_fp16_convert.py index 2af765047..3a58a47e1 100644 --- a/testing/python/type_conversion/test_int4b_fp16_convert.py +++ b/testing/python/type_conversion/test_int4b_fp16_convert.py @@ -5,7 +5,6 @@ import torch import numpy as np from tvm.script import tir as T -import numpy as np def general_compress_to_int8(lowprecision_weight, source_bits=4): From e672a23a3910605eacfd705aede46be3c81a0021 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Jul 2024 08:47:05 +0000 Subject: [PATCH 18/88] remove unused design file --- bitblas/generator.py | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 bitblas/generator.py diff --git a/bitblas/generator.py b/bitblas/generator.py deleted file mode 100644 index 4ac6f2be2..000000000 --- a/bitblas/generator.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -class BitBLASGenerator: - - def __init__(self): - # Initialize the generator with configuration - pass - - def generate_cuda_code(self): - pass - - def generate_header(self): - pass From 21e54300313ebe803cf96fdd7fd212bb0e09bf21 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Jul 2024 08:47:40 +0000 Subject: [PATCH 19/88] move tile device from package to base --- bitblas/base/{roller => }/arch/__init__.py | 0 bitblas/base/{roller => }/arch/arch_base.py | 0 bitblas/base/{roller => }/arch/cpu.py | 0 bitblas/base/{roller => }/arch/cuda.py | 0 bitblas/base/roller/__init__.py | 2 +- bitblas/base/roller/policy/default.py | 2 +- bitblas/base/roller/policy/tensorcore.py | 2 +- bitblas/base/utils.py | 2 +- 8 files changed, 4 insertions(+), 4 deletions(-) rename bitblas/base/{roller => }/arch/__init__.py (100%) rename bitblas/base/{roller => }/arch/arch_base.py (100%) rename bitblas/base/{roller => }/arch/cpu.py (100%) rename bitblas/base/{roller => }/arch/cuda.py (100%) diff --git a/bitblas/base/roller/arch/__init__.py b/bitblas/base/arch/__init__.py similarity index 100% rename from bitblas/base/roller/arch/__init__.py rename to bitblas/base/arch/__init__.py diff --git a/bitblas/base/roller/arch/arch_base.py b/bitblas/base/arch/arch_base.py similarity index 100% rename from bitblas/base/roller/arch/arch_base.py rename to bitblas/base/arch/arch_base.py diff --git a/bitblas/base/roller/arch/cpu.py b/bitblas/base/arch/cpu.py similarity index 100% rename from bitblas/base/roller/arch/cpu.py rename to bitblas/base/arch/cpu.py diff --git a/bitblas/base/roller/arch/cuda.py b/bitblas/base/arch/cuda.py similarity index 100% rename from bitblas/base/roller/arch/cuda.py rename to bitblas/base/arch/cuda.py diff --git a/bitblas/base/roller/__init__.py b/bitblas/base/roller/__init__.py index 9afd7cff0..3f728e695 100644 --- a/bitblas/base/roller/__init__.py +++ b/bitblas/base/roller/__init__.py @@ -4,4 +4,4 @@ from .rasterization import NoRasterization, Rasterization2DRow, Rasterization2DColumn # noqa: F401 from .hint import Hint # noqa: F401 from .policy import DefaultPolicy, TensorCorePolicy # noqa: F401 -from .arch import TileDevice, CUDA # noqa: F401 +from ..arch import TileDevice, CUDA # noqa: F401 diff --git a/bitblas/base/roller/policy/default.py b/bitblas/base/roller/policy/default.py index 730c8336f..e9f7b809f 100644 --- a/bitblas/base/roller/policy/default.py +++ b/bitblas/base/roller/policy/default.py @@ -9,7 +9,7 @@ import numpy as np from bitblas import tvm -from ..arch import TileDevice +from ...arch import TileDevice from ..bestfit import BestFit from ..hint import Hint, Stride, TileDict from .common import coalesced_factor, coalesced_tensor_shape, factorize, get_all_factors diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index ae45b5893..e69bcabc3 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -5,7 +5,7 @@ from typing import Dict, List, Tuple, Optional import numpy as np -from ..arch import TileDevice +from ...arch import TileDevice from ..hint import Hint, Stride, TileDict, IntrinInfo from ..node import PrimFuncNode from .common import coalesced_factor, factorize, get_all_factors diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 4cd82fa93..1596b3c86 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -13,7 +13,7 @@ from tvm.relax.expr import Function import bitblas from .analysis import get_root_block, get_reduction_blocks, find_var_from_func -from bitblas.base.roller.arch import CUDA +from bitblas.base.arch import CUDA from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags import tempfile From fd1194063884a00678663d04cdda811e9cf0f932 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Jul 2024 08:49:18 +0000 Subject: [PATCH 20/88] dummy impl for codegen --- bitblas/codegen/base.py | 25 +++++++++++++++++++++++++ bitblas/codegen/tir.py | 4 ++++ 2 files changed, 29 insertions(+) create mode 100644 bitblas/codegen/base.py create mode 100644 bitblas/codegen/tir.py diff --git a/bitblas/codegen/base.py b/bitblas/codegen/base.py new file mode 100644 index 000000000..312091f94 --- /dev/null +++ b/bitblas/codegen/base.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas.ops.operator import OperatorConfig +from abc import ABC, abstractmethod + +class Backend(ABC): + """ + input: OperatorConfig + The duty of bakend: + - is OperatorConfig is Available For our Backend + - Generate CUDA Source for compilation + """ + def __init__(self, config): + self.config = config + @abstractmethod + def compile(self, config): + pass + + @abstractmethod + def execute(self, *args, **kwargs): + pass + + @abstractmethod + def optimize(self, *args, **kwargs): + pass \ No newline at end of file diff --git a/bitblas/codegen/tir.py b/bitblas/codegen/tir.py new file mode 100644 index 000000000..ff5b0b6be --- /dev/null +++ b/bitblas/codegen/tir.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +class \ No newline at end of file From 9ccfa85d5581cd86f50bdeec79c316564083399c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Jul 2024 08:49:38 +0000 Subject: [PATCH 21/88] Refactor file structure for ladder_permutate module --- .../__init__.py} | 129 +-- bitblas/ops/general_matmul/backend/tir.py | 2 + .../ops/general_matmul/tirscript/__init__.py | 4 + .../tirscript/matmul_dequantize_impl.py | 967 ++++++++++++++++++ .../general_matmul/tirscript/matmul_impl.py | 356 +++++++ .../__init__.py} | 4 +- .../ladder_permutate/ladder_permutate_impl.py | 82 ++ .../__init__.py} | 4 +- .../ops/lop3_permutate/lop3_permutate_impl.py | 152 +++ bitblas/ops/matmul.py | 2 +- 10 files changed, 1638 insertions(+), 64 deletions(-) rename bitblas/ops/{general_matmul.py => general_matmul/__init__.py} (91%) create mode 100644 bitblas/ops/general_matmul/backend/tir.py create mode 100644 bitblas/ops/general_matmul/tirscript/__init__.py create mode 100644 bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py create mode 100644 bitblas/ops/general_matmul/tirscript/matmul_impl.py rename bitblas/ops/{ladder_permutate.py => ladder_permutate/__init__.py} (96%) create mode 100644 bitblas/ops/ladder_permutate/ladder_permutate_impl.py rename bitblas/ops/{lop3_permutate.py => lop3_permutate/__init__.py} (95%) create mode 100644 bitblas/ops/lop3_permutate/lop3_permutate_impl.py diff --git a/bitblas/ops/general_matmul.py b/bitblas/ops/general_matmul/__init__.py similarity index 91% rename from bitblas/ops/general_matmul.py rename to bitblas/ops/general_matmul/__init__.py index 97dd7d13f..cf16029dd 100644 --- a/bitblas/ops/general_matmul.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -4,17 +4,16 @@ from tvm.target import Target import operator from functools import reduce -from bitblas.base.roller.arch.cuda import CUDA +from bitblas.base.arch.cuda import CUDA from typing import Any, Literal, Optional, Tuple, Union -from .operator import Operator, TransformKind, OPExecutorCPU -from .impl.matmul_dequantize_impl import ( - select_implementation as weight_dequantize_implementation,) -from .impl.matmul_impl import select_implementation as consistent_implementation -from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 +from ..operator import Operator, TransformKind, OPExecutorCPU +from .tirscript.matmul_dequantize_impl import select_implementation as weight_dequantize_implementation +from .tirscript.matmul_impl import select_implementation as consistent_implementation +from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 from bitblas.utils.target_detector import auto_detect_nvidia_target from dataclasses import dataclass -from .ladder_permutate import LadderPermutate, LadderPermutateConfig -from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig +from ..ladder_permutate import LadderPermutate, LadderPermutateConfig +from ..lop3_permutate import LOP3Permutate, LOP3PermutateConfig import logging import torch @@ -252,6 +251,49 @@ def __init__( self._build_default_module(target) self.workspace = None + if source_format == "nf": + self.lut = torch.tensor( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], + dtype=getattr(torch, self.A_dtype), + ).cuda() + else: + self.lut = None + + # create permutate_opertors + self.ladder_permutate_a = self._assign_ladder_permutate_a(target, enable_tuning) + self.ladder_permutate_b = self._assign_ladder_permutate_b(target, enable_tuning) + self.lop3_permutate = self._assign_lop3_permutate(target, enable_tuning) + # create cpu weight executors + self.weight_executors = self._create_weight_executors() + + if enable_tuning: + self.hardware_aware_finetune() + + # output data type + self.torch_output_dtype = getattr(torch, self.out_dtype) + + def _alloc_workspace(self): + return torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() + + def _assign_ladder_permutate_a(self, target: Target, enable_tuning: bool): + ladder_permutate_a = None if self.propagate_a: # for general purpose, we use propagate_a to control the ladder permutation. ladder_permutate_config = LadderPermutateConfig( @@ -263,14 +305,18 @@ def __init__( transpose_matrix=False, transform_kind=self.propagate_a, ) - self.ladder_permutate_a = LadderPermutate( + ladder_permutate_a = LadderPermutate( config=ladder_permutate_config, target=target, enable_tuning=enable_tuning, ) - self.workspace = torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() - else: - self.ladder_permutate_a = None + self.workspace = self._alloc_workspace() + return ladder_permutate_a + + def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool): + # unused variables + del target + del enable_tuning if self.propagate_b: ladder_permutate_config = LadderPermutateConfig( @@ -283,13 +329,16 @@ def __init__( transpose_matrix=self.layout == "nt", transform_kind=self.propagate_b, ) - self.ladder_permutate_b = LadderPermutate( + return LadderPermutate( config=ladder_permutate_config, target=tvm.target.Target("llvm"), ) - else: - self.ladder_permutate_b = None + return None + def _assign_lop3_permutate(self, target: Target, enable_tuning: bool): + # unused variables + del target + del enable_tuning if self.fast_decoding: assert self.source_format in ["int", "uint"] lop3_permutate_config = LOP3PermutateConfig( @@ -299,57 +348,19 @@ def __init__( dequantize_bits=self.bit, storage_dtype=self.storage_dtype, ) - self.lop3_permutate = LOP3Permutate( + return LOP3Permutate( config=lop3_permutate_config, target=tvm.target.Target("llvm"), ) - else: - self.lop3_permutate = None - - input_executors = OPExecutorCPU() - if self.ladder_permutate_a is not None: - input_executors.append(self.ladder_permutate_a) - self.input_executors = input_executors + return None + def _create_weight_executors(self): weight_executors = OPExecutorCPU() - if self.lop3_permutate is not None: + if self.fast_decoding: weight_executors.append(self.lop3_permutate) - - if self.ladder_permutate_b is not None: + if self.propagate_b is not TransformKind.NonTransform: weight_executors.append(self.ladder_permutate_b) - - self.weight_executors = weight_executors - - if enable_tuning: - self.hardware_aware_finetune() - - if source_format == "nf": - self.lut = torch.tensor( - [ - -1.0, - -0.6961928009986877, - -0.5250730514526367, - -0.39491748809814453, - -0.28444138169288635, - -0.18477343022823334, - -0.09105003625154495, - 0.0, - 0.07958029955625534, - 0.16093020141124725, - 0.24611230194568634, - 0.33791524171829224, - 0.44070982933044434, - 0.5626170039176941, - 0.7229568362236023, - 1.0, - ], - dtype=getattr(torch, self.A_dtype), - ).cuda() - else: - self.lut = None - - # output data type - self.torch_output_dtype = getattr(torch, self.out_dtype) + return weight_executors def _build_default_module(self, target: Target): try: diff --git a/bitblas/ops/general_matmul/backend/tir.py b/bitblas/ops/general_matmul/backend/tir.py new file mode 100644 index 000000000..59e481eb9 --- /dev/null +++ b/bitblas/ops/general_matmul/backend/tir.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/bitblas/ops/general_matmul/tirscript/__init__.py b/bitblas/ops/general_matmul/tirscript/__init__.py new file mode 100644 index 000000000..ae3038239 --- /dev/null +++ b/bitblas/ops/general_matmul/tirscript/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .matmul_dequantize_impl import select_implementation as matmul_dequantize_select_implementation # noqa: F401 +from .matmul_impl import select_implementation as matmul_select_implementation # noqa: F401 diff --git a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py new file mode 100644 index 000000000..82c83be90 --- /dev/null +++ b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py @@ -0,0 +1,967 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +from bitblas import tvm +from tvm import te, DataType +from tvm.tir import IndexMap +from bitblas.ops.operator import TransformKind +from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.quantization import ( + _tir_packed_int_to_int_convert, + _tir_packed_to_signed_convert, + _tir_packed_to_unsigned_convert, + _tir_u32_to_f4_to_f16, + _tir_u8_to_f8_e4m3_to_f16, + _tir_packed_to_unsigned_convert_with_zeros, +) + + +class MatMulNTDequantizeEmitter: + + def __init__( + self, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, + ): + self.M = self._validate_dimension(M, "M") + self.N = N + self.K = K + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.accum_dtype = accum_dtype + self.bit = bit + self.storage_dtype = storage_dtype + self.source_format = source_format + self.with_scaling = with_scaling + self.with_zeros = with_zeros + self.group_size = group_size if group_size != -1 else K + self.fast_decoding = fast_decoding + self.with_bias = with_bias + self.zeros_mode = zeros_mode + self.propagate_a = self._legalize_transform_kind(propagate_a) + self.propagate_b = self._legalize_transform_kind(propagate_b) + + self._validate_bit() + self._validate_layout() + + @staticmethod + def _validate_dimension(dim, name): + if not isinstance(dim, int): + return tvm.te.var(name.lower()) + return dim + + def _validate_bit(self): + if self.bit not in [1, 2, 4, 8]: + raise ValueError(f"Unsupported bit: {self.bit}") + + def _validate_layout(self): + # TODO: extend the dequantize operators into General Layout + pass + + def _legalize_group_size(self): + if self.group_size == -1: + self.group_size = self.K + + def _legalize_transform_kind(self, propagate): + if propagate is None: + return TransformKind.NonTransform + if isinstance(propagate, bool): + return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) + elif isinstance(propagate, int): + return TransformKind(propagate) + + def _create_placeholders(self): + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + in_dtype = self.in_dtype + bit = self.bit + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * bit), name="B", dtype=storage_dtype) + if self.propagate_a: + A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype) + if self.propagate_b: + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + qr = r * bit // storage_nbit + B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), + name="B", + dtype=storage_dtype) + + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=in_dtype) + QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * bit), + name="QZeros", + dtype=self.storage_dtype) + Bias = te.placeholder((self.N,), name="Bias", dtype=in_dtype) + return A, B, LUT, Scale, Zeros, QZeros, Bias + + def _propagate_input(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="A"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=False, dtype=in_dtype, matrix_name=matrix_name) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.M, self.K), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _propagage_weight(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="B"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + bit = self.bit + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=True, dtype=in_dtype, matrix_name=matrix_name) + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inversed_index_map.initial_indices + scaling_final_indices = inversed_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inversed_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + qr = r * bit // storage_nbit + + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.N, self.K // storage_nbit * bit), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _decode_func(self, B, LUT, Scale, Zeros, QZeros): + bit = self.bit + in_dtype = self.in_dtype + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + + # TODO: Move the decode function into a more general place + def decode(n, k): + w = None + if self.with_zeros and self.zeros_mode == "quantized": + qzeros_dequantize = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + QZeros[k, n // n_float_per_elem], + n % n_float_per_elem, + dtype=self.storage_dtype, + ) + w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + qzeros_dequantize, + dtype=in_dtype, + ) + elif self.source_format == "uint": + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif self.source_format == "int": + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif self.source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif self.source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) + elif self.source_format == "nf": + index = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", + ) + w = LUT[index] + else: + raise ValueError(f"Unsupported source_format: {self.source_format}") + + assert w is not None, "w is None" + + group_size = self.group_size + zeros_mode = self.zeros_mode + + if not self.with_scaling: + return w + + if not self.with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + elif zeros_mode == "quantized": + w = w * Scale[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + return te.compute((self.N, self.K), decode, name="B_decode") + + def _compute_matmul(self, A, B_decode): + k = te.reduce_axis((0, self.K), name="k") + C = te.compute( + (self.M, self.N), + lambda i, j: te.sum( + A[i, k].astype(self.accum_dtype) * B_decode[j, k].astype(self.accum_dtype), axis=k), + name="C", + ) + return C + + def _convert_dtype(self, tensor): + if self.accum_dtype != self.out_dtype: + return te.compute((self.M, self.N), + lambda i, j: tensor[i, j].astype(self.out_dtype), + name="D") + return tensor + + def _apply_bias(self, tensor, Bias): + if self.with_bias: + return te.compute((self.M, self.N), lambda i, j: tensor[i, j] + Bias[j], name="E") + return tensor + + def emit(self): + A, B, LUT, Scale, Zeros, QZeros, Bias = self._create_placeholders() + A_reindex = self._propagate_input(A, self.propagate_a, "A") + B_reindex = self._propagage_weight(B, self.propagate_b, "B") + + B_decode = self._decode_func(B_reindex, LUT, Scale, Zeros, QZeros) + C = self._compute_matmul(A_reindex, B_decode) + D = self._convert_dtype(C) + last_output = self._apply_bias(D, Bias) + + args = [A, B] + if self.source_format == "nf": + args.append(LUT) + if self.with_scaling: + args.append(Scale) + if self.with_zeros: + args.append(QZeros if self.zeros_mode == "quantized" else Zeros) + if self.with_bias: + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": self.fast_decoding, + "source_format": { + "bits": self.bit, + "format": self.source_format, + }, + "storage_dtype": self.storage_dtype, + "target_format": self.in_dtype, + "with_zeros": self.with_zeros, + "zeros_mode": self.zeros_mode, + "with_scaling": self.with_scaling, + "group_size": self.group_size, + } + }, + ) + if self.propagate_a: + func = func.with_attr("input_transform_kind", self.propagate_a.value) + if self.propagate_b: + func = func.with_attr("weight_transform_kind", self.propagate_b.value) + return tvm.IRModule.from_expr(func) + + +def matmul_nt_dequantize_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N, K // storage_nbit * bit), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) + QZeros = te.placeholder(((K // group_size), N // storage_nbit * bit), + name="QZeros", + dtype=storage_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def qzeros_dequantize(k, n): + return _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + QZeros[k, n // n_float_per_elem], + n % n_float_per_elem, + dtype=storage_dtype, + ) + + Dequantize_qzeros = None + if with_zeros and zeros_mode == "quantized": + Dequantize_qzeros = te.compute( + (K // group_size, N), + qzeros_dequantize, + name="Dequantize_zeros", + ) + + def decode_func(n, k): + if with_zeros and zeros_mode == "quantized": + assert Dequantize_qzeros is not None, "Dequantize_zeros is None" + w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + Dequantize_qzeros[k // group_size, n], + dtype=in_dtype, + ) + elif source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "int": + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + elif zeros_mode == "quantized": + w = w * Scale[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), + name="C", + ) + + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + args = [A, B] + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_zeros: + if zeros_mode == "quantized": + args.append(QZeros) + else: + args.append(Zeros) + if with_bias: + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_scaling": with_scaling, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "group_size": group_size, + } + }, + ) + return tvm.IRModule.from_expr(func) + + +def matmul_nt_dequantize_b_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + transform_kind: TransformKind = TransformKind.IntraWarpTransform, +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + _, inverse_indexmap = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inverse_indexmap.initial_indices + scaling_final_indices = inverse_indexmap.map_indices(initial_indices[:-1] + + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inverse_indexmap = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + qr = r * bit // storage_nbit + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inverse_indexmap.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K // storage_nbit * bit), + fcompute, + name="B_reindex", + ) + + def decode_func(n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "int": + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + args = [A, B] + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_zeros: + args.append(Zeros) + if with_bias: + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "with_scaling": with_scaling, + "group_size": group_size, + } + }, + ) + func = func.with_attr("weight_transform_kind", transform_kind.value) + return tvm.IRModule.from_expr(func) + + +def matmul_nt_dequantize_b_propagate_a_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, + transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") + A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind_input >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return A[new_index] + + A_reindex = te.compute( + (M, K), + fcompute, + name="A_reindex", + ) + + _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inversed_index_map.initial_indices + scaling_final_indices = inversed_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inversed_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + qr = r * bit // storage_nbit + B = te.placeholder((N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind_weight >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K // storage_nbit * bit), + fcompute, + name="B_reindex", + ) + + def decode_func(n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "int": + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A_reindex[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), + axis=k, + ), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + args = [A, B] + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_zeros: + args.append(Zeros) + if with_bias: + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "with_scaling": with_scaling, + "group_size": group_size, + } + }, + ) + func = func.with_attr("input_transform_kind", transform_kind_input.value) + func = func.with_attr("weight_transform_kind", transform_kind_weight.value) + return tvm.IRModule.from_expr(func) + +# Should be refactored with Emitter +def select_implementation( + M=None, + N=1024, + K=1024, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + layout="nt", + zeros_mode="original", + propagate_a=False, + propagate_b=False, +): + if layout == "nn": + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn in Dequantize Implementation" + ) + elif layout == "nt": + if propagate_a and propagate_b: + return matmul_nt_dequantize_b_propagate_a_propagate_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + transform_kind_input=propagate_a, + transform_kind_weight=propagate_b, + ) + elif propagate_a: + raise NotImplementedError + elif propagate_b: + return matmul_nt_dequantize_b_propagate_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + transform_kind=propagate_b, + ) + else: + return matmul_nt_dequantize_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + ) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/bitblas/ops/general_matmul/tirscript/matmul_impl.py b/bitblas/ops/general_matmul/tirscript/matmul_impl.py new file mode 100644 index 000000000..b093f0d9c --- /dev/null +++ b/bitblas/ops/general_matmul/tirscript/matmul_impl.py @@ -0,0 +1,356 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +from bitblas import tvm +from tvm import te +from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.ops.operator import TransformKind + + +def matmul_nn( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + if not isinstance(M, int): + M = tvm.te.var("m") + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((K, N), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum(A[i, k].astype(accum_dtype) * B[k, j].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def matmul_nt( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + if not isinstance(M, int): + M = tvm.te.var("m") + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N, K), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum(A[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def matmul( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", +): + if layout == "nn": + return matmul_nn(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) + return matmul_nt(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) + + +def matmul_nt_propagate_a( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + transform_kind: TransformKind = TransformKind.IntraWarpTransform, +): + if not isinstance(M, int): + M = tvm.te.var("m") + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") + + A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) + B = te.placeholder((N, K), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return A[new_index] + + A_reindex = te.compute( + (M, K), + fcompute, + name="A_reindex", + ) + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A_reindex[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + func = func.with_attr("input_transform_kind", transform_kind.value) + + return tvm.IRModule.from_expr(func) + + +def matmul_nt_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + transform_kind: TransformKind = TransformKind.IntraWarpTransform, +): + if not isinstance(M, int): + M = tvm.te.var("m") + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K), + fcompute, + name="B_reindex", + ) + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + func = func.with_attr("weight_transform_kind", transform_kind.value) + + return tvm.IRModule.from_expr(func) + + +def matmul_nt_propagate_a_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, + transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, +): + if not isinstance(M, int): + M = tvm.te.var("m") + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) + B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind_input >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return A[new_index] + + A_reindex = te.compute( + (M, K), + fcompute, + name="A_reindex", + ) + + _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind_weight >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K), + fcompute, + name="B_reindex", + ) + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A_reindex[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), + axis=k, + ), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + func = func.with_attr("input_transform_kind", transform_kind_input.value) + func = func.with_attr("weight_transform_kind", transform_kind_weight.value) + + return tvm.IRModule.from_expr(func) + + +def select_implementation( + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, +): + if layout == "nn": + if propagate_a or propagate_b: + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn") + return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + elif layout == "nt": + if propagate_a and propagate_b: + return matmul_nt_propagate_a_propagate_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + transform_kind_input=propagate_a, + transform_kind_weight=propagate_b, + ) + elif propagate_a: + return matmul_nt_propagate_a( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + transform_kind=propagate_a, + ) + elif propagate_b: + return matmul_nt_propagate_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + transform_kind=propagate_b, + ) + else: + return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/bitblas/ops/ladder_permutate.py b/bitblas/ops/ladder_permutate/__init__.py similarity index 96% rename from bitblas/ops/ladder_permutate.py rename to bitblas/ops/ladder_permutate/__init__.py index 70999b09d..6644705cd 100644 --- a/bitblas/ops/ladder_permutate.py +++ b/bitblas/ops/ladder_permutate/__init__.py @@ -2,8 +2,8 @@ # Licensed under the MIT License. from tvm.target import Target from typing import Literal, Union -from .operator import Operator -from .impl.ladder_permutate_impl import select_implementation +from ..operator import Operator +from .ladder_permutate_impl import select_implementation from dataclasses import dataclass diff --git a/bitblas/ops/ladder_permutate/ladder_permutate_impl.py b/bitblas/ops/ladder_permutate/ladder_permutate_impl.py new file mode 100644 index 000000000..76b5a01fb --- /dev/null +++ b/bitblas/ops/ladder_permutate/ladder_permutate_impl.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas.gpu.matmul_analysis import get_propagate_map +from typing import Literal +from tvm import te, IRModule, DataType +from tvm.tir import IndexMap + + +def select_implementation( + M: int, + N: int, + datatype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"] = "float16", + dequantize_bits: int = -1, + storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16", + propagate_kind: Literal["A", "B"] = "B", + transpose_matrix: bool = False, + transform_kind: int = 0, + target_instruction: Literal["nvidia-mma"] = "nvidia-mma", +): + if target_instruction != "nvidia-mma": + raise ValueError("Currently only support nvidia-mma instruction") + + # This is trick to get the basic tile size for the current datatype + # as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8 + l = r = 16 # noqa: E741 + if datatype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + intra_index_map, _ = get_propagate_map( + transpose_matrix, dtype=datatype, matrix_name=propagate_kind) + + target_dtype = DataType(datatype) + scaling_factor = 1 + if dequantize_bits > 0 and dequantize_bits < target_dtype.bits: + scaling_factor = ((target_dtype.bits // dequantize_bits) * DataType(storage_dtype).bits // + target_dtype.bits) + r = r // scaling_factor + initial_indices = intra_index_map.initial_indices + scaling_final_indices = intra_index_map.map_indices(initial_indices[:-1] + + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + intra_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + inp = te.placeholder((M, N // scaling_factor), name="inp", dtype=storage_dtype) + args = [inp] + + assert transform_kind != 0, "Permute only apply when transform_kind >= 1" + if transform_kind >= 1: + arg = args[-1] + + inter_warp = te.compute( + (M // l, (N // scaling_factor) // r, l, r), + lambda i, j, ii, jj: arg[i * l + ii, j * r + jj], + name="inter_warp_permutate", + ) + args.append(inter_warp) + if transform_kind >= 2: + arg = args[-1] + + def fcompute(*args): + warp_i, warp_j = args[-2:] + spatial_args = args[:-2] + permutate_i, permutate_j = intra_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, permutate_i, permutate_j) + return arg[new_index] + + intra_warp = te.compute( + (M // l, (N // scaling_factor) // r, l, r), + fcompute, + name="intra_warp_permutate", + ) + args.append(intra_warp) + args = [args[0], args[-1]] + + func = te.create_prim_func(args) + + return IRModule.from_expr(func) diff --git a/bitblas/ops/lop3_permutate.py b/bitblas/ops/lop3_permutate/__init__.py similarity index 95% rename from bitblas/ops/lop3_permutate.py rename to bitblas/ops/lop3_permutate/__init__.py index 867432a5e..7715be471 100644 --- a/bitblas/ops/lop3_permutate.py +++ b/bitblas/ops/lop3_permutate/__init__.py @@ -2,8 +2,8 @@ # Licensed under the MIT License. from tvm.target import Target from typing import Literal, Union -from .operator import Operator -from .impl.lop3_permutate_impl import select_implementation +from ..operator import Operator +from .lop3_permutate_impl import select_implementation from dataclasses import dataclass import torch diff --git a/bitblas/ops/lop3_permutate/lop3_permutate_impl.py b/bitblas/ops/lop3_permutate/lop3_permutate_impl.py new file mode 100644 index 000000000..07d8f4f0c --- /dev/null +++ b/bitblas/ops/lop3_permutate/lop3_permutate_impl.py @@ -0,0 +1,152 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Literal +from tvm import DataType +from tvm import IRModule +from tvm.ir import GlobalVar +from tvm.script import tir as T + + +# fmt: off +# TIR interleave weight impl-> 2D implementation +def tir_interleave_weight( + N: int = 2, + K: int = 16, + bits: int = 4, + QK: int = -1, + target_dtype: str = "float16", + storage_dtype: str = "int32", +): + if QK == -1: + QK = K * bits // 32 + bits_stride = DataType(target_dtype).bits + mask = (1 << bits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // bits + + @T.prim_func + def interleave_weight(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), storage_dtype)): + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) + + @T.prim_func + def interleave_weight_f16_2b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), + storage_dtype)): + B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xFF0000FF) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x00FF0000)) << 8) >> 16 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x0000FF00)) << 16) >> 8 + B[v0, v1] = B_tmp_1[v0, v1] | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] + + @T.prim_func + def interleave_weight_f16_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), + storage_dtype)): + B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_4 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_5 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_6 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_7 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xF000000F) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x000000F0)) >> 4) << 8 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x00000F00)) >> 8) << 16 + B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24 + B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 8 + B_tmp_6[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 20) << 12 + B_tmp_7[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 24) << 20 + B[v0, v1] = ( + B_tmp_1[v0, v1] + | B_tmp_2[v0, v1] + | B_tmp_3[v0, v1] + | B_tmp_4[v0, v1] + | B_tmp_5[v0, v1] + | B_tmp_6[v0, v1] + | B_tmp_7[v0, v1]) + + @T.prim_func + def interleave_weight_int8_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), + storage_dtype)): + B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_4 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_5 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xF0F00F0F) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x000000F0)) >> 4) << 16 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24 + B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 4 + B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x0F000000)) >> 24) << 12 + B[v0, v1] = ( + B_tmp_1[v0, v1] + | B_tmp_2[v0, v1] + | B_tmp_3[v0, v1] + | B_tmp_4[v0, v1] + | B_tmp_5[v0, v1]) + + if target_dtype == "float16" and bits == 2: + return interleave_weight_f16_2b + elif target_dtype == "float16" and bits == 1: + return interleave_weight_f16_1b + elif target_dtype == "int8" and bits == 1: + return interleave_weight_int8_1b + + return interleave_weight + + +# fmt: on + + +def select_implementation( + M: int, + N: int, + datatype: Literal["float16", "int8"] = "float16", + storage_dtype: Literal["int8", "uint8", "int32", "uint32"] = "int32", + dequantize_bits: int = 4, +): + func = tir_interleave_weight( + N=M, + K=N, + bits=dequantize_bits, + target_dtype=datatype, + storage_dtype=storage_dtype, + ) + mod = IRModule() + mod.update_func(GlobalVar("main"), func) + return mod diff --git a/bitblas/ops/matmul.py b/bitblas/ops/matmul.py index 34014abb9..7334906c8 100644 --- a/bitblas/ops/matmul.py +++ b/bitblas/ops/matmul.py @@ -148,7 +148,7 @@ def __init__( input_executors = TransformExecutorCPU() if self.ladder_permutate_a is not None: - input_executors.append(self.ladder_permutate_b) + input_executors.append(self.ladder_permutate_a) self.input_executors = input_executors From 7c7d73eeda9314c732916dd1c9e0cdbf00af4d25 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Jul 2024 08:50:28 +0000 Subject: [PATCH 22/88] Refactor backend class and fix typos in comments --- bitblas/codegen/base.py | 9 ++++++--- bitblas/codegen/tir.py | 2 -- bitblas/ops/general_matmul/tirscript/__init__.py | 4 ++-- .../general_matmul/tirscript/matmul_dequantize_impl.py | 1 + 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/bitblas/codegen/base.py b/bitblas/codegen/base.py index 312091f94..974f631be 100644 --- a/bitblas/codegen/base.py +++ b/bitblas/codegen/base.py @@ -3,15 +3,18 @@ from bitblas.ops.operator import OperatorConfig from abc import ABC, abstractmethod + class Backend(ABC): """ input: OperatorConfig - The duty of bakend: + The duty of backend: - is OperatorConfig is Available For our Backend - Generate CUDA Source for compilation """ - def __init__(self, config): + + def __init__(self, config: OperatorConfig): self.config = config + @abstractmethod def compile(self, config): pass @@ -22,4 +25,4 @@ def execute(self, *args, **kwargs): @abstractmethod def optimize(self, *args, **kwargs): - pass \ No newline at end of file + pass diff --git a/bitblas/codegen/tir.py b/bitblas/codegen/tir.py index ff5b0b6be..59e481eb9 100644 --- a/bitblas/codegen/tir.py +++ b/bitblas/codegen/tir.py @@ -1,4 +1,2 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - -class \ No newline at end of file diff --git a/bitblas/ops/general_matmul/tirscript/__init__.py b/bitblas/ops/general_matmul/tirscript/__init__.py index ae3038239..f783e05b3 100644 --- a/bitblas/ops/general_matmul/tirscript/__init__.py +++ b/bitblas/ops/general_matmul/tirscript/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .matmul_dequantize_impl import select_implementation as matmul_dequantize_select_implementation # noqa: F401 -from .matmul_impl import select_implementation as matmul_select_implementation # noqa: F401 +from .matmul_dequantize_impl import select_implementation as matmul_dequantize_select_implementation # noqa: F401 +from .matmul_impl import select_implementation as matmul_select_implementation # noqa: F401 diff --git a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py index 82c83be90..65f1c75e5 100644 --- a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py +++ b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py @@ -878,6 +878,7 @@ def decode_func(n, k): func = func.with_attr("weight_transform_kind", transform_kind_weight.value) return tvm.IRModule.from_expr(func) + # Should be refactored with Emitter def select_implementation( M=None, From 47d5fc5125953c6442ad316d01e4c048e54094ae Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 8 Jul 2024 17:26:44 +0000 Subject: [PATCH 23/88] Deep refactor Lib related code. --- bitblas/builder/lib_generator/__init__.py | 72 +++ bitblas/builder/wrapper.py | 0 .../tir.py => builder/wrapper/__init__.py} | 1 + bitblas/builder/wrapper/base.py | 12 + bitblas/builder/wrapper/tir.py | 413 ++++++++++++++++++ bitblas/cache/operator.py | 44 +- bitblas/codegen/base.py | 28 -- bitblas/ops/general_matmul/__init__.py | 24 +- bitblas/ops/general_matmul/cuda/__init__.py | 88 ++++ bitblas/ops/general_matmul/cuda/template.py | 2 + .../{backend/tir.py => tilelang/__init__.py} | 2 + bitblas/ops/matmul.py | 11 - bitblas/ops/matmul_dequantize.py | 13 +- bitblas/ops/operator.py | 178 ++++---- bitblas/utils/__init__.py | 1 + bitblas/utils/rtmod_analysis.py | 92 ++++ bitblas/wrapper/general.py | 20 +- .../builder/test_backend_tir_builder.py | 51 +++ 18 files changed, 883 insertions(+), 169 deletions(-) create mode 100644 bitblas/builder/lib_generator/__init__.py create mode 100644 bitblas/builder/wrapper.py rename bitblas/{codegen/tir.py => builder/wrapper/__init__.py} (72%) create mode 100644 bitblas/builder/wrapper/base.py create mode 100644 bitblas/builder/wrapper/tir.py delete mode 100644 bitblas/codegen/base.py create mode 100644 bitblas/ops/general_matmul/cuda/__init__.py create mode 100644 bitblas/ops/general_matmul/cuda/template.py rename bitblas/ops/general_matmul/{backend/tir.py => tilelang/__init__.py} (70%) create mode 100644 bitblas/utils/rtmod_analysis.py create mode 100644 testing/python/builder/test_backend_tir_builder.py diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py new file mode 100644 index 000000000..60b2ad6f9 --- /dev/null +++ b/bitblas/builder/lib_generator/__init__.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Optional +from bitblas import TileDevice +import ctypes +import os +import tempfile +import subprocess +import logging + +logger = logging.getLogger(__name__) + + +class LibraryGenerator(object): + srcpath: Optional[str] = None + libpath: Optional[str] = None + lib_code: Optional[str] = None + + def __init__(self, arch: TileDevice): + self.arch = arch + + def update_lib_code(self, lib_code: str): + self.lib_code = lib_code + + # Assume currently we only support CUDA compilation + def load_lib(self): + return ctypes.CDLL(self.libpath) + + def compile_lib(self, timeout: float = None): + arch = self.arch + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) + compute_version = arch.compute_capability + libpath = src.name.replace(".cu", ".so") + + command = [ + "nvcc", + "-std=c++17", + "-Xcudafe", + "--diag_suppress=177", + "--compiler-options", + "'-fPIC'", + "-lineinfo", + "--shared", + src.name, + "-lcuda", + f"-gencode=arch=compute_{compute_version},code=compute_{compute_version}", + "-o", + libpath, + ] + src.write(self.lib_code) + src.flush() + try: + ret = subprocess.run(command, timeout=timeout) + except subprocess.TimeoutExpired: + logger.warning(f"Compilation Timeout! {command}") + return None + if ret.returncode != 0: + logger.warning(f"Compilation Failed! {command}") + return None + self.srcpath = src.name + self.libpath = libpath + + def remove_lib(self): + if self.libpath: + os.remove(self.libpath) + self.libpath = None + + def get_source_path(self): + return self.srcpath + + def get_lib_path(self): + return self.libpath diff --git a/bitblas/builder/wrapper.py b/bitblas/builder/wrapper.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitblas/codegen/tir.py b/bitblas/builder/wrapper/__init__.py similarity index 72% rename from bitblas/codegen/tir.py rename to bitblas/builder/wrapper/__init__.py index 59e481eb9..316a80ecc 100644 --- a/bitblas/codegen/tir.py +++ b/bitblas/builder/wrapper/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from .tir import TIRWrapper diff --git a/bitblas/builder/wrapper/base.py b/bitblas/builder/wrapper/base.py new file mode 100644 index 000000000..200285f33 --- /dev/null +++ b/bitblas/builder/wrapper/base.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from abc import ABC, abstractmethod + + +class BaseWrapper(ABC): + def __init__(self): + pass + + @abstractmethod + def wrap(self, *args, **kwargs): + raise NotImplementedError diff --git a/bitblas/builder/wrapper/tir.py b/bitblas/builder/wrapper/tir.py new file mode 100644 index 000000000..f8e0ad1b1 --- /dev/null +++ b/bitblas/builder/wrapper/tir.py @@ -0,0 +1,413 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm +from typing import Optional, List, Dict, Union +from tvm import IRModule +from bitblas import TileDevice +from bitblas.utils import match_global_kernel +from bitblas.utils.rtmod_analysis import get_annotated_device_mod +import re +from .base import BaseWrapper +from abc import ABC, abstractmethod +from bitblas import tvm +from tvm import IRModule +from tvm.target import Target +from tvm.tir import PrimFunc +from tvm.contrib.dlpack import to_pytorch_func +from tvm._ffi.base import _LIB, raise_last_ffi_error +from tvm._ffi._ctypes.types import TVMValue, ArgTypeCode +from typing import List, Dict, Optional +import logging + +logger = logging.getLogger(__name__) + + +import logging +logger = logging.getLogger(__name__) + +class TIRCUDASourceWrapper(object): + _TYPE_MAP = { + "float32": "float", + "float16": "half", + "bfloat16": "__nv_bfloat162", + "e4m3_float8": "__nv_fp8_e4m3", + "e5m2_float8": "__nv_fp8_e5m2", + "float64": "double", + "int64": "int64_t", + "int32": "int", + "uint32": "unsigned int", + "bool": "int8_t", + "int8": "int8_t", + "uint8": "uint8_t", + "int16": "int16_t", + "uchar": "uint8_t", + } + + def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): + self.mod = optimized_mod + self.arch = arch + self.source = source + self.function_name: Optional[str] = None + self.dynamic_smem_buf: Optional[int] = None + self.block_info: Union[List[int], Dict] = [1, 1, 1] + self.grid_info: Union[List[int], Dict] = [1, 1, 1] + self.parse_source_information() + self.srcpath: Optional[str] = None + self.libpath: Optional[str] = None + self.lib_code: Optional[str] = self.update_lib_code(source) + + def parse_source_information(self): + device_mod = get_annotated_device_mod(self.mod, self.arch.target) + assert (len(device_mod.functions) == 1 + ), "Only support one function in the module for static shape kernel." + for g_var, func in device_mod.functions.items(): + self.function_name = g_var.name_hint + attrs = func.attrs + if "dyn_shared_memory_buf" in attrs: + self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) + if "thread_extent" in attrs: + thread_extent = attrs["thread_extent"] + for tag, extent in thread_extent.items(): + if "threadIdx" in tag: + self.block_info["xyz".index(tag[-1])] = extent + elif "blockIdx" in tag: + self.grid_info["xyz".index(tag[-1])] = extent + + def get_dynamic_symbolic_set(self, prim_func): + # Determine the set of dynamic symbols used in the function + dynamic_symbolic_set = set() + for param in prim_func.params: + buffer = prim_func.buffer_map[param] + for dim in buffer.shape: + if isinstance(dim, tvm.tir.Var): + dynamic_symbolic_set.add(dim.name) + return dynamic_symbolic_set + + def get_cuda_init_func(self): + # Initialize an empty string for the CUDA function call + call_str = """""" + # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call + if self.dynamic_smem_buf is not None: + call_str = """ + cudaFuncSetAttribute({}, + cudaFuncAttributeMaxDynamicSharedMemorySize, {}); + """.format(self.function_name, self.dynamic_smem_buf) + # Format the initialization function using the call_str + init_funcs = """ + extern "C" void init() {{ + {} + }} + """.format(call_str) + return init_funcs + + def update_lib_code(self, code: str): + # Update the library code with the given code string + self.lib_code = code + # Find the index of the global kernel function in the code + index = match_global_kernel(code) + # Extract the declaration of the function starting from the found index + declaration = code[index:].split(";")[0] + + function_name = self.function_name + # Get the CUDA initialization function + init_func = self.get_cuda_init_func() + + # Locate the opening brace of the function to insert arguments + index = code.index("{", index) + function_args = [] + # Populate the function arguments from the primary function's parameters and buffers + for param in self.prim_func.params: + buffer = self.prim_func.buffer_map[param] + function_args.append({ + "name": buffer.name, + "type": self._TYPE_MAP[buffer.dtype] + "* __restrict__", + }) + + dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) + # Add dynamic symbolic parameters as integers to the function arguments + for dyn_sym in dynamic_symbolic_set: + function_args.append({"name": dyn_sym, "type": "int"}) + + function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) + # Format the function arguments for declaration + def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) + + def func_call_args(s, function_args): + # Extract the function call arguments matching the function definition + pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" + matches = re.findall(pattern, s) + call_args = [] + for match in matches: + for arg in function_args: + if arg["name"] == match: + call_args.append(match) + return call_args + + call_args = ", ".join(func_call_args(declaration, function_args)) + block_info, grid_info = self.block_info, self.grid_info + + def legalize_c(p): + # Convert TIR expressions to legal C expressions + # Directly convert to string since the special case handling + # does not alter the string representation for `tvm.tir.Var` and `IntImm`. + # Replace Python's floor division operator with C's division operator + if isinstance(p, tvm.tir.IntImm): + p = int(p) + return str(p).replace("//", "/") + + # Prepare the block and grid dimensions for the CUDA kernel launch + block_str = "dim3({}, {}, {})".format( + legalize_c(block_info[0]), + legalize_c(block_info[1]), + legalize_c(block_info[2]), + ) + grid_str = "dim3({}, {}, {})".format( + legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2])) + # Determine the shared memory size, defaulting to 0 if not specified + smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf + # Format the CUDA kernel launch string + if len(dynamic_symbolic_set) != 0: + call_str = "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0]) + else: + call_str = "" + call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, + smem_str, call_args) + # Create the host function wrapper for the CUDA kernel + host_func = """ + extern "C" void call({}) {{ + {} + }} + """.format(def_args, call_str) + # Combine the source, initialization function, and host function to form the complete library code + lib_code = self.source + init_func + host_func + return lib_code + + @property + def prim_func(self): + return self.mod["main"] + +class TIRCUDASourceWrapperWithDynamic(TIRCUDASourceWrapper): + + def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): + super().__init__(optimized_mod, source, arch) + + def get_cuda_init_func(self): + # Initialize an empty string to accumulate CUDA function calls for setting dynamic shared memory + call_str = """""" + # Iterate over functions and their dynamic shared memory requirements + for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): + if dynamic_smem_buf is not None: + # Format the cudaFuncSetAttribute call for dynamic shared memory + call_str += """ + cudaFuncSetAttribute({}, + cudaFuncAttributeMaxDynamicSharedMemorySize, {}); + """.format(function_name, dynamic_smem_buf) + # Define the init function that will set the attributes for each kernel + init_funcs = """ +extern "C" void init() {{ + {} +}} + """.format(call_str) + return init_funcs + + def create_dispatch_func(self, code, function_informations): + # Extract the set of dynamic symbolic names used in the primary function + dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) + + # Find the location of the global kernel function in the code + index = match_global_kernel(code) + + # Analyze the function declaration to prepare for argument extraction + dummy_declaration = code[index:].split(";")[0] + + function_name = self.function_name + + # Identify the start of the function body to insert arguments + index = code.index("{", index) + function_args = [] + # Collect function arguments based on primary function's parameters and buffer mappings + for param in self.prim_func.params: + buffer = self.prim_func.buffer_map[param] + function_args.append({ + "name": buffer.name, + "type": self._TYPE_MAP[buffer.dtype] + "* __restrict__", + }) + # Add dynamic symbols as integer arguments + for dyn_sym in dynamic_symbolic_set: + function_args.append({"name": dyn_sym, "type": "int"}) + + function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) + + # Format the argument definitions for function declaration + def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) + + def func_call_args(s: str, function_args): + # Extract and clean the function call arguments to match the declaration + pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" + matches = re.findall(pattern, s) + call_args = [] + for match in matches: + match = re.sub(r"\d+", "", match) # Remove numbers + match = re.sub(r"_", "", match) # Remove underscores + for arg in function_args: + if arg["name"] == match: + call_args.append(match) + return call_args + + call_args = ", ".join(func_call_args(dummy_declaration, function_args)) + + def legalize_c(p): + # Convert TIR expressions to legal C expressions + # Directly convert to string since the special case handling + # does not alter the string representation for `tvm.tir.Var` and `IntImm`. + # Replace Python's floor division operator with C's division operator + if isinstance(p, tvm.tir.IntImm): + p = int(p) + return str(p).replace("//", "/") + + last_range = 0 + num_items = len(function_informations) + _call_str = """""" + for function_name, info in function_informations.items(): + # Prepare block and grid configurations for kernel launches + block_info, grid_info = info["block_info"], info["grid_info"] + block_str = "dim3({}, {}, {})".format( + legalize_c(block_info[0]), + legalize_c(block_info[1]), + legalize_c(block_info[2]), + ) + grid_str = "dim3({}, {}, {})".format( + legalize_c(grid_info[0]), + legalize_c(grid_info[1]), + legalize_c(grid_info[2]), + ) + # Handle dynamic shared memory specification + smem_str = (0 if info["dynamic_smem_buf"] is None else info["dynamic_smem_buf"]) + opt_shapes = info["opt_shapes"] + # Generate conditional kernel launch code based on dynamic symbolic ranges + (symbolic,) = list(dynamic_symbolic_set) + range_str = opt_shapes[symbolic] + if last_range == 0: + call_str = "if ({} == 0) return; \n".format(symbolic,) + call_str += "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + symbolic, + range_str, + function_name, + grid_str, + block_str, + smem_str, + call_args, + ) + else: + call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + symbolic, + range_str, + function_name, + grid_str, + block_str, + smem_str, + call_args, + ) + if last_range == num_items - 1: + call_str += ( + "\t\telse {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + function_name, grid_str, block_str, smem_str, call_args)) + last_range += 1 + _call_str += call_str + + # Wrap the kernel dispatch logic in an external C function + host_func = """ +extern "C" void call({}) {{ + {} +}} + """.format(def_args, _call_str) + return host_func + + def parse_source_information(self): + # Parse device module to extract execution configurations for each function + device_mod = get_annotated_device_mod(self.mod, self.arch.target) + block_info_map = {} + grid_info_map = {} + dynamic_smem_buf_map = {} + for g_var, func in device_mod.functions.items(): + # Default block and grid configurations + block_info = [1, 1, 1] + grid_info = [1, 1, 1] + function_name = g_var.name_hint + attrs = func.attrs + dynamic_smem_buf = None + if "dyn_shared_memory_buf" in attrs: + dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) + if "thread_extent" in attrs: + # Extract block and grid sizes from thread extents + thread_extent = attrs["thread_extent"] + for tag, extent in thread_extent.items(): + if "threadIdx" in tag: + block_info["xyz".index(tag[-1])] = extent + elif "blockIdx" in tag: + grid_info["xyz".index(tag[-1])] = extent + # Map the extracted configurations to each function + block_info_map[function_name] = block_info + grid_info_map[function_name] = grid_info + dynamic_smem_buf_map[function_name] = dynamic_smem_buf + # Store the mappings for use in code generation + self.block_info = block_info_map + self.grid_info = grid_info_map + self.dynamic_smem_buf = dynamic_smem_buf_map + + def update_lib_code(self, code: str): + # Organize function information for code generation + function_informations = {} + for g_var, func in self.mod.functions.items(): + if g_var.name_hint == "main": + continue + function_name = g_var.name_hint + attrs = func.attrs + assert "opt_shapes" in attrs + opt_shapes = attrs["opt_shapes"] + function_informations[function_name] = { + "function_name": function_name, + "opt_shapes": opt_shapes, + "block_info": self.block_info[function_name], + "grid_info": self.grid_info[function_name], + "dynamic_smem_buf": self.dynamic_smem_buf[function_name], + } + + def compare_map_objects(map_obj): + comparable_representation = list(map_obj.values()) + return comparable_representation + + function_informations = dict( + sorted( + function_informations.items(), + key=lambda item: compare_map_objects(item[1]["opt_shapes"]))) + + self.lib_code = code + + # Generate the initialization and dispatch functions + init_func = self.get_cuda_init_func() + host_func = self.create_dispatch_func(code, function_informations) + # Concatenate source code with generated code segments + lib_code = self.source + init_func + host_func + return lib_code + + @property + def prim_func(self): + return self.mod["main"] + +class TIRWrapper(BaseWrapper): + + def __init__(self, arch: TileDevice): + super().__init__() + self.optimized_mod = None + self.arch = arch + self.lib = None + + def assign_optimized_module(self, optimized_mod: IRModule): + self.optimized_mod = optimized_mod + + # Get Scheduled Rt Module and return source to be compiled + def wrap(self, c_source:str, is_dynamic: bool = False): + wrapper_class = TIRCUDASourceWrapper if not is_dynamic else TIRCUDASourceWrapperWithDynamic + wrapper = wrapper_class(self.optimized_mod, c_source, self.arch) + return wrapper.lib_code diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index 0c41ab686..ac949711e 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -76,7 +76,11 @@ def _ensure_database_path(self, database_path): return database_path def _determine_arch_str(self, op_inst, target): - return (target if target else "-".join(list(op_inst.target.keys) + [op_inst.target.arch])) + return ( + target + if target + else "-".join(list(op_inst.target.keys) + [op_inst.target.arch]) + ) def _ensure_directory(self, path): os.makedirs(path, exist_ok=True) @@ -107,21 +111,25 @@ def _save_operator_config_and_artifact(self, config, op_inst, config_path): with open(optimized_file_path, "w") as optimized_file: if op_inst.optimized_func is not None: optimized_file.write(op_inst.optimized_func.script(show_meta=False)) - if op_inst.wrapper.lib_name is not None: + if op_inst.wrapper.libpath is not None: # copy lib name to the same directory as the artifact - src_name = op_inst.wrapper.src_name + srcpath = op_inst.wrapper.srcpath shutil.copy( - src_name, + srcpath, os.path.join(config_path, os.path.basename("wrapper_source.cu")), ) - lib_name = op_inst.wrapper.lib_name + libpath = op_inst.wrapper.libpath shutil.copy( - lib_name, + libpath, os.path.join(config_path, os.path.basename("wrapper_compiled.so")), ) def _determine_target_arch_str(self, target): - return (target if isinstance(target, str) else "-".join(list(target.keys) + [target.arch])) + return ( + target + if isinstance(target, str) + else "-".join(list(target.keys) + [target.arch]) + ) def _load_operators_from_arch_path(self, arch_path, target): for root, dirs, _ in os.walk(arch_path): @@ -130,7 +138,7 @@ def _load_operators_from_arch_path(self, arch_path, target): self._load_operator(config_path, target) def _load_operator(self, config_path, target): - mapping, config, rt_mod, src_name, lib_name = None, None, None, None, None + mapping, config, rt_mod, srcpath, libpath = None, None, None, None, None for file in os.listdir(config_path): full_path = os.path.join(config_path, file) if file == "mapping.json": @@ -142,19 +150,27 @@ def _load_operator(self, config_path, target): elif file.endswith(".tar"): rt_mod = tvm.runtime.load_module(full_path) elif file == "wrapper_compiled.so": - lib_name = full_path + libpath = full_path elif file == "wrapper_source.cu": - src_name = full_path + srcpath = full_path if mapping and config and rt_mod: - self._instantiate_and_add_operator(mapping, config, rt_mod, src_name, lib_name, target) + self._instantiate_and_add_operator( + mapping, config, rt_mod, srcpath, libpath, target + ) - def _instantiate_and_add_operator(self, mapping, config, rt_mod, src_name, lib_name, target): + def _instantiate_and_add_operator( + self, mapping, config, rt_mod, srcpath, libpath, target + ): config_cls = getattr(bitblas, mapping["config_type"]) operator_cls = getattr(bitblas, mapping["operator_type"]) op_inst = operator_cls( - config=config_cls(**config), target=target, enable_tuning=False, from_database=True) - op_inst.update_runtime_module(rt_mod, src_name=src_name, lib_name=lib_name) + config=config_cls(**config), + target=target, + enable_tuning=False, + from_database=True, + ) + op_inst.update_runtime_module(rt_mod, srcpath=srcpath, libpath=libpath) self.add(config_cls(**config), op_inst) diff --git a/bitblas/codegen/base.py b/bitblas/codegen/base.py deleted file mode 100644 index 974f631be..000000000 --- a/bitblas/codegen/base.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas.ops.operator import OperatorConfig -from abc import ABC, abstractmethod - - -class Backend(ABC): - """ - input: OperatorConfig - The duty of backend: - - is OperatorConfig is Available For our Backend - - Generate CUDA Source for compilation - """ - - def __init__(self, config: OperatorConfig): - self.config = config - - @abstractmethod - def compile(self, config): - pass - - @abstractmethod - def execute(self, *args, **kwargs): - pass - - @abstractmethod - def optimize(self, *args, **kwargs): - pass diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index cf16029dd..786b8043e 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -6,7 +6,7 @@ from functools import reduce from bitblas.base.arch.cuda import CUDA from typing import Any, Literal, Optional, Tuple, Union -from ..operator import Operator, TransformKind, OPExecutorCPU +from ..operator import OperatorConfig, Operator, TransformKind, OPExecutorCPU from .tirscript.matmul_dequantize_impl import select_implementation as weight_dequantize_implementation from .tirscript.matmul_impl import select_implementation as consistent_implementation from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 @@ -41,7 +41,7 @@ def is_native_compute(A_dtype, W_dtype) -> bool: @dataclass(frozen=True) -class MatmulConfig: +class MatmulConfig(OperatorConfig): M: Union[int, Tuple[int]] = None N: int = None K: int = None @@ -185,7 +185,7 @@ def __post_init__(self): class Matmul(Operator): - # TODO(lei): This should be improved into a general datatype. + # TODO(lei): This should be improved into a general datatype class. BITBLAS_TRICK_DTYPE_MAP = { "float64": ("fp", 64), "float32": ("fp", 32), @@ -281,6 +281,7 @@ def __init__( self.ladder_permutate_b = self._assign_ladder_permutate_b(target, enable_tuning) self.lop3_permutate = self._assign_lop3_permutate(target, enable_tuning) # create cpu weight executors + self.input_executors = self._create_input_executors() self.weight_executors = self._create_weight_executors() if enable_tuning: @@ -354,6 +355,12 @@ def _assign_lop3_permutate(self, target: Target, enable_tuning: bool): ) return None + def _create_input_executors(self): + input_executors = OPExecutorCPU() + if self.propagate_a is not TransformKind.NonTransform: + input_executors.append(self.ladder_permutate_a) + return input_executors + def _create_weight_executors(self): weight_executors = OPExecutorCPU() if self.fast_decoding: @@ -362,17 +369,6 @@ def _create_weight_executors(self): weight_executors.append(self.ladder_permutate_b) return weight_executors - def _build_default_module(self, target: Target): - try: - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) - except Exception: - self.optimized_func = None - logger.warning( - "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." - ) - - self._build_runtime_module(target) - def _select_implementation(self): if is_native_compute(self.A_dtype, self.W_dtype): return consistent_implementation( diff --git a/bitblas/ops/general_matmul/cuda/__init__.py b/bitblas/ops/general_matmul/cuda/__init__.py new file mode 100644 index 000000000..f3d1243bb --- /dev/null +++ b/bitblas/ops/general_matmul/cuda/__init__.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# TODO: Not Implemented Yet +from bitblas.ops.operator import TransformKind +from bitblas.base import TileDevice + +class MatmulDequantizeCudaEmitter: + + def __init__( + self, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, + ): + self.N = N + self.K = K + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.accum_dtype = accum_dtype + self.bit = bit + self.storage_dtype = storage_dtype + self.source_format = source_format + self.with_scaling = with_scaling + self.with_zeros = with_zeros + self.group_size = group_size if group_size != -1 else K + self.fast_decoding = fast_decoding + self.with_bias = with_bias + self.zeros_mode = zeros_mode + self.propagate_a = self._legalize_transform_kind(propagate_a) + self.propagate_b = self._legalize_transform_kind(propagate_b) + + def _legalize_group_size(self): + if self.group_size == -1: + self.group_size = self.K + + def _legalize_transform_kind(self, propagate): + if propagate is None: + return TransformKind.NonTransform + if isinstance(propagate, bool): + return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) + elif isinstance(propagate, int): + return TransformKind(propagate) + + def is_available(self, arch:TileDevice): + conditons = [] + # group size must be -1, 128, k + conditons.append(self.group_size in [-1, 128, self.K]) + # source format must be int + conditons.append(self.source_format == "int") + # with scaling must be true + conditons.append(self.with_scaling) + # with zeros must be false + conditons.append(not self.with_zeros) + # bit must be 4 + conditons.append(self.bit == 4) + # in_dtype must be float16 + conditons.append(self.in_dtype == "float16") + # out_dtype must be float16 + conditons.append(self.out_dtype == "float16") + # accum_dtype must be float32 + conditons.append(self.accum_dtype == "float32") + # sm version must be 80 (A100) + conditons.append(self.arch.sm_version == 80) + return all(conditons) + + def get_weight_transform(self): + raise NotImplementedError + + def get_scale_transform(self): + raise NotImplementedError + + def get_wrapped_source(self): + raise NotImplementedError \ No newline at end of file diff --git a/bitblas/ops/general_matmul/cuda/template.py b/bitblas/ops/general_matmul/cuda/template.py new file mode 100644 index 000000000..20cd28c45 --- /dev/null +++ b/bitblas/ops/general_matmul/cuda/template.py @@ -0,0 +1,2 @@ +template_source = """ +""" \ No newline at end of file diff --git a/bitblas/ops/general_matmul/backend/tir.py b/bitblas/ops/general_matmul/tilelang/__init__.py similarity index 70% rename from bitblas/ops/general_matmul/backend/tir.py rename to bitblas/ops/general_matmul/tilelang/__init__.py index 59e481eb9..197bd5e70 100644 --- a/bitblas/ops/general_matmul/backend/tir.py +++ b/bitblas/ops/general_matmul/tilelang/__init__.py @@ -1,2 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + +# TODO: Not Implemented Yet diff --git a/bitblas/ops/matmul.py b/bitblas/ops/matmul.py index 7334906c8..af0370294 100644 --- a/bitblas/ops/matmul.py +++ b/bitblas/ops/matmul.py @@ -161,17 +161,6 @@ def __init__( if enable_tuning: self.hardware_aware_finetune() - def _build_default_module(self, target: Target): - try: - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) - except Exception: - self.optimized_func = None - logger.warning( - "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." - ) - - self._build_runtime_module(target) - def _select_implementation(self): return select_implementation( M=self.M, diff --git a/bitblas/ops/matmul_dequantize.py b/bitblas/ops/matmul_dequantize.py index 7381b3f12..6971547b0 100644 --- a/bitblas/ops/matmul_dequantize.py +++ b/bitblas/ops/matmul_dequantize.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from bitblas import tvm from tvm.target import Target -from bitblas.base.roller.arch.cuda import CUDA +from bitblas.base.arch.cuda import CUDA from typing import Any, List, Literal, Optional, Tuple, Union from .operator import Operator, TransformKind from .impl.matmul_dequantize_impl import select_implementation @@ -198,17 +198,6 @@ def __init__( if enable_tuning: self.hardware_aware_finetune() - def _build_default_module(self, target: Target): - try: - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) - except Exception: - self.optimized_func = None - logger.warning( - "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." - ) - - self._build_runtime_module(target) - def _select_implementation(self): return select_implementation( M=self.M, diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index f3d391778..f73603d93 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -6,17 +6,16 @@ from tvm.target import Target from tvm.tir import PrimFunc from tvm.contrib.dlpack import to_pytorch_func -from tvm._ffi.base import _LIB, raise_last_ffi_error -from tvm._ffi._ctypes.types import TVMValue, ArgTypeCode import bitblas import ctypes from typing import List, Dict, Any, Optional import numpy as np from ..base import fast_tune, fast_tune_with_dynamic_range from copy import deepcopy -from bitblas.base.roller.arch import get_arch +from bitblas.base.arch import get_arch from bitblas.utils.tensor_adapter import tvm_tensor_to_torch -from bitblas.wrapper import CUDASourceWrapper, CUDASourceWrapperWithDynamic +from bitblas.builder.wrapper import TIRWrapper +from bitblas.builder.lib_generator import LibraryGenerator from dataclasses import dataclass from enum import IntEnum import logging @@ -30,11 +29,11 @@ class TransformKind(IntEnum): IntraWarpTransform = 2 -@dataclass +@dataclass(frozen=True) class OperatorConfig: """Base class for operator configurations. Used for typing.""" - pass + backend: Optional[str] = "tir" # Avaliable backends: tir, cuda, tile-languange. class Operator(ABC): @@ -58,9 +57,8 @@ def __init__(self, name, config: OperatorConfig, target: Target = None): self.num_output_args: int = ( 1 # todo(lei): should be analyzed from the prim_func. ) - self.wrapper = None - self.src_name = None - self.lib_name = None + self.lib_generator = LibraryGenerator(self.arch) + self.wrapper = TIRWrapper(self.arch) self.lib = None def get_source(self, target: Target = None) -> str: @@ -101,12 +99,13 @@ def tvm_callback_cuda_postproc(code, _): try: # Use a specific TVM pass context for CUDA platforms - with tvm.transform.PassContext(config={ - "tir.use_async_copy": True, - **self.pass_context - }): - rt_mod = tvm.build(self.optimized_func, target=target, name=self.name) - except Exception: # noqa: F841 + with tvm.transform.PassContext( + config={"tir.use_async_copy": True, **self.pass_context} + ): + rt_mod = tvm.build( + self.optimized_func, target=target, name=self.name + ) + except Exception: # noqa: F841 logger.debug( "Failed to build optimized function for CUDA target with default schedule, Please consider enable hardware aware tuning!" ) @@ -119,27 +118,32 @@ def tvm_callback_cuda_postproc(code, _): self.rt_mod = rt_mod # Initialize a time evaluator with the built module, specifying the device and the number of runs self.time_evaluator = rt_mod.time_evaluator( - rt_mod.entry_name, self.arch.device, number=10) + rt_mod.entry_name, self.arch.device, number=10 + ) self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle self.torch_func = to_pytorch_func(rt_mod) if self.arch.platform == "CUDA": try: - if (self.dynamic_range is not None and len(self.optimized_func.functions) > 1): - wrapper = CUDASourceWrapperWithDynamic(self.optimized_func, - self.get_source(target), self.arch) - else: - wrapper = CUDASourceWrapper(self.optimized_func, self.get_source(target), - self.arch) - wrapper.compile_lib() - self.wrapper = wrapper - self.src_name = self.wrapper.src_name - self.lib_name = self.wrapper.lib_name - self.lib = self.wrapper.load_lib() + is_dynamic = ( + self.dynamic_range is not None + and len(self.optimized_func.functions) > 1 + ) + self.wrapper.assign_optimized_module(self.optimized_func) + wrapped_source = self.wrapper.wrap( + self.get_source(target), is_dynamic + ) + self.lib_generator.update_lib_code(wrapped_source) + self.lib_generator.compile_lib() + self.lib = self.lib_generator.load_lib() self.lib.init() + except Exception as e: build_runtime_library_error = e logger.debug( - "Failed to build runtime library {}".format(build_runtime_library_error)) + "Failed to build runtime library {}".format( + build_runtime_library_error + ) + ) return rt_mod @@ -153,20 +157,32 @@ def apply_default_schedule(self, func_mod: IRModule, target: Target) -> IRModule bitblas.gpu.Reduction(), bitblas.gpu.GeneralReduction(), bitblas.gpu.Fallback(), - )(mod_for_opt)) + )(mod_for_opt) + ) if optimized_mod is not None: return optimized_mod return None + def _build_default_module(self, target: Target): + try: + self.optimized_func = self.apply_default_schedule( + self.prim_func_mod, target + ) + except Exception: + self.optimized_func = None + logger.warning( + "[BitBLAS][Warning] Apply default schedule failed. Please perform hardware-aware tuning manually." + ) + + self._build_runtime_module(target) + def post_process(self, code: str) -> str: return code - def apply_fast_tuning(self, - func: PrimFunc, - target: Target, - topk: int = 20, - parallel_build=True) -> IRModule: + def apply_fast_tuning( + self, func: PrimFunc, target: Target, topk: int = 20, parallel_build=True + ) -> IRModule: _, best = fast_tune(func, target, topk=topk, parallel_build=parallel_build) if best is not None: return best.sch.mod @@ -181,25 +197,27 @@ def apply_fast_tuning_with_dynamic_range( dynamic_range: Dict[str, List[int]] = None, ): optimized_mod = fast_tune_with_dynamic_range( - func, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range) + func, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range + ) if optimized_mod is not None: return optimized_mod return None - def hardware_aware_finetune(self, - topk: int = 20, - target: tvm.target.Target = None, - parallel_build=True): + def hardware_aware_finetune( + self, topk: int = 20, target: tvm.target.Target = None, parallel_build=True + ): if target is None: target = self.target dynamic_range = self.dynamic_range func = self.prim_func if dynamic_range is not None: self.optimized_func = self.apply_fast_tuning_with_dynamic_range( - func, target, topk, dynamic_range) + func, target, topk, dynamic_range + ) else: self.optimized_func = self.apply_fast_tuning( - func, target, topk, parallel_build=parallel_build) + func, target, topk, parallel_build=parallel_build + ) self._build_runtime_module(self.target) def get_profile_tensors(self, dynamic_symbolic_constrains: Optional[Dict] = None): @@ -222,8 +240,8 @@ def var_warpper(v): def map_numpy_type(intype): typemap = { - 'e4m3_float8': 'float8_e4m3fn', - 'e5m2_float8': 'float8_e5m2', + "e4m3_float8": "float8_e4m3fn", + "e5m2_float8": "float8_e5m2", } if intype in typemap: return typemap[intype] @@ -239,14 +257,18 @@ def map_numpy_type(intype): numpy_dtype = map_numpy_type(arg.dtype) profile_tensors.append( tvm.nd.array( - np.random.uniform(0, 1, - [var_warpper(i) for i in arg.shape]).astype(numpy_dtype), + np.random.uniform(0, 1, [var_warpper(i) for i in arg.shape]).astype( + numpy_dtype + ), device=device, - )) + ) + ) self.profile_tensors = profile_tensors return profile_tensors - def profile_latency(self, dynamic_symbolic_constrains: Optional[Dict] = None) -> str: + def profile_latency( + self, dynamic_symbolic_constrains: Optional[Dict] = None + ) -> str: if dynamic_symbolic_constrains is None: dynamic_symbolic_constrains = {} profile_tensors = self.get_profile_tensors(dynamic_symbolic_constrains) @@ -266,16 +288,9 @@ def _tensor_adapter(self, tensor, device): else: raise RuntimeError("Not supported type: ", type(tensor)) - def _forward_from_tvm_args(self, *args): - _tvm_args = [self._tensor_adapter(arg, self.arch.device) for arg in args] - self.rt_mod(*_tvm_args) - - def _forward_from_tvm_nd_array(self, *args): - self.rt_mod(*args) - def _forward_from_torch_func(self, *args): - # torch func is not reliable as some datatypes they don't support - # like float8. + # Torch func is not reliable as the runtime overhead dlpack + # is not negaliable, ref to https://discuss.tvm.apache.org/t/strange-overhead-of-tvm-runtime-ndarray-from-dlpack/16516 self.torch_func(*args) return args[-1] @@ -284,7 +299,8 @@ def forward(self, *args): def _forward_from_prebuild_lib(self, *args, stream=0): ctypes_args = [ - ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args + ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr + for arr in args ] ctypes_args.append(ctypes.c_void_p(stream)) self.lib.call(*ctypes_args) @@ -292,38 +308,24 @@ def _forward_from_prebuild_lib(self, *args, stream=0): def call_lib(self, *args, stream=0): self.lib.call(*args, ctypes.c_void_p(stream)) - def _forward_from_tvm_lib_func(self, values): - tcodes = (ctypes.c_int * self.num_args)() - ret_val = TVMValue() - ret_tcode = ctypes.c_int() - for i in range(self.num_args): - tcodes[i] = ArgTypeCode.NDARRAY_HANDLE - if (_LIB.TVMFuncCall( - self.function_handle, - values, - tcodes, - ctypes.c_int(self.num_args), - ctypes.byref(ret_val), - ctypes.byref(ret_tcode), - ) != 0): - raise_last_ffi_error() - def __call__(self, *args: Any) -> Any: return self.forward(*args) def update_func(self, func: PrimFunc): self.prim_func_mod["main"] = func - def update_runtime_module(self, rt_mod, src_name=None, lib_name=None): + def update_runtime_module(self, rt_mod, srcpath=None, libpath=None): self.rt_mod = rt_mod - self.time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, self.arch.device, number=10) + self.time_evaluator = rt_mod.time_evaluator( + rt_mod.entry_name, self.arch.device, number=10 + ) self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle self.torch_func = to_pytorch_func(rt_mod) - if src_name is not None: - self.src_name = src_name - if lib_name is not None: - self.lib_name = lib_name - self.lib = ctypes.CDLL(lib_name) + if srcpath is not None: + self.srcpath = srcpath + if libpath is not None: + self.libpath = libpath + self.lib = ctypes.CDLL(libpath) self.lib.init() @abstractmethod @@ -334,6 +336,22 @@ def _select_implementation(self) -> IRModule: def prim_func(self): return self.prim_func_mod["main"] + @property + def backend(self): + return self.config.backend + + @property + def srcpath(self): + return self.lib_generator.get_source_path() + + @property + def libpath(self): + return self.lib_generator.get_lib_path() + + @property + def wrapped_source(self): + return self.lib_generator.lib_code + class OPExecutorCPU: """ diff --git a/bitblas/utils/__init__.py b/bitblas/utils/__init__.py index 00bddc2a5..bdf9589f7 100644 --- a/bitblas/utils/__init__.py +++ b/bitblas/utils/__init__.py @@ -3,3 +3,4 @@ from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 # noqa: F401 from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401 from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401 +from .rtmod_analysis import get_annotated_device_mod # noqa: F401 diff --git a/bitblas/utils/rtmod_analysis.py b/bitblas/utils/rtmod_analysis.py new file mode 100644 index 000000000..15ce1bd0c --- /dev/null +++ b/bitblas/utils/rtmod_analysis.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm +from tvm import IRModule +from tvm.runtime import ndarray +from tvm.driver import lower +from tvm.target import Target +from typing import Tuple, List + +def get_annotated_device_mod(mod: IRModule, target: Target) -> "IRModule": + """ + Lower the given IRModule and create a device module for the specified target. + + Parameters: + - mod: The input IRModule. + - target: The compilation target. + + Returns: + - A device module ready for execution. + """ + input_mod = lower(mod) + target_input_mod = {target: input_mod} + annotated_mods = {} + runtime = None + target_host = None + for tgt, mod in target_input_mod.items(): + if not isinstance(tgt, (str, Target)): + raise ValueError("The key of inputs must be str or " + "Target when inputs is dict.") + if not isinstance(mod, tvm.IRModule): + raise ValueError("inputs must be Schedule, IRModule, " + "or dict of str to IRModule.") + annotated_mods[tgt] = mod.with_attr("runtime", runtime) + annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) + if not target_host: + for tar, _ in annotated_mods.items(): + device_type = ndarray.device(tar.kind.name, 0).device_type + if device_type == ndarray.cpu(0).device_type: + target_host = tar + break + if not target_host: + target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" + annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) + for target, mod in annotated_mods.items(): + mixed_mod_passes = tvm.get_global_func("driver.mixed_mod_passes") + device_mod_passes = tvm.get_global_func("driver.device_mod_passes") + mod = mixed_mod_passes(mod, target)(mod) + device_mod = device_mod_passes(mod, target)(mod) + return device_mod + + +def get_thread_block_information(mod: IRModule) -> Tuple[List[int], List[int]]: + """ + Extracts the thread block and grid dimensions for the reduction block within a given IRModule. + + Parameters: + - mod: The input IRModule from which to extract thread block and grid information. + + Returns: + A tuple containing two lists: + - The first list contains the dimensions of the thread block (threadIdx.x, threadIdx.y, threadIdx.z). + - The second list contains the dimensions of the grid (blockIdx.x, blockIdx.y, blockIdx.z). + """ + + # Initialize the schedule from the IRModule + sch = tvm.tir.Schedule(mod) + + # Get the root block and its child blocks + root_block = sch.get_block("root") + child_blocks = sch.get_child_blocks(root_block) + + # Initialize default block and grid dimensions (1, 1, 1) + block_dims, grid_dims = [1, 1, 1], [1, 1, 1] + + for block in child_blocks: + # Get the loops surrounding the main block + loops = sch.get_loops(block) + + # Iterate over each loop to extract thread and block bindings + for loop in loops: + stmt = sch.get(loop) + thread_binding = stmt.thread_binding + extent = int(stmt.extent) + + # Skip loops without thread binding + if thread_binding: + if "threadIdx" in thread_binding.thread_tag: + block_dims["xyz".index(thread_binding.thread_tag[-1])] = extent + elif "blockIdx" in thread_binding.thread_tag: + grid_dims["xyz".index(thread_binding.thread_tag[-1])] = extent + + return block_dims, grid_dims diff --git a/bitblas/wrapper/general.py b/bitblas/wrapper/general.py index 1271329f1..aa76f6158 100644 --- a/bitblas/wrapper/general.py +++ b/bitblas/wrapper/general.py @@ -131,23 +131,23 @@ def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): self.block_info: Union[List[int], Dict] = [1, 1, 1] self.grid_info: Union[List[int], Dict] = [1, 1, 1] self.parse_source_information() - self.src_name: Optional[str] = None - self.lib_name: Optional[str] = None + self.srcpath: Optional[str] = None + self.libpath: Optional[str] = None self.lib_code: Optional[str] = self.update_lib_code(source) def load_lib(self): - return ctypes.CDLL(self.lib_name) + return ctypes.CDLL(self.libpath) def remove_lib(self): - if self.lib_name: - os.remove(self.lib_name) - self.lib_name = None + if self.libpath: + os.remove(self.libpath) + self.libpath = None def compile_lib(self, timeout: float = None): arch = self.arch src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) compute_version = arch.compute_capability - lib_name = src.name.replace(".cu", ".so") + libpath = src.name.replace(".cu", ".so") command = [ "nvcc", @@ -162,7 +162,7 @@ def compile_lib(self, timeout: float = None): "-lcuda", f"-gencode=arch=compute_{compute_version},code=compute_{compute_version}", "-o", - lib_name, + libpath, ] src.write(self.lib_code) src.flush() @@ -174,8 +174,8 @@ def compile_lib(self, timeout: float = None): if ret.returncode != 0: logger.warning(f"Compilation Failed! {command}") return None - self.src_name = src.name - self.lib_name = lib_name + self.srcpath = src.name + self.libpath = libpath def parse_source_information(self): device_mod = get_annotated_device_mod(self.mod, self.arch.target) diff --git a/testing/python/builder/test_backend_tir_builder.py b/testing/python/builder/test_backend_tir_builder.py new file mode 100644 index 000000000..4a9cecd8a --- /dev/null +++ b/testing/python/builder/test_backend_tir_builder.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +from bitblas import MatmulConfig, Matmul +import logging +from bitblas import set_log_level +from bitblas.builder.backend.tir import TIRBackend +set_log_level(logging.DEBUG) + + +def get_codegen_result(ops): + code = ops.get_source() + return code + +def matmul_backend_code_wrap( + M, + N, + K, + A_dtype, + W_dtype, + accum_dtype, + out_dtype, + with_bias, +): + import torch + torch.random.manual_seed(0) + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + with_bias=with_bias, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False) + backend = TIRBackend(config=matmul_config, optimized_mod=matmul.optimized_func, arch=matmul.arch) + wrapped_code = backend.wrap(matmul.get_source(), is_dynamic=isinstance(M, list)) + assert "void call" in wrapped_code + +def test_matmul_transform_weight(): + matmul_backend_code_wrap(1, 768, 768, "float16", "uint4", "float16", "float16", False) + matmul_backend_code_wrap(768, 768, 768, "float16", "uint4", "float16", "float16", False) + matmul_backend_code_wrap([1, 768], 768, 768, "float16", "uint4", "float16", "float16", False) + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main() + test_matmul_transform_weight() From 53dd0dd6436978d63a6e451af03bea7f1e73e629 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 10 Jul 2024 16:17:54 +0000 Subject: [PATCH 24/88] remove ci pull. --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 511b95833..8cf347e57 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,6 @@ name: CI -on: [push, pull_request] +on: [pull_request] jobs: format-check: From d58ac4349e0548d7146df2e1b8e2963c7d2bf9d1 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 10 Jul 2024 16:29:46 +0000 Subject: [PATCH 25/88] LintFix --- bitblas/builder/lib_generator/__init__.py | 4 +- bitblas/builder/wrapper/__init__.py | 2 +- bitblas/builder/wrapper/base.py | 2 - bitblas/builder/wrapper/tir.py | 18 +- bitblas/cache/operator.py | 20 +- bitblas/ops/general_matmul/__init__.py | 8 +- bitblas/ops/general_matmul/cuda/__init__.py | 46 +- bitblas/ops/general_matmul/cuda/template.py | 832 +++++++++++++++++- .../ops/general_matmul/tilelang/__init__.py | 2 +- bitblas/ops/operator.py | 86 +- bitblas/utils/rtmod_analysis.py | 1 + .../builder/test_backend_tir_builder.py | 7 +- 12 files changed, 915 insertions(+), 113 deletions(-) diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py index 60b2ad6f9..f50d2557a 100644 --- a/bitblas/builder/lib_generator/__init__.py +++ b/bitblas/builder/lib_generator/__init__.py @@ -18,7 +18,7 @@ class LibraryGenerator(object): def __init__(self, arch: TileDevice): self.arch = arch - + def update_lib_code(self, lib_code: str): self.lib_code = lib_code @@ -67,6 +67,6 @@ def remove_lib(self): def get_source_path(self): return self.srcpath - + def get_lib_path(self): return self.libpath diff --git a/bitblas/builder/wrapper/__init__.py b/bitblas/builder/wrapper/__init__.py index 316a80ecc..c864f7a4b 100644 --- a/bitblas/builder/wrapper/__init__.py +++ b/bitblas/builder/wrapper/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .tir import TIRWrapper +from .tir import TIRWrapper # noqa: F401 diff --git a/bitblas/builder/wrapper/base.py b/bitblas/builder/wrapper/base.py index 200285f33..1705af2cc 100644 --- a/bitblas/builder/wrapper/base.py +++ b/bitblas/builder/wrapper/base.py @@ -4,8 +4,6 @@ class BaseWrapper(ABC): - def __init__(self): - pass @abstractmethod def wrap(self, *args, **kwargs): diff --git a/bitblas/builder/wrapper/tir.py b/bitblas/builder/wrapper/tir.py index f8e0ad1b1..5978856a8 100644 --- a/bitblas/builder/wrapper/tir.py +++ b/bitblas/builder/wrapper/tir.py @@ -8,23 +8,11 @@ from bitblas.utils.rtmod_analysis import get_annotated_device_mod import re from .base import BaseWrapper -from abc import ABC, abstractmethod -from bitblas import tvm -from tvm import IRModule -from tvm.target import Target -from tvm.tir import PrimFunc -from tvm.contrib.dlpack import to_pytorch_func -from tvm._ffi.base import _LIB, raise_last_ffi_error -from tvm._ffi._ctypes.types import TVMValue, ArgTypeCode -from typing import List, Dict, Optional import logging logger = logging.getLogger(__name__) -import logging -logger = logging.getLogger(__name__) - class TIRCUDASourceWrapper(object): _TYPE_MAP = { "float32": "float", @@ -186,6 +174,7 @@ def legalize_c(p): def prim_func(self): return self.mod["main"] + class TIRCUDASourceWrapperWithDynamic(TIRCUDASourceWrapper): def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): @@ -395,8 +384,9 @@ def compare_map_objects(map_obj): def prim_func(self): return self.mod["main"] + class TIRWrapper(BaseWrapper): - + def __init__(self, arch: TileDevice): super().__init__() self.optimized_mod = None @@ -407,7 +397,7 @@ def assign_optimized_module(self, optimized_mod: IRModule): self.optimized_mod = optimized_mod # Get Scheduled Rt Module and return source to be compiled - def wrap(self, c_source:str, is_dynamic: bool = False): + def wrap(self, c_source: str, is_dynamic: bool = False): wrapper_class = TIRCUDASourceWrapper if not is_dynamic else TIRCUDASourceWrapperWithDynamic wrapper = wrapper_class(self.optimized_mod, c_source, self.arch) return wrapper.lib_code diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index ac949711e..295630f5d 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -76,11 +76,7 @@ def _ensure_database_path(self, database_path): return database_path def _determine_arch_str(self, op_inst, target): - return ( - target - if target - else "-".join(list(op_inst.target.keys) + [op_inst.target.arch]) - ) + return (target if target else "-".join(list(op_inst.target.keys) + [op_inst.target.arch])) def _ensure_directory(self, path): os.makedirs(path, exist_ok=True) @@ -125,11 +121,7 @@ def _save_operator_config_and_artifact(self, config, op_inst, config_path): ) def _determine_target_arch_str(self, target): - return ( - target - if isinstance(target, str) - else "-".join(list(target.keys) + [target.arch]) - ) + return (target if isinstance(target, str) else "-".join(list(target.keys) + [target.arch])) def _load_operators_from_arch_path(self, arch_path, target): for root, dirs, _ in os.walk(arch_path): @@ -155,13 +147,9 @@ def _load_operator(self, config_path, target): srcpath = full_path if mapping and config and rt_mod: - self._instantiate_and_add_operator( - mapping, config, rt_mod, srcpath, libpath, target - ) + self._instantiate_and_add_operator(mapping, config, rt_mod, srcpath, libpath, target) - def _instantiate_and_add_operator( - self, mapping, config, rt_mod, srcpath, libpath, target - ): + def _instantiate_and_add_operator(self, mapping, config, rt_mod, srcpath, libpath, target): config_cls = getattr(bitblas, mapping["config_type"]) operator_cls = getattr(bitblas, mapping["operator_type"]) op_inst = operator_cls( diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 786b8043e..5c3f6d2e6 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -215,6 +215,7 @@ def __init__( target: Optional[Union[str, Target]] = None, enable_tuning: bool = True, from_database: bool = False, + backend: str = "tir", ): # if from database, we should disable default schedule # to save compilation time @@ -227,6 +228,7 @@ def __init__( self.source_format = source_format self.bit = bit + self.backend = backend super().__init__(name, config, target) if source_format == "int" and self.with_zeros: @@ -238,6 +240,10 @@ def __init__( if target.kind.name != "cuda": raise ValueError("Currently only support cuda target") + self.dispatch_tir(target, from_database, source_format, enable_tuning) + + def dispatch_tir(self, target: Target, from_database: bool = False, source_format: str = "uint", enable_tuning: bool = True): + '''Dispatch the tir script implementation''' self.arch = CUDA(target) if isinstance(self.M, Tuple): @@ -289,7 +295,7 @@ def __init__( # output data type self.torch_output_dtype = getattr(torch, self.out_dtype) - + def _alloc_workspace(self): return torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() diff --git a/bitblas/ops/general_matmul/cuda/__init__.py b/bitblas/ops/general_matmul/cuda/__init__.py index f3d1243bb..a0366abd3 100644 --- a/bitblas/ops/general_matmul/cuda/__init__.py +++ b/bitblas/ops/general_matmul/cuda/__init__.py @@ -1,12 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# TODO: Not Implemented Yet +# TODO: Not Implemented Yet from bitblas.ops.operator import TransformKind from bitblas.base import TileDevice +from .template import i4_scale_template_source + class MatmulDequantizeCudaEmitter: - + def __init__( self, M, @@ -55,34 +57,42 @@ def _legalize_transform_kind(self, propagate): return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) elif isinstance(propagate, int): return TransformKind(propagate) - - def is_available(self, arch:TileDevice): - conditons = [] + + def is_available(self, arch: TileDevice): + conditions = [] # group size must be -1, 128, k - conditons.append(self.group_size in [-1, 128, self.K]) + conditions.append(self.group_size in [-1, 128, self.K]) # source format must be int - conditons.append(self.source_format == "int") + conditions.append(self.source_format == "int") # with scaling must be true - conditons.append(self.with_scaling) + conditions.append(self.with_scaling) # with zeros must be false - conditons.append(not self.with_zeros) + conditions.append(not self.with_zeros) # bit must be 4 - conditons.append(self.bit == 4) + conditions.append(self.bit == 4) # in_dtype must be float16 - conditons.append(self.in_dtype == "float16") + conditions.append(self.in_dtype == "float16") # out_dtype must be float16 - conditons.append(self.out_dtype == "float16") + conditions.append(self.out_dtype == "float16") # accum_dtype must be float32 - conditons.append(self.accum_dtype == "float32") + conditions.append(self.accum_dtype == "float32") # sm version must be 80 (A100) - conditons.append(self.arch.sm_version == 80) - return all(conditons) - + conditions.append(self.arch.sm_version == 80) + return all(conditions) + def get_weight_transform(self): raise NotImplementedError def get_scale_transform(self): raise NotImplementedError - + def get_wrapped_source(self): - raise NotImplementedError \ No newline at end of file + wrapped_source = f""" + extern "C" void init() {{ + + }} + extern "C" void call(half* __restrict__ A, int8_t* __restrict__ B, half* __restrict__ Scale, half* __restrict__ C, int m, void* workspace, cudaStream_t stream=cudaStreamDefault) {{ + marlin_cuda(A, B, C, Scale, m, {self.N}, {self.K}, workspace, {self.group_size}, 0, -1, -1, 108, 16); + }} + """ + return i4_scale_template_source + wrapped_source diff --git a/bitblas/ops/general_matmul/cuda/template.py b/bitblas/ops/general_matmul/cuda/template.py index 20cd28c45..a088e12fa 100644 --- a/bitblas/ops/general_matmul/cuda/template.py +++ b/bitblas/ops/general_matmul/cuda/template.py @@ -1,2 +1,830 @@ -template_source = """ -""" \ No newline at end of file +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +i4_scale_template_source = """ +// Copyright 2018 The apache/tvm Authors. All Rights Reserved. +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// Modifications Copyright (c) Microsoft. +// The code below is mostly copied from marlin_cuda in IST-DASLab/marlin. + +#ifndef MARLIN_CUDA_KERNEL_CUH +#define MARLIN_CUDA_KERNEL_CUH + + +#include +#include +#include +#include + + +constexpr int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core +// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { + return elems[i]; + } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that +// are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for +// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need +// for inputs A and outputs C. +__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .b64 p;\n" + " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" + " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" :: "n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]) + ); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem) + ); +} + +// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to +// automatically recognize it in all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut) + ); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. +// We mostly follow the strategy in the link below, with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2( + *reinterpret_cast(&lo), + *reinterpret_cast(&SUB) + ); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible globally. + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier. + asm volatile ("fence.acq_rel.gpu;\n"); + asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + + +template < + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const int stages, // number of stages for the async global->shared fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale +> +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple + // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs + // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as + // possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case + // where a stripe starts in the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to top + + // We can easily implement parallel problem execution by just remapping indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for synchronization. + auto init_slice = [&] () { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time constant + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major + // layout in the former and in row-major in the latter case. + if (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than + // required for a certain tilesize or when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank + // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of + // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based + // on NSight-Compute) that each warp must also write a consecutive memory segment? + auto transform_a = [&] (int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory + // accesses are static, we simply precompute both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependencies between + // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&] () { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location. + auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i] + ); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) + cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&] () { + // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when + // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer. + auto fetch_to_registers = [&] (int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a + // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the + // compiler and correspondingly a noticeable drop in performance. + if (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&] (int k) { + // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + FragB frag_b0 = dequant(b_quant); + // If there are no groups, we can just scale the final output once and can avoid doing so for each weight. + if (group_blocks != -1) + scale(frag_b0, frag_s[k % 2][j], 0); + FragB frag_b1 = dequant(b_quant_shift); + if (group_blocks != -1) + scale(frag_b1, frag_s[k % 2][j], 1); + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n + // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&] () { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations, + // e.g., for two warps we write only once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over + // the results. As the striped partitioning minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&] (bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. + // To do this, we write out results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns, + // hence we also use async-copies even though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m + ); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float( + reinterpret_cast<__half*>(&c_red)[j] + ); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = __float2half( + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] + ); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step, + // the reduction above is performed in fragment layout. + auto write_result = [&] () { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final global write patterns + auto write = [&] (int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + if (group_blocks == -1) // for per-column quantization we finally apply the scale here + res = __hmul2(res, s[0]); + ((half2*) sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&] () { + #pragma unroll + for (int i = 0; i < stages - 1; i++) + fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are + // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) + break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most + // readable, other ways of writing the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before write-out + if (group_blocks == -1 && last) { + if (s_sh_wr_pred) + cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + thread_block_reduce(); + if (group_blocks == -1 && last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + if (slice_count > 1) { // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + + +// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more +// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. +const int THREADS = 256; +const int STAGES = 4; // 4 pipeline stages fit into shared memory +const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if ( \ + thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS \ + ) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + SHARED_MEM \ + ); \ + Marlin< \ + THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \ + ><<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, \ + prob_m, prob_n, prob_k, \ + locks \ + ); \ + } + +const int ERR_PROB_SHAPE = 1; +const int ERR_KERN_SHAPE = 2; + +int marlin_cuda( + const void* A, + const void* B, + void* C, + void* s, + int prob_m, + int prob_n, + int prob_k, + void* workspace, + int groupsize = -1, + int dev = 0, + cudaStream_t stream = 0, + int thread_k = -1, + int thread_n = -1, + int sms = -1, + int max_par = 16 +) { + int tot_m = prob_m; + int tot_m_blocks = ceildiv(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + if (thread_k == -1 || thread_n == -1) { + if (prob_m <= 16) { + // For small batchizes, better partitioning is slightly more important than better compute utilization + thread_k = 128; + thread_n = 128; + } else { + thread_k = 64; + thread_n = 256; + } + } + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0)) + return ERR_PROB_SHAPE; + if (prob_m == 0 || prob_n == 0 || prob_k == 0) + return 0; + + const int4* A_ptr = (const int4*) A; + const int4* B_ptr = (const int4*) B; + int4* C_ptr = (int4*) C; + const int4* s_ptr = (const int4*) s; + + int cols = prob_n / thread_n; + int* locks = (int*) workspace; + + int ret = 0; + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) + par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance) + // in our testing, however many more are, in principle, possible. + if (false) {} + CALL_IF(1, 8, 8, -1) + CALL_IF(1, 8, 8, 8) + CALL_IF(1, 16, 4, -1) + CALL_IF(1, 16, 4, 8) + CALL_IF(2, 16, 4, -1) + CALL_IF(2, 16, 4, 8) + CALL_IF(3, 16, 4, -1) + CALL_IF(3, 16, 4, 8) + CALL_IF(4, 16, 4, -1) + CALL_IF(4, 16, 4, 8) + else + ret = ERR_KERN_SHAPE; + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } + + return ret; +} + + +#endif +""" diff --git a/bitblas/ops/general_matmul/tilelang/__init__.py b/bitblas/ops/general_matmul/tilelang/__init__.py index 197bd5e70..92956855c 100644 --- a/bitblas/ops/general_matmul/tilelang/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# TODO: Not Implemented Yet +# TODO: Not Implemented Yet diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index f73603d93..2e9078727 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -32,8 +32,7 @@ class TransformKind(IntEnum): @dataclass(frozen=True) class OperatorConfig: """Base class for operator configurations. Used for typing.""" - - backend: Optional[str] = "tir" # Avaliable backends: tir, cuda, tile-languange. + pass class Operator(ABC): @@ -99,12 +98,11 @@ def tvm_callback_cuda_postproc(code, _): try: # Use a specific TVM pass context for CUDA platforms - with tvm.transform.PassContext( - config={"tir.use_async_copy": True, **self.pass_context} - ): - rt_mod = tvm.build( - self.optimized_func, target=target, name=self.name - ) + with tvm.transform.PassContext(config={ + "tir.use_async_copy": True, + **self.pass_context + }): + rt_mod = tvm.build(self.optimized_func, target=target, name=self.name) except Exception: # noqa: F841 logger.debug( "Failed to build optimized function for CUDA target with default schedule, Please consider enable hardware aware tuning!" @@ -118,20 +116,15 @@ def tvm_callback_cuda_postproc(code, _): self.rt_mod = rt_mod # Initialize a time evaluator with the built module, specifying the device and the number of runs self.time_evaluator = rt_mod.time_evaluator( - rt_mod.entry_name, self.arch.device, number=10 - ) + rt_mod.entry_name, self.arch.device, number=10) self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle self.torch_func = to_pytorch_func(rt_mod) if self.arch.platform == "CUDA": try: is_dynamic = ( - self.dynamic_range is not None - and len(self.optimized_func.functions) > 1 - ) + self.dynamic_range is not None and len(self.optimized_func.functions) > 1) self.wrapper.assign_optimized_module(self.optimized_func) - wrapped_source = self.wrapper.wrap( - self.get_source(target), is_dynamic - ) + wrapped_source = self.wrapper.wrap(self.get_source(target), is_dynamic) self.lib_generator.update_lib_code(wrapped_source) self.lib_generator.compile_lib() self.lib = self.lib_generator.load_lib() @@ -140,10 +133,7 @@ def tvm_callback_cuda_postproc(code, _): except Exception as e: build_runtime_library_error = e logger.debug( - "Failed to build runtime library {}".format( - build_runtime_library_error - ) - ) + "Failed to build runtime library {}".format(build_runtime_library_error)) return rt_mod @@ -157,8 +147,7 @@ def apply_default_schedule(self, func_mod: IRModule, target: Target) -> IRModule bitblas.gpu.Reduction(), bitblas.gpu.GeneralReduction(), bitblas.gpu.Fallback(), - )(mod_for_opt) - ) + )(mod_for_opt)) if optimized_mod is not None: return optimized_mod @@ -166,9 +155,7 @@ def apply_default_schedule(self, func_mod: IRModule, target: Target) -> IRModule def _build_default_module(self, target: Target): try: - self.optimized_func = self.apply_default_schedule( - self.prim_func_mod, target - ) + self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) except Exception: self.optimized_func = None logger.warning( @@ -180,9 +167,11 @@ def _build_default_module(self, target: Target): def post_process(self, code: str) -> str: return code - def apply_fast_tuning( - self, func: PrimFunc, target: Target, topk: int = 20, parallel_build=True - ) -> IRModule: + def apply_fast_tuning(self, + func: PrimFunc, + target: Target, + topk: int = 20, + parallel_build=True) -> IRModule: _, best = fast_tune(func, target, topk=topk, parallel_build=parallel_build) if best is not None: return best.sch.mod @@ -197,27 +186,25 @@ def apply_fast_tuning_with_dynamic_range( dynamic_range: Dict[str, List[int]] = None, ): optimized_mod = fast_tune_with_dynamic_range( - func, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range - ) + func, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range) if optimized_mod is not None: return optimized_mod return None - def hardware_aware_finetune( - self, topk: int = 20, target: tvm.target.Target = None, parallel_build=True - ): + def hardware_aware_finetune(self, + topk: int = 20, + target: tvm.target.Target = None, + parallel_build=True): if target is None: target = self.target dynamic_range = self.dynamic_range func = self.prim_func if dynamic_range is not None: self.optimized_func = self.apply_fast_tuning_with_dynamic_range( - func, target, topk, dynamic_range - ) + func, target, topk, dynamic_range) else: self.optimized_func = self.apply_fast_tuning( - func, target, topk, parallel_build=parallel_build - ) + func, target, topk, parallel_build=parallel_build) self._build_runtime_module(self.target) def get_profile_tensors(self, dynamic_symbolic_constrains: Optional[Dict] = None): @@ -257,18 +244,14 @@ def map_numpy_type(intype): numpy_dtype = map_numpy_type(arg.dtype) profile_tensors.append( tvm.nd.array( - np.random.uniform(0, 1, [var_warpper(i) for i in arg.shape]).astype( - numpy_dtype - ), + np.random.uniform(0, 1, + [var_warpper(i) for i in arg.shape]).astype(numpy_dtype), device=device, - ) - ) + )) self.profile_tensors = profile_tensors return profile_tensors - def profile_latency( - self, dynamic_symbolic_constrains: Optional[Dict] = None - ) -> str: + def profile_latency(self, dynamic_symbolic_constrains: Optional[Dict] = None) -> str: if dynamic_symbolic_constrains is None: dynamic_symbolic_constrains = {} profile_tensors = self.get_profile_tensors(dynamic_symbolic_constrains) @@ -299,8 +282,7 @@ def forward(self, *args): def _forward_from_prebuild_lib(self, *args, stream=0): ctypes_args = [ - ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr - for arr in args + ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args ] ctypes_args.append(ctypes.c_void_p(stream)) self.lib.call(*ctypes_args) @@ -316,9 +298,7 @@ def update_func(self, func: PrimFunc): def update_runtime_module(self, rt_mod, srcpath=None, libpath=None): self.rt_mod = rt_mod - self.time_evaluator = rt_mod.time_evaluator( - rt_mod.entry_name, self.arch.device, number=10 - ) + self.time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, self.arch.device, number=10) self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle self.torch_func = to_pytorch_func(rt_mod) if srcpath is not None: @@ -336,10 +316,6 @@ def _select_implementation(self) -> IRModule: def prim_func(self): return self.prim_func_mod["main"] - @property - def backend(self): - return self.config.backend - @property def srcpath(self): return self.lib_generator.get_source_path() @@ -347,7 +323,7 @@ def srcpath(self): @property def libpath(self): return self.lib_generator.get_lib_path() - + @property def wrapped_source(self): return self.lib_generator.lib_code diff --git a/bitblas/utils/rtmod_analysis.py b/bitblas/utils/rtmod_analysis.py index 15ce1bd0c..69a08dfdc 100644 --- a/bitblas/utils/rtmod_analysis.py +++ b/bitblas/utils/rtmod_analysis.py @@ -7,6 +7,7 @@ from tvm.target import Target from typing import Tuple, List + def get_annotated_device_mod(mod: IRModule, target: Target) -> "IRModule": """ Lower the given IRModule and create a device module for the specified target. diff --git a/testing/python/builder/test_backend_tir_builder.py b/testing/python/builder/test_backend_tir_builder.py index 4a9cecd8a..f384a9d41 100644 --- a/testing/python/builder/test_backend_tir_builder.py +++ b/testing/python/builder/test_backend_tir_builder.py @@ -5,6 +5,7 @@ import logging from bitblas import set_log_level from bitblas.builder.backend.tir import TIRBackend + set_log_level(logging.DEBUG) @@ -12,6 +13,7 @@ def get_codegen_result(ops): code = ops.get_source() return code + def matmul_backend_code_wrap( M, N, @@ -36,15 +38,18 @@ def matmul_backend_code_wrap( with_bias=with_bias, ) matmul = Matmul(config=matmul_config, enable_tuning=False) - backend = TIRBackend(config=matmul_config, optimized_mod=matmul.optimized_func, arch=matmul.arch) + backend = TIRBackend( + config=matmul_config, optimized_mod=matmul.optimized_func, arch=matmul.arch) wrapped_code = backend.wrap(matmul.get_source(), is_dynamic=isinstance(M, list)) assert "void call" in wrapped_code + def test_matmul_transform_weight(): matmul_backend_code_wrap(1, 768, 768, "float16", "uint4", "float16", "float16", False) matmul_backend_code_wrap(768, 768, 768, "float16", "uint4", "float16", "float16", False) matmul_backend_code_wrap([1, 768], 768, 768, "float16", "uint4", "float16", "float16", False) + # fmt: on if __name__ == "__main__": bitblas.testing.main() From 37cb07cbd062a36fe19a182274ab0a76a06cdf19 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 10 Jul 2024 17:18:16 +0000 Subject: [PATCH 26/88] refactor builder for whl build --- bitblas/builder/__init__.py | 5 +++++ bitblas/builder/wrapper.py | 0 2 files changed, 5 insertions(+) create mode 100644 bitblas/builder/__init__.py delete mode 100644 bitblas/builder/wrapper.py diff --git a/bitblas/builder/__init__.py b/bitblas/builder/__init__.py new file mode 100644 index 000000000..8e4d715c9 --- /dev/null +++ b/bitblas/builder/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .lib_generator import LibraryGenerator # noqa: F401 +from .wrapper import TIRWrapper # noqa: F401 diff --git a/bitblas/builder/wrapper.py b/bitblas/builder/wrapper.py deleted file mode 100644 index e69de29bb..000000000 From f5b9999e4b26059e2e5244c95f3c8f6bad2bf32c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 11 Jul 2024 02:52:45 +0000 Subject: [PATCH 27/88] Refactor TIRWrapper.wrap() method to include an assertion for the optimized module --- bitblas/builder/wrapper/tir.py | 1 + testing/python/builder/test_backend_tir_builder.py | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bitblas/builder/wrapper/tir.py b/bitblas/builder/wrapper/tir.py index 5978856a8..2d0162f66 100644 --- a/bitblas/builder/wrapper/tir.py +++ b/bitblas/builder/wrapper/tir.py @@ -398,6 +398,7 @@ def assign_optimized_module(self, optimized_mod: IRModule): # Get Scheduled Rt Module and return source to be compiled def wrap(self, c_source: str, is_dynamic: bool = False): + assert self.optimized_mod is not None, "Please assign optimized module first." wrapper_class = TIRCUDASourceWrapper if not is_dynamic else TIRCUDASourceWrapperWithDynamic wrapper = wrapper_class(self.optimized_mod, c_source, self.arch) return wrapper.lib_code diff --git a/testing/python/builder/test_backend_tir_builder.py b/testing/python/builder/test_backend_tir_builder.py index f384a9d41..22c134b12 100644 --- a/testing/python/builder/test_backend_tir_builder.py +++ b/testing/python/builder/test_backend_tir_builder.py @@ -4,7 +4,7 @@ from bitblas import MatmulConfig, Matmul import logging from bitblas import set_log_level -from bitblas.builder.backend.tir import TIRBackend +from bitblas.builder.wrapper import TIRWrapper set_log_level(logging.DEBUG) @@ -38,8 +38,8 @@ def matmul_backend_code_wrap( with_bias=with_bias, ) matmul = Matmul(config=matmul_config, enable_tuning=False) - backend = TIRBackend( - config=matmul_config, optimized_mod=matmul.optimized_func, arch=matmul.arch) + backend = TIRWrapper(arch=matmul.arch) + backend.assign_optimized_module(matmul.optimized_func) wrapped_code = backend.wrap(matmul.get_source(), is_dynamic=isinstance(M, list)) assert "void call" in wrapped_code @@ -53,4 +53,3 @@ def test_matmul_transform_weight(): # fmt: on if __name__ == "__main__": bitblas.testing.main() - test_matmul_transform_weight() From fb78244d3a083d9c1901edb3476c482fc3365f7d Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 11 Jul 2024 11:43:46 +0000 Subject: [PATCH 28/88] Refactor lib_generator to set library and source paths --- bitblas/builder/lib_generator/__init__.py | 6 ++++++ bitblas/ops/operator.py | 6 ++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py index f50d2557a..576e32de4 100644 --- a/bitblas/builder/lib_generator/__init__.py +++ b/bitblas/builder/lib_generator/__init__.py @@ -70,3 +70,9 @@ def get_source_path(self): def get_lib_path(self): return self.libpath + + def set_lib_path(self, libpath): + self.libpath = libpath + + def set_src_path(self, srcpath): + self.srcpath = srcpath diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 2e9078727..39fc2e785 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -302,9 +302,11 @@ def update_runtime_module(self, rt_mod, srcpath=None, libpath=None): self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle self.torch_func = to_pytorch_func(rt_mod) if srcpath is not None: - self.srcpath = srcpath + assert self.lib_generator is not None, "lib_generator is not initialized" + self.lib_generator.set_src_path(srcpath) if libpath is not None: - self.libpath = libpath + assert self.lib_generator is not None, "lib_generator is not initialized" + self.lib_generator.set_lib_path(libpath) self.lib = ctypes.CDLL(libpath) self.lib.init() From 706e2270b948eae21c41e5b80bafa79613896ab3 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 11 Jul 2024 11:49:14 +0000 Subject: [PATCH 29/88] lint fix --- bitblas/builder/lib_generator/__init__.py | 2 +- bitblas/ops/operator.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py index 576e32de4..a0800751a 100644 --- a/bitblas/builder/lib_generator/__init__.py +++ b/bitblas/builder/lib_generator/__init__.py @@ -73,6 +73,6 @@ def get_lib_path(self): def set_lib_path(self, libpath): self.libpath = libpath - + def set_src_path(self, srcpath): self.srcpath = srcpath diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 39fc2e785..9c592f9f2 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -309,6 +309,7 @@ def update_runtime_module(self, rt_mod, srcpath=None, libpath=None): self.lib_generator.set_lib_path(libpath) self.lib = ctypes.CDLL(libpath) self.lib.init() + # TODO: update the lib code from srcpath @abstractmethod def _select_implementation(self) -> IRModule: From 63f5515f3586b1fb5661e9c883be8a8be039a8c0 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 16 Jul 2024 14:28:01 +0000 Subject: [PATCH 30/88] BitNet vllm integration --- bitblas/cache/operator.py | 17 +-- integration/BitNet/.gitignore | 1 + integration/BitNet/create_bitblas_ckpt.py | 118 +++++++++++++++++++++ integration/BitNet/eval_correctness.py | 58 +++++++++-- integration/BitNet/load_from_quantized.py | 80 +++++++++++++++ integration/BitNet/modeling_bitnet.py | 120 +++++++++++++++++++++- integration/BitNet/utils_quant.py | 92 ++++++++++++----- 7 files changed, 435 insertions(+), 51 deletions(-) create mode 100644 integration/BitNet/.gitignore create mode 100644 integration/BitNet/create_bitblas_ckpt.py create mode 100644 integration/BitNet/load_from_quantized.py diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index 295630f5d..e0c825a9a 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -15,7 +15,8 @@ logger = logging.getLogger(__name__) BITBLAS_DATABASE_PATH = os.path.expanduser("~/.cache/bitblas") - +BITBLAS_WRAPPED_SOURCE_NAME = "wrapper_source.cu" +BITBLAS_WRAPPED_COMPILED_NAME = "wrapper_compiled.so" class OperatorCache: """ @@ -107,17 +108,17 @@ def _save_operator_config_and_artifact(self, config, op_inst, config_path): with open(optimized_file_path, "w") as optimized_file: if op_inst.optimized_func is not None: optimized_file.write(op_inst.optimized_func.script(show_meta=False)) - if op_inst.wrapper.libpath is not None: + if op_inst.libpath is not None: # copy lib name to the same directory as the artifact - srcpath = op_inst.wrapper.srcpath + srcpath = op_inst.srcpath shutil.copy( srcpath, - os.path.join(config_path, os.path.basename("wrapper_source.cu")), + os.path.join(config_path, os.path.basename(BITBLAS_WRAPPED_SOURCE_NAME)), ) - libpath = op_inst.wrapper.libpath + libpath = op_inst.libpath shutil.copy( libpath, - os.path.join(config_path, os.path.basename("wrapper_compiled.so")), + os.path.join(config_path, os.path.basename(BITBLAS_WRAPPED_COMPILED_NAME)), ) def _determine_target_arch_str(self, target): @@ -141,9 +142,9 @@ def _load_operator(self, config_path, target): config = json.load(f) elif file.endswith(".tar"): rt_mod = tvm.runtime.load_module(full_path) - elif file == "wrapper_compiled.so": + elif file == BITBLAS_WRAPPED_COMPILED_NAME: libpath = full_path - elif file == "wrapper_source.cu": + elif file == BITBLAS_WRAPPED_SOURCE_NAME: srcpath = full_path if mapping and config and rt_mod: diff --git a/integration/BitNet/.gitignore b/integration/BitNet/.gitignore new file mode 100644 index 000000000..6ea887496 --- /dev/null +++ b/integration/BitNet/.gitignore @@ -0,0 +1 @@ +models/ \ No newline at end of file diff --git a/integration/BitNet/create_bitblas_ckpt.py b/integration/BitNet/create_bitblas_ckpt.py new file mode 100644 index 000000000..b920c19bd --- /dev/null +++ b/integration/BitNet/create_bitblas_ckpt.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import bitblas +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer +from transformers.utils.hub import cached_file +import os +from transformers import GenerationConfig +import time +import json + +filepath = os.path.abspath(__file__) +dirpath = os.path.dirname(filepath) + +torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + +model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits" +saved_model_path = os.path.join( + dirpath, "models", f"{model_name_or_path}_bitblas" +) + +def generate_text(model, tokenizer, prompt, max_length=100): + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) + # Generate cos and sin values + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + end_time = time.time() + + generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + generation_time = end_time - start_time + num_tokens = len(output_ids[0]) + tokens_per_second = num_tokens / generation_time + + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_text + + +def main(): + model = ( + BitnetForCausalLM.from_pretrained( + model_name_or_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ) + .cuda() + .half() + ) + tokenizer = BitnetTokenizer.from_pretrained( + model_name_or_path, use_fast=False + ) + + # print("original model generated text:") + # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) + input_ids = torch.ones((1, 1), dtype=torch.long).cuda() + # naive model inference + output = model(input_ids) + print("original model output:", output) + + model.quantize() + print("original model generated text:") + print(generate_text(model, tokenizer, "Hi, ", max_length=100)) + + model.save_pretrained(saved_model_path) + + # load quant config + quant_config_path = cached_file(model_name_or_path, "quantize_config.json") + with open(quant_config_path, "r") as f: + quant_config = json.load(f) + print("quant config:") + print(quant_config) + quant_config["checkpoint_format"] = "bitblas" + + # save quant config + quant_config_path = os.path.join(saved_model_path, "quantize_config.json") + with open(quant_config_path, "w") as f: + json.dump(quant_config, f) + print("quant config saved to:", quant_config_path) + + # copy benchmark filed into saved model path + file_list = [ + "configuration_bitnet.py", + "eval_utils.py", + "modeling_bitnet.py", + "tokenization_bitnet.py", + "utils_quant.py", + "README.md", + ] + for file in file_list: + file_path = cached_file(model_name_or_path, file) + os.system(f"cp {file_path} {saved_model_path}") + # load quantized model + qmodel = BitnetForCausalLM.from_quantized( + saved_model_path, + ).cuda().half() + print("quantized model generated text:") + print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) + + +if __name__ == '__main__': + main() diff --git a/integration/BitNet/eval_correctness.py b/integration/BitNet/eval_correctness.py index 578715da4..22172e1a0 100644 --- a/integration/BitNet/eval_correctness.py +++ b/integration/BitNet/eval_correctness.py @@ -3,17 +3,51 @@ import argparse import torch - +import bitblas from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer +from transformers import GenerationConfig +import time torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + +def generate_text(model, tokenizer, prompt, max_length=100): + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) + # Generate cos and sin values + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + # position_embeddings = model.embed_positions(position_ids) + # cos = position_embeddings[:, :, 0::2].cos() + # sin = position_embeddings[:, :, 1::2].sin() + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + # output_ids = model.generate(input_ids, generation_config=generation_config, cos=cos, sin=sin) + end_time = time.time() + + generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + generation_time = end_time - start_time + num_tokens = len(output_ids[0]) + tokens_per_second = num_tokens / generation_time -parser = argparse.ArgumentParser() -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_text def profile(model, input_data): - import time import numpy as np model = model.cuda() @@ -35,24 +69,26 @@ def get_runtime(num_repeats=1): times = get_runtime(num_repeats) return np.mean(times) - +model_path = '1bitLLM/bitnet_b1_58-3B' def main(): model = BitnetForCausalLM.from_pretrained( - '1bitLLM/bitnet_b1_58-3B', + model_path, use_flash_attention_2=True, torch_dtype=torch.float16, ).cuda().half() with torch.no_grad(): model._post_process_weights() - input_id = torch.ones(1, 1).long().cuda() - - # test forward + # input_id = torch.ones(1, 1).long().cuda() + tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False) + input_id = tokenizer("Hello")['input_ids'] + input_id = torch.tensor(input_id).unsqueeze(0).cuda() output = model(input_id) - - # make sure the output is the same as the simulated output print(output) + print(generate_text(model, tokenizer, "Hello", max_length=100)) + + if __name__ == '__main__': main() diff --git a/integration/BitNet/load_from_quantized.py b/integration/BitNet/load_from_quantized.py new file mode 100644 index 000000000..197f7dd15 --- /dev/null +++ b/integration/BitNet/load_from_quantized.py @@ -0,0 +1,80 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import bitblas +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer +from transformers.utils.hub import cached_file +import os +from transformers import GenerationConfig +import time +import json + +filepath = os.path.abspath(__file__) +dirpath = os.path.dirname(filepath) + +torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + +model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits" +saved_model_path = os.path.join( + dirpath, "models", f"{model_name_or_path}_bitblas" +) + + +def generate_text(model, tokenizer, prompt, max_length=100): + input_ids = tokenizer.encode(prompt, return_tensors="pt").to( + model.lm_head.weight.device + ) + # Generate cos and sin values + seq_length = input_ids.size(1) + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device + ) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + end_time = time.time() + + generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + generation_time = end_time - start_time + num_tokens = len(output_ids[0]) + tokens_per_second = num_tokens / generation_time + + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_text + + +def main(): + # load quantized model + qmodel = BitnetForCausalLM.from_quantized( + saved_model_path, + ).cuda().half() + tokenizer = BitnetTokenizer.from_pretrained( + model_name_or_path, use_fast=False + ) + # print("original model generated text:") + # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) + input_ids = torch.ones((1, 1), dtype=torch.long).cuda() + # naive model inference + output = qmodel(input_ids) + print("original model output:", output) + print("quantized model generated text:") + print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) + + +if __name__ == "__main__": + main() diff --git a/integration/BitNet/modeling_bitnet.py b/integration/BitNet/modeling_bitnet.py index 11be4059f..49da46318 100644 --- a/integration/BitNet/modeling_bitnet.py +++ b/integration/BitNet/modeling_bitnet.py @@ -49,13 +49,24 @@ replace_return_docstrings, ) from configuration_bitnet import BitnetConfig -from utils_quant import BitLinear +from utils_quant import BitLinear, BitLinearBitBLAS +from transformers.utils.hub import cached_file if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401 +def find_layers(module, layers=None, name=""): + if not layers: + layers = [nn.Linear] + for layer in layers: + if isinstance(module, layer): + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) + return res logger = logging.get_logger(__name__) @@ -538,7 +549,6 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) - LLAMA_ATTENTION_CLASSES = { "eager": BitnetAttention, "flash_attention_2": BitnetFlashAttention2, @@ -961,7 +971,7 @@ def __init__(self, config): self.model = BitnetModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.quantized = False # Initialize weights and apply final processing self.post_init() @@ -1164,13 +1174,113 @@ def _reorder_cache(past_key_values, beam_idx): ) return reordered_past - + @staticmethod + def recursive_set(model, name, attr): + ''' + set layers.25.mlp.up_proj to attr + ''' + + names = name.split('.') + obj = model + for n in names[:-1]: + obj = getattr(obj, n) + setattr(obj, names[-1], attr) + + def quantize(self): + for name, module in self.model.named_modules(): + # if is bitnet layer + if isinstance(module, BitLinear): + # create quantized version of the layer + print("Quantizing module", name) + bitblas_linear = BitLinearBitBLAS.from_bit_linear( + module + ) + print("Replacing module", name, "with a quantized version") + self.recursive_set(self.model, name, bitblas_linear) + self.quantized = True + def _post_process_weights(self): for name, module in self.model.named_modules(): if hasattr(module, "post_process_weights"): print("Post processing weights for module", name) module.post_process_weights() + def _replace_weight_param_with_qweight(self): + for name, module in self.model.named_modules(): + if hasattr(module, "replace_weight_param_with_qweight"): + print("Replacing weight param with qweight for module", name) + module.replace_weight_param_with_qweight() + + @classmethod + def from_quantized(cls, + model_name_or_path: Optional[str], + trust_remote_code: bool = False, + **kwargs, + ): + """load quantized model from local disk""" + # Parameters related to loading from Hugging Face Hub + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + # == step1: prepare configs and file names == # + config: BitnetConfig = BitnetConfig.from_pretrained( + model_name_or_path, + trust_remote_code=trust_remote_code, + **cached_file_kwargs, + ) + # only load from remote instead of local + # TODO(lei): add local support + quantize_file = cached_file( + model_name_or_path, + "quantize_config.json" + ) + assert quantize_file is not None, "quantize config file not found" + import json + # get quantize format + with open(quantize_file, "r") as f: + quant_config = json.load(f) + checkpoint_format = quant_config["checkpoint_format"] + assert checkpoint_format in ["bitblas"], "quantize format not supported" + + import accelerate + if checkpoint_format == "bitblas": + model = cls(config) + for name, module in model.named_modules(): + if isinstance(module, BitLinear): + # create quantized version of the layer + print("Quantizing module", name) + bitblas_linear = BitLinearBitBLAS.from_bit_linear( + module + ) + print("Replacing module", name, "with a quantized version") + model.recursive_set(model, name, bitblas_linear) + accelerate.utils.modeling.load_checkpoint_in_model( + model, + checkpoint=model_name_or_path, + offload_state_dict=True, + offload_buffers=True, + ) + return model + @add_start_docstrings( """ The LLaMa Model transformer with a sequence classification head on top (linear layer). @@ -1390,4 +1500,4 @@ def forward( end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - ) \ No newline at end of file + ) diff --git a/integration/BitNet/utils_quant.py b/integration/BitNet/utils_quant.py index 121649387..d0f00237d 100644 --- a/integration/BitNet/utils_quant.py +++ b/integration/BitNet/utils_quant.py @@ -13,7 +13,6 @@ from logging import getLogger logger = getLogger(__name__) -bitblas.set_log_level("INFO") BITBLAS_TARGET = auto_detect_nvidia_target() BITBLAS_DATABASE_PATH = get_database_path() @@ -36,14 +35,22 @@ def activation_quant(x, num_bits=8): return result.type(dtype) -# BitBLAS BitLinear -class BitLinear(nn.Linear): +class BitLinearBitBLAS(nn.Module): - def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): - super(BitLinear, self).__init__(*kargs, **kwargs) + def __init__( + self, + in_features: int, + out_features: int, + weight_bits=1, + input_bits=8, + **kwargs, + ): + super().__init__() """ RMSNorm is placed outside BitLinear """ + self.in_features = in_features + self.out_features = out_features self.weight_bits = weight_bits self.input_bits = input_bits matmul_config = MatmulConfig( @@ -64,6 +71,7 @@ def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): ENABLE_TUNING = True self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, ENABLE_TUNING) + self.format = "bitnet" self.Qp = 2**(self.input_bits - 1) - 1 def _get_or_create_bitblas_operator(self, config, enable_tuning): @@ -86,14 +94,45 @@ def _get_or_create_bitblas_operator(self, config, enable_tuning): print("BitBLAS Operator found in global_operator_cache.") return bitblas_matmul + def replace_weight_param_with_qweight(self): + if hasattr(self, "weight"): + del self.weight + quant_weight = torch.empty(self.bitblas_matmul.retrieve_weight_shape()) + self.qweight = nn.Parameter(quant_weight, requires_grad=False) + self.format = "bitblas" + + @classmethod + def from_bit_linear(cls, bitlinear): + bitblas_linear = cls(bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) + sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight) + bitblas_linear.register_buffer("qweight", qweight) + bitblas_linear.register_buffer("sw", sw) + if bitlinear.bias is not None: + bitblas_linear.register_buffer("bias", bitlinear.bias) + else: + bitblas_linear.bias = None + return bitblas_linear + + def create_bitblas_weights(self, weight): + sw = 1 / weight.abs().mean().clamp(min=1e-5) + quant_weight = self.weight_quant(weight).detach() + quant_weight = self.bitblas_matmul.transform_weight(quant_weight) + qweight = nn.Parameter(quant_weight, requires_grad=False) + return sw, qweight + def post_process_weights(self): sw = 1 / self.weight.abs().mean().clamp(min=1e-5) self.sw = sw quant_weight = self.weight_quant(self.weight).detach() quant_weight = self.bitblas_matmul.transform_weight(quant_weight) - self.weight = nn.Parameter(quant_weight, requires_grad=False) - - def weight_quant(self, weight): + # remove self.weight and replace it with quant_weight + if hasattr(self, "weight"): + del self.weight + self.qweight = nn.Parameter(quant_weight, requires_grad=False) + self.format = "bitblas" + + @staticmethod + def weight_quant(weight): weight = weight.float() s = 1 / weight.abs().mean().clamp(min=1e-5) result = (weight * s).round().clamp(-1, 1) @@ -139,9 +178,8 @@ def forward_fp32_simulated(self, input): def forward(self, input): # return self.forward_fp32_simulated(input) - quant_input = self.activation_quant(input, self.input_bits).detach() - fp32_out = self.bitblas_matmul(quant_input, self.weight) + fp32_out = self.bitblas_matmul(quant_input, self.qweight) sw = self.sw Qp = self.Qp si = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) @@ -154,25 +192,25 @@ def forward(self, input): return out -# # Naive BitLinear from HuggingFace -# class BitLinear(nn.Linear): +# Naive BitLinear from HuggingFace +class BitLinear(nn.Linear): -# def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): -# super(BitLinear, self).__init__(*kargs, **kwargs) -# """ -# RMSNorm is placed outside BitLinear -# """ -# self.weight_bits = weight_bits -# self.input_bits = input_bits + def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): + super(BitLinear, self).__init__(*kargs, **kwargs) + """ + RMSNorm is placed outside BitLinear + """ + self.weight_bits = weight_bits + self.input_bits = input_bits -# def forward(self, input): + def forward(self, input): -# quant_input = input + (activation_quant(input, self.input_bits) - input).detach() -# quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - -# self.weight).detach() + quant_input = input + (activation_quant(input, self.input_bits) - input).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - + self.weight).detach() -# out = nn.functional.linear(quant_input, quant_weight) -# if not self.bias is None: -# out += self.bias.view(1, -1).expand_as(out) + out = nn.functional.linear(quant_input, quant_weight) + if not self.bias is None: + out += self.bias.view(1, -1).expand_as(out) -# return out + return out From b9655fd56347df67430d2d7765d84bda38cffba7 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 16 Jul 2024 14:31:46 +0000 Subject: [PATCH 31/88] chore: update codespell to version 2.3.0 --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 085de6a4f..99c101afb 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,7 @@ yapf==0.32.0 toml==0.10.2 tomli==2.0.1 ruff==0.1.5 -codespell==2.2.6 +codespell==2.3.0 cffi cpplint From fff385faf266d6175b729ea773d185085629bc34 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 16 Jul 2024 14:37:22 +0000 Subject: [PATCH 32/88] Lintfix --- bitblas/cache/operator.py | 1 + integration/BitNet/create_bitblas_ckpt.py | 18 +- integration/BitNet/eval_correctness.py | 7 +- integration/BitNet/load_from_quantized.py | 22 +- integration/BitNet/modeling_bitnet.py | 322 +++++++++++++--------- integration/BitNet/utils_quant.py | 8 +- 6 files changed, 206 insertions(+), 172 deletions(-) diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index e0c825a9a..cbb2e0437 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -18,6 +18,7 @@ BITBLAS_WRAPPED_SOURCE_NAME = "wrapper_source.cu" BITBLAS_WRAPPED_COMPILED_NAME = "wrapper_compiled.so" + class OperatorCache: """ Manages a cache for operator instances (e.g., Matmul, Convolution) based on their configurations. diff --git a/integration/BitNet/create_bitblas_ckpt.py b/integration/BitNet/create_bitblas_ckpt.py index b920c19bd..d443b2e20 100644 --- a/integration/BitNet/create_bitblas_ckpt.py +++ b/integration/BitNet/create_bitblas_ckpt.py @@ -18,9 +18,8 @@ bitblas.set_log_level("INFO") model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits" -saved_model_path = os.path.join( - dirpath, "models", f"{model_name_or_path}_bitblas" -) +saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") + def generate_text(model, tokenizer, prompt, max_length=100): input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) @@ -59,13 +58,8 @@ def main(): model_name_or_path, use_flash_attention_2=True, torch_dtype=torch.float16, - ) - .cuda() - .half() - ) - tokenizer = BitnetTokenizer.from_pretrained( - model_name_or_path, use_fast=False - ) + ).cuda().half()) + tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) # print("original model generated text:") # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) @@ -107,9 +101,7 @@ def main(): file_path = cached_file(model_name_or_path, file) os.system(f"cp {file_path} {saved_model_path}") # load quantized model - qmodel = BitnetForCausalLM.from_quantized( - saved_model_path, - ).cuda().half() + qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() print("quantized model generated text:") print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) diff --git a/integration/BitNet/eval_correctness.py b/integration/BitNet/eval_correctness.py index 22172e1a0..cef89313d 100644 --- a/integration/BitNet/eval_correctness.py +++ b/integration/BitNet/eval_correctness.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import argparse import torch import bitblas from modeling_bitnet import BitnetForCausalLM @@ -12,6 +11,7 @@ torch.set_grad_enabled(False) bitblas.set_log_level("INFO") + def generate_text(model, tokenizer, prompt, max_length=100): input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) # Generate cos and sin values @@ -69,7 +69,10 @@ def get_runtime(num_repeats=1): times = get_runtime(num_repeats) return np.mean(times) + model_path = '1bitLLM/bitnet_b1_58-3B' + + def main(): model = BitnetForCausalLM.from_pretrained( model_path, @@ -79,7 +82,6 @@ def main(): with torch.no_grad(): model._post_process_weights() - # input_id = torch.ones(1, 1).long().cuda() tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False) input_id = tokenizer("Hello")['input_ids'] input_id = torch.tensor(input_id).unsqueeze(0).cuda() @@ -87,7 +89,6 @@ def main(): print(output) print(generate_text(model, tokenizer, "Hello", max_length=100)) - if __name__ == '__main__': diff --git a/integration/BitNet/load_from_quantized.py b/integration/BitNet/load_from_quantized.py index 197f7dd15..acea3bd0a 100644 --- a/integration/BitNet/load_from_quantized.py +++ b/integration/BitNet/load_from_quantized.py @@ -5,11 +5,9 @@ import bitblas from modeling_bitnet import BitnetForCausalLM from tokenization_bitnet import BitnetTokenizer -from transformers.utils.hub import cached_file import os from transformers import GenerationConfig import time -import json filepath = os.path.abspath(__file__) dirpath = os.path.dirname(filepath) @@ -18,20 +16,14 @@ bitblas.set_log_level("INFO") model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits" -saved_model_path = os.path.join( - dirpath, "models", f"{model_name_or_path}_bitblas" -) +saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") def generate_text(model, tokenizer, prompt, max_length=100): - input_ids = tokenizer.encode(prompt, return_tensors="pt").to( - model.lm_head.weight.device - ) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) # Generate cos and sin values seq_length = input_ids.size(1) - position_ids = torch.arange( - seq_length, dtype=torch.long, device=input_ids.device - ) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) generation_config = GenerationConfig( @@ -60,12 +52,8 @@ def generate_text(model, tokenizer, prompt, max_length=100): def main(): # load quantized model - qmodel = BitnetForCausalLM.from_quantized( - saved_model_path, - ).cuda().half() - tokenizer = BitnetTokenizer.from_pretrained( - model_name_or_path, use_fast=False - ) + qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() + tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) # print("original model generated text:") # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) input_ids = torch.ones((1, 1), dtype=torch.long).cuda() diff --git a/integration/BitNet/modeling_bitnet.py b/integration/BitNet/modeling_bitnet.py index 49da46318..e4e1d88ea 100644 --- a/integration/BitNet/modeling_bitnet.py +++ b/integration/BitNet/modeling_bitnet.py @@ -31,7 +31,6 @@ from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache -from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -52,11 +51,11 @@ from utils_quant import BitLinear, BitLinearBitBLAS from transformers.utils.hub import cached_file - if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401 + def find_layers(module, layers=None, name=""): if not layers: layers = [nn.Linear] @@ -65,9 +64,11 @@ def find_layers(module, layers=None, name=""): return {name: module} res = {} for name1, child in module.named_children(): - res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) + res.update( + find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) return res + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "BitnetConfig" @@ -86,6 +87,7 @@ def _get_unpad_data(attention_mask): class BitnetRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): """ BitnetRMSNorm is equivalent to T5LayerNorm @@ -106,23 +108,34 @@ def forward(self, hidden_states): class BitnetRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + + def __init__(self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0): super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + self.base + **(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + self.register_buffer( + "_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer( + "_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property def sin_cached(self): @@ -143,12 +156,14 @@ def cos_cached(self): @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = self.inv_freq[None, :, + None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if isinstance(device_type, + str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) @@ -159,8 +174,8 @@ def forward(self, x, position_ids): def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) @@ -192,22 +207,32 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class BitnetMLP(nn.Module): + def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = BitLinear( - self.hidden_size, self.intermediate_size, bias=False, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.hidden_size, + self.intermediate_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self.up_proj = BitLinear( - self.hidden_size, self.intermediate_size, bias=False, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.hidden_size, + self.intermediate_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self.down_proj = BitLinear( - self.intermediate_size, self.hidden_size, bias=False, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.intermediate_size, + self.hidden_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self.act_fn = ACT2FN[config.hidden_act] self.ffn_layernorm = BitnetRMSNorm(self.intermediate_size, eps=config.rms_norm_eps) @@ -227,7 +252,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, + head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -242,8 +268,7 @@ def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + "when creating this class.") self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -258,24 +283,35 @@ def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) + f" and `num_heads`: {self.num_heads}).") self.q_proj = BitLinear( - self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.hidden_size, + self.num_heads * self.head_dim, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self.k_proj = BitLinear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self.v_proj = BitLinear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self.o_proj = BitLinear( - self.hidden_size, self.hidden_size, bias=config.attention_bias, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self._init_rope() self.inner_attn_ln = BitnetRMSNorm(self.hidden_size, eps=config.rms_norm_eps) @@ -308,8 +344,10 @@ def forward( value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) @@ -318,27 +356,30 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + f" {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() @@ -364,7 +405,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() @@ -391,8 +432,10 @@ def forward( # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -402,7 +445,8 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -431,16 +475,14 @@ def forward( logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + f" {target_dtype}.") query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate - ) + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.inner_attn_ln(attn_output) @@ -451,9 +493,14 @@ def forward( return attn_output, attn_weights, past_key_value - def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None - ): + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. @@ -483,8 +530,7 @@ def _flash_attention_forward( if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) + query_states, key_states, value_states, attention_mask, query_length) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -505,8 +551,12 @@ def _flash_attention_forward( attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal - ) + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal) return attn_output @@ -515,29 +565,27 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) if query_length == kv_seq_len: query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k - ) + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. + batch_size + 1, dtype=torch.int32, + device=query_layer.device) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask) return ( query_layer, @@ -556,11 +604,13 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query class BitnetDecoderLayer(nn.Module): + def __init__(self, config: BitnetConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx) self.mlp = BitnetMLP(config) self.input_layernorm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -593,8 +643,8 @@ def forward( """ if "padding_mask" in kwargs: warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`", + stacklevel=2) residual = hidden_states @@ -686,8 +736,7 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = else: dtype = layer.self_attn.o_proj.weight.dtype layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) + self.config, max_batch_size, max_cache_len, device=device, dtype=dtype) def _reset_cache(self): for layer in self.model.layers: @@ -786,9 +835,9 @@ def __init__(self, config: BitnetConfig): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) + self.layers = nn.ModuleList([ + BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) + ]) self.norm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -817,8 +866,8 @@ def forward( ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -837,17 +886,17 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -905,10 +954,11 @@ def forward( next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache - ) + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, Cache) else next_decoder_cache) if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -933,10 +983,13 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 - ) + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + causal_mask = torch.full((sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) @@ -945,8 +998,10 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + padding_mask = causal_mask[..., :mask_length].eq( + 0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill( + padding_mask, min_dtype) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. @@ -956,9 +1011,8 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[:mask_shape[0], :mask_shape[1], + offset:mask_shape[2] + offset, :mask_shape[3]] = mask_slice return causal_mask @@ -1036,8 +1090,8 @@ def forward( ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -1083,9 +1137,13 @@ def forward( attentions=outputs.attentions, ) - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs - ): + def prepare_inputs_for_generation(self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + **kwargs): # With static cache, the `past_key_values` is None # TODO joao: standardize interface for the different Cache classes and remove of this if has_static_cache = False @@ -1096,13 +1154,13 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + past_length = cache_position[ + 0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + if past_key_values.get_max_length() is not None else None) + cache_length = past_length if max_cache_length is None else torch.min( + max_cache_length, past_length) # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] @@ -1113,7 +1171,7 @@ def prepare_inputs_for_generation( # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: @@ -1121,11 +1179,8 @@ def prepare_inputs_for_generation( # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): + if (max_cache_length is not None and attention_mask is not None and + cache_length + input_ids.shape[1] > max_cache_length): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) @@ -1134,7 +1189,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + position_ids = position_ids[:, -input_ids.shape[1]:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -1147,31 +1202,30 @@ def prepare_inputs_for_generation( input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + cache_position = torch.arange( + past_length, past_length + input_length, device=input_ids.device) else: cache_position = cache_position[-input_length:] if has_static_cache: past_key_values = None - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) + model_inputs.update({ + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + }) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) + reordered_past += (tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past),) return reordered_past @staticmethod @@ -1192,9 +1246,7 @@ def quantize(self): if isinstance(module, BitLinear): # create quantized version of the layer print("Quantizing module", name) - bitblas_linear = BitLinearBitBLAS.from_bit_linear( - module - ) + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module) print("Replacing module", name, "with a quantized version") self.recursive_set(self.model, name, bitblas_linear) self.quantized = True @@ -1212,7 +1264,8 @@ def _replace_weight_param_with_qweight(self): module.replace_weight_param_with_qweight() @classmethod - def from_quantized(cls, + def from_quantized( + cls, model_name_or_path: Optional[str], trust_remote_code: bool = False, **kwargs, @@ -1249,10 +1302,7 @@ def from_quantized(cls, ) # only load from remote instead of local # TODO(lei): add local support - quantize_file = cached_file( - model_name_or_path, - "quantize_config.json" - ) + quantize_file = cached_file(model_name_or_path, "quantize_config.json") assert quantize_file is not None, "quantize config file not found" import json # get quantize format @@ -1261,16 +1311,14 @@ def from_quantized(cls, checkpoint_format = quant_config["checkpoint_format"] assert checkpoint_format in ["bitblas"], "quantize format not supported" - import accelerate + import accelerate if checkpoint_format == "bitblas": model = cls(config) for name, module in model.named_modules(): if isinstance(module, BitLinear): # create quantized version of the layer print("Quantizing module", name) - bitblas_linear = BitLinearBitBLAS.from_bit_linear( - module - ) + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module) print("Replacing module", name, "with a quantized version") model.recursive_set(model, name, bitblas_linear) accelerate.utils.modeling.load_checkpoint_in_model( @@ -1281,6 +1329,7 @@ def from_quantized(cls, ) return model + @add_start_docstrings( """ The LLaMa Model transformer with a sequence classification head on top (linear layer). @@ -1297,6 +1346,7 @@ def from_quantized(cls, LLAMA_START_DOCSTRING, ) class BitnetForSequenceClassification(BitnetPreTrainedModel): + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels @@ -1360,7 +1410,8 @@ def forward( else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = torch.eq(input_ids, + self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: @@ -1374,7 +1425,8 @@ def forward( if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + elif self.num_labels > 1 and (labels.dtype == torch.long or + labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" diff --git a/integration/BitNet/utils_quant.py b/integration/BitNet/utils_quant.py index d0f00237d..d9cc25ae7 100644 --- a/integration/BitNet/utils_quant.py +++ b/integration/BitNet/utils_quant.py @@ -6,7 +6,6 @@ import torch from torch import nn -import bitblas from bitblas.cache import global_operator_cache, get_database_path from bitblas import Matmul, MatmulConfig from bitblas import auto_detect_nvidia_target @@ -103,7 +102,8 @@ def replace_weight_param_with_qweight(self): @classmethod def from_bit_linear(cls, bitlinear): - bitblas_linear = cls(bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) + bitblas_linear = cls( + bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight) bitblas_linear.register_buffer("qweight", qweight) bitblas_linear.register_buffer("sw", sw) @@ -111,7 +111,7 @@ def from_bit_linear(cls, bitlinear): bitblas_linear.register_buffer("bias", bitlinear.bias) else: bitblas_linear.bias = None - return bitblas_linear + return bitblas_linear def create_bitblas_weights(self, weight): sw = 1 / weight.abs().mean().clamp(min=1e-5) @@ -210,7 +210,7 @@ def forward(self, input): self.weight).detach() out = nn.functional.linear(quant_input, quant_weight) - if not self.bias is None: + if self.bias is not None: out += self.bias.view(1, -1).expand_as(out) return out From 72a98e7ba51ef741c3f6880bc40eb257e9e2c76f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 18 Jul 2024 05:22:56 +0000 Subject: [PATCH 33/88] Bump version to 0.0.1.dev13 --- VERSION | 2 +- bitblas/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/VERSION b/VERSION index cef7517c2..419bd5f01 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.1.dev12 \ No newline at end of file +0.0.1.dev13 \ No newline at end of file diff --git a/bitblas/__init__.py b/bitblas/__init__.py index c2f44ae05..0694e57e2 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -86,4 +86,4 @@ def _init_logger(): _init_logger() -__version__ = "0.0.1.dev12" +__version__ = "0.0.1.dev13" From 5646ab5f18e8215d151b3d48f9d21c01b40cacb6 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 18 Jul 2024 05:31:34 +0000 Subject: [PATCH 34/88] lint fix --- bitblas/__init__.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 0694e57e2..e40f17f3e 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -4,14 +4,12 @@ import os # installing tvm -install_tvm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") +install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") sys.path.insert(0, install_tvm_path + "/python") -develop_tvm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm") +develop_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm") if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path: os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") sys.path.insert(0, develop_tvm_path + "/python") @@ -31,7 +29,6 @@ try_inline_contiguous_spatial, # noqa: F401 ) - from . import testing # noqa: F401 from .utils import auto_detect_nvidia_target # noqa: F401 from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 From b965863999c23909dd4160a0267cb5e58dd7c86f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 21 Jul 2024 06:13:15 +0000 Subject: [PATCH 35/88] disable fast decoding [u]int4xint8 by default. --- bitblas/ops/general_matmul/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 5c3f6d2e6..867e06edb 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -128,6 +128,9 @@ def is_not_fast_decoding_supported(): conditions.append(self.W_dtype == self.A_dtype) # int8,uint8 also do not implement and also do not require fast decoding conditions.append(self.W_dtype in ["int8", "uint8"]) + # if the w_dtype is int4/uint4 and the a_dtype is int8 + # we do not require fast decoding + conditions.append(self.W_dtype in ["int4", "uint4"] and self.A_dtype in ["int8"]) return any(conditions) if fast_decoding is not None: From 1198fc777e2f9665915bba8052b6c3e16a127c08 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 21 Jul 2024 06:13:31 +0000 Subject: [PATCH 36/88] optimize from dict design in Hint --- bitblas/base/roller/hint.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/bitblas/base/roller/hint.py b/bitblas/base/roller/hint.py index f6e2fb03a..85fb294bc 100644 --- a/bitblas/base/roller/hint.py +++ b/bitblas/base/roller/hint.py @@ -210,12 +210,13 @@ def to_dict(self) -> Dict: if self.block_reduction_depth is not None: dic["block_reduction_depth"] = self.block_reduction_depth return dic - - def from_dict(self, dic: Dict) -> "Hint": - self.__init__() + + @classmethod + def from_dict(cls, dic: Dict) -> "Hint": + hint = cls() for k, v in dic.items(): - setattr(self, k, v) - return self + setattr(hint, k, v) + return hint def tensorcore_legalization(self): # only keep the last 2 axes for tensorcore From 014213c8cf52d465c29ba191b492e75a7f464f1c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 21 Jul 2024 06:14:24 +0000 Subject: [PATCH 37/88] Implement SplitK --- .../ops/impl/matmul_dequantize_splitk_impl.py | 208 +++++++++++++++++- 1 file changed, 207 insertions(+), 1 deletion(-) diff --git a/bitblas/ops/impl/matmul_dequantize_splitk_impl.py b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py index d3ef23187..875d3477b 100644 --- a/bitblas/ops/impl/matmul_dequantize_splitk_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py @@ -2,7 +2,10 @@ # Licensed under the MIT License. # pre-transformed tir expression of matmul from bitblas import tvm -from tvm import te +from tvm import te, DataType +from tvm.tir import IndexMap +from bitblas.ops.operator import TransformKind +from bitblas.gpu.matmul_analysis import get_propagate_map from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, _tir_u8_to_f8_e4m3_to_f16) @@ -129,6 +132,209 @@ def decode_func(n, k): return tvm.IRModule.from_expr(func) +def matmul_nt_dequantize_b_propagate_a_propagate_b( + SplitK, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, + transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") + A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind_input >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return A[new_index] + + A_reindex = te.compute( + (M, K), + fcompute, + name="A_reindex", + ) + + _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inversed_index_map.initial_indices + scaling_final_indices = inversed_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inversed_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + qr = r * bit // storage_nbit + B = te.placeholder((N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind_weight >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K // storage_nbit * bit), + fcompute, + name="B_reindex", + ) + + def decode_func(n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "int": + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + # Describe the matrix multiplication in TE + RK = K // SplitK + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, RK), name="k") + C = te.compute( + (SplitK, M, N), + lambda sk, i, j: te.sum( + A_reindex[i, sk * RK + k].astype(accum_dtype) * B_decode[j, sk * RK + k].astype(accum_dtype), + axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((SplitK, M, N), lambda b, i, j: last_output[b, i, j].astype(out_dtype), name="D") + last_output = D + + args = [A, B] + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_zeros: + args.append(Zeros) + if with_bias: + E = te.compute((SplitK, M, N), lambda b, i, j: D[b, i, j] + Bias[j], name="E") + last_output = E + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "with_scaling": with_scaling, + "group_size": group_size, + } + }, + ) + func = func.with_attr("input_transform_kind", transform_kind_input.value) + func = func.with_attr("weight_transform_kind", transform_kind_weight.value) + return tvm.IRModule.from_expr(func) + + def select_implementation( SplitK=1, M=None, From e0ca752797b9a6b57b4dc47d7c17c96760909076 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 21 Jul 2024 06:14:35 +0000 Subject: [PATCH 38/88] bitnet benchmark generation. --- integration/BitNet/benchmark_generate.py | 112 +++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 integration/BitNet/benchmark_generate.py diff --git a/integration/BitNet/benchmark_generate.py b/integration/BitNet/benchmark_generate.py new file mode 100644 index 000000000..4c3cd8dca --- /dev/null +++ b/integration/BitNet/benchmark_generate.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import bitblas +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer +from transformers import GenerationConfig +import time +import argparse + +torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + + +def generate_text_batch(model, tokenizer, prompts, max_length=100): + # Encode the input prompts as a batch + input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) + + # Generate cos and sin values (commented out as not used in generation) + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + # position_embeddings = model.embed_positions(position_ids) + # cos = position_embeddings[:, :, 0::2].cos() + # sin = position_embeddings[:, :, 1::2].sin() + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + # output_ids = model.generate(input_ids, generation_config=generation_config, cos=cos, sin=sin) + end_time = time.time() + + # Decode the output ids to text + generated_texts = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids] + + generation_time = end_time - start_time + num_tokens = sum(len(output_id) for output_id in output_ids) + tokens_per_second = num_tokens / generation_time + + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_texts + +def profile(model, input_data): + + import numpy as np + model = model.cuda() + model.eval() + + def get_runtime(num_repeats=1): + tic = time.time() + for _ in range(num_repeats): + _ = model(input_data) + torch.cuda.synchronize() + return (time.time() - tic) * 1000 / num_repeats + + with torch.no_grad(): + st = time.time() + while time.time() - st < 1.0: + get_runtime() # warmup + warmup_runtime = get_runtime() + num_repeats = max(1, int(1000 / warmup_runtime)) + times = get_runtime(num_repeats) + return np.mean(times) + + +model_path = '1bitLLM/bitnet_b1_58-3B' + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--bs', default=16, type=int) + parser.add_argument('--in_seq_len', default=32, type=int) + parser.add_argument('--out_seq_len', default=128, type=int) + parser.add_argument('--bitblas', action='store_true') + args = parser.parse_args() + bs = args.bs + in_seq_len = args.in_seq_len + out_seq_len = args.out_seq_len + is_bitblas = args.bitblas + model = BitnetForCausalLM.from_pretrained( + model_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ).cuda().half() + if is_bitblas: + with torch.no_grad(): + model.quantize() + + tokenizer = BitnetTokenizer.from_pretrained(model_path) + prompt = "" + for _ in range(in_seq_len): + prompt += "Hello " + + prompts = [] + for _ in range(bs): + prompts.append(prompt) + max_length = out_seq_len + in_seq_len + print(generate_text_batch(model, tokenizer, prompts, max_length=max_length)) + + +if __name__ == '__main__': + main() From 81b9cf0eab808ec760c6120d47552b9a189279c7 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 21 Jul 2024 06:15:16 +0000 Subject: [PATCH 39/88] Add benchmark script for BitNet integration --- integration/BitNet/benchmark.sh | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100755 integration/BitNet/benchmark.sh diff --git a/integration/BitNet/benchmark.sh b/integration/BitNet/benchmark.sh new file mode 100755 index 000000000..37a21167d --- /dev/null +++ b/integration/BitNet/benchmark.sh @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +python benchmark_generate.py --bs 16 --in_seq_len 32 --out_seq_len 128 | tee b16_i32_o128.log + +python benchmark_generate.py --bs 1 --in_seq_len 512 --out_seq_len 64 | tee b1_i512_o64.log + +python benchmark_generate.py --bs 32 --in_seq_len 32 --out_seq_len 128 | tee b32_i32_o128.log + +python benchmark_generate.py --bs 16 --in_seq_len 32 --out_seq_len 128 --bitblas | tee b16_i32_o128_bitblas.log + +python benchmark_generate.py --bs 1 --in_seq_len 512 --out_seq_len 64 --bitblas | tee b1_i512_o64_bitblas.log + +python benchmark_generate.py --bs 32 --in_seq_len 32 --out_seq_len 128 --bitblas | tee b32_i32_o128_bitblas.log From 02edc0ba2024615dda4a7924905dde18ba9cd149 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 21 Jul 2024 06:20:51 +0000 Subject: [PATCH 40/88] AtomicAdd Support --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index a077796b9..2b8c136f9 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit a077796b9e2dd3b2275fbaa212786645758c360d +Subproject commit 2b8c136f9270c84ef93ae9a47b1ef7c5d2e4d639 From 1a70c2dd9ba86f62e91043029fd3ea53b189019c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 21 Jul 2024 06:20:59 +0000 Subject: [PATCH 41/88] LintFix --- bitblas/base/roller/hint.py | 2 +- bitblas/ops/general_matmul/__init__.py | 14 ++++++---- .../ops/impl/matmul_dequantize_splitk_impl.py | 7 +++-- integration/BitNet/benchmark_generate.py | 12 +++++--- integration/BitNet/eval_correctness.py | 28 +++++++++++-------- 5 files changed, 40 insertions(+), 23 deletions(-) diff --git a/bitblas/base/roller/hint.py b/bitblas/base/roller/hint.py index 85fb294bc..191614dfa 100644 --- a/bitblas/base/roller/hint.py +++ b/bitblas/base/roller/hint.py @@ -210,7 +210,7 @@ def to_dict(self) -> Dict: if self.block_reduction_depth is not None: dic["block_reduction_depth"] = self.block_reduction_depth return dic - + @classmethod def from_dict(cls, dic: Dict) -> "Hint": hint = cls() diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 867e06edb..7ed8fbc39 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -92,7 +92,7 @@ def __initialize_propagate(self, propagate_a: Optional[TransformKind], # Currently we do not support propagate_a when propagate_b is not transformed. object.__setattr__(self, "propagate_a", TransformKind.NonTransform) elif (isinstance(self.M, int) and (self.M % MICRO_KERNEL_SIZE) == 0 and - (self.K % MICRO_KERNEL_SIZE) == 0): + (self.K % MICRO_KERNEL_SIZE) == 0): object.__setattr__(self, "propagate_a", TransformKind.IntraWarpTransform) else: object.__setattr__(self, "propagate_a", TransformKind.NonTransform) @@ -245,7 +245,11 @@ def __init__( self.dispatch_tir(target, from_database, source_format, enable_tuning) - def dispatch_tir(self, target: Target, from_database: bool = False, source_format: str = "uint", enable_tuning: bool = True): + def dispatch_tir(self, + target: Target, + from_database: bool = False, + source_format: str = "uint", + enable_tuning: bool = True): '''Dispatch the tir script implementation''' self.arch = CUDA(target) @@ -298,7 +302,7 @@ def dispatch_tir(self, target: Target, from_database: bool = False, source_forma # output data type self.torch_output_dtype = getattr(torch, self.out_dtype) - + def _alloc_workspace(self): return torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() @@ -322,7 +326,7 @@ def _assign_ladder_permutate_a(self, target: Target, enable_tuning: bool): ) self.workspace = self._alloc_workspace() return ladder_permutate_a - + def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool): # unused variables del target @@ -369,7 +373,7 @@ def _create_input_executors(self): if self.propagate_a is not TransformKind.NonTransform: input_executors.append(self.ladder_permutate_a) return input_executors - + def _create_weight_executors(self): weight_executors = OPExecutorCPU() if self.fast_decoding: diff --git a/bitblas/ops/impl/matmul_dequantize_splitk_impl.py b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py index 875d3477b..cc1b60de0 100644 --- a/bitblas/ops/impl/matmul_dequantize_splitk_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py @@ -289,13 +289,16 @@ def decode_func(n, k): C = te.compute( (SplitK, M, N), lambda sk, i, j: te.sum( - A_reindex[i, sk * RK + k].astype(accum_dtype) * B_decode[j, sk * RK + k].astype(accum_dtype), + A_reindex[i, sk * RK + k].astype(accum_dtype) * B_decode[j, sk * RK + k].astype( + accum_dtype), axis=k), name="C", ) last_output = C if accum_dtype != out_dtype: - D = te.compute((SplitK, M, N), lambda b, i, j: last_output[b, i, j].astype(out_dtype), name="D") + D = te.compute((SplitK, M, N), + lambda b, i, j: last_output[b, i, j].astype(out_dtype), + name="D") last_output = D args = [A, B] diff --git a/integration/BitNet/benchmark_generate.py b/integration/BitNet/benchmark_generate.py index 4c3cd8dca..f597b8f6d 100644 --- a/integration/BitNet/benchmark_generate.py +++ b/integration/BitNet/benchmark_generate.py @@ -15,8 +15,9 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): # Encode the input prompts as a batch - input_ids = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) - + input_ids = tokenizer( + prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device) + # Generate cos and sin values (commented out as not used in generation) seq_length = input_ids.size(1) position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) @@ -39,7 +40,9 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): end_time = time.time() # Decode the output ids to text - generated_texts = [tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids] + generated_texts = [ + tokenizer.decode(output_id, skip_special_tokens=True) for output_id in output_ids + ] generation_time = end_time - start_time num_tokens = sum(len(output_id) for output_id in output_ids) @@ -50,6 +53,7 @@ def generate_text_batch(model, tokenizer, prompts, max_length=100): return generated_texts + def profile(model, input_data): import numpy as np @@ -100,7 +104,7 @@ def main(): prompt = "" for _ in range(in_seq_len): prompt += "Hello " - + prompts = [] for _ in range(bs): prompts.append(prompt) diff --git a/integration/BitNet/eval_correctness.py b/integration/BitNet/eval_correctness.py index cef89313d..e832441d2 100644 --- a/integration/BitNet/eval_correctness.py +++ b/integration/BitNet/eval_correctness.py @@ -74,21 +74,27 @@ def get_runtime(num_repeats=1): def main(): - model = BitnetForCausalLM.from_pretrained( + # model = BitnetForCausalLM.from_pretrained( + # model_path, + # use_flash_attention_2=True, + # torch_dtype=torch.float16, + # ).cuda().half() + + tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False) + # input_id = tokenizer("Hello")['input_ids'] + # input_id = torch.tensor(input_id).unsqueeze(0).cuda() + # # output = model(input_id) + # # print(output) + + # print(generate_text(model, tokenizer, "Hi, tell me about microsoft?", max_length=100)) + + qmodel = BitnetForCausalLM.from_pretrained( model_path, use_flash_attention_2=True, torch_dtype=torch.float16, ).cuda().half() - with torch.no_grad(): - model._post_process_weights() - - tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False) - input_id = tokenizer("Hello")['input_ids'] - input_id = torch.tensor(input_id).unsqueeze(0).cuda() - output = model(input_id) - print(output) - - print(generate_text(model, tokenizer, "Hello", max_length=100)) + qmodel.quantize() + print(generate_text(qmodel, tokenizer, "Hi, tell me about microsoft?", max_length=100)) if __name__ == '__main__': From c447a955e0fdf3eca4745fca775c0570949b477f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 21 Jul 2024 06:29:52 +0000 Subject: [PATCH 42/88] ci fix when 3rdparty tvm is initialized. --- .github/workflows/ci.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8cf347e57..9b76866c5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,6 +26,9 @@ jobs: python -m pip install --upgrade pip if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi + - name: Update submodules recursively + run: git submodule update --init --recursive + - name: Run format check run: | source bitblas_ci/bin/activate From 79a001bb86426c6a9f1049c1b3693d9f28d2579e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 21 Jul 2024 06:42:19 +0000 Subject: [PATCH 43/88] bug fix for setup --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7e82b998a..6954fa798 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ import urllib.request from distutils.version import LooseVersion import platform +import multiprocessing # Environment variables False/True PYPI_BUILD = os.environ.get("PYPI_BUILD", "False").lower() == "true" @@ -180,7 +181,8 @@ def build_tvm(llvm_config_path): # Run CMake and make try: subprocess.check_call(["cmake", ".."]) - subprocess.check_call(["make", "-j$(nproc)"]) + num_jobs = multiprocessing.cpu_count() + subprocess.check_call(["make", f"-j{num_jobs}"]) except subprocess.CalledProcessError as error: raise RuntimeError("Failed to build TVM") from error finally: From 31813b2e108ad24d21adf3a80fe2dceb797a5239 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 21 Jul 2024 15:54:08 +0000 Subject: [PATCH 44/88] fix a bug in block reduce --- bitblas/gpu/matmul_mma_dequantize.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index 679e84395..ca4779569 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -1985,20 +1985,19 @@ def get_param_indices( k0, k1 = sch.split(k, k_factors) k0, kr = sch.split(k0, [None, reduce_k]) - sch.reorder(i0, j0, i1, j1, i2, j2, kr, i3, j3, k0, k1) - + sch.reorder(i0, j0, i1, j1, i2, j2, kr, k0, k1, i3, j3) + # sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) block_idy = sch.fuse(i0, j0) block_idx = sch.fuse(i1, j1) thread_idy = i2 thread_idz = j2 - + sch.bind(batch, "blockIdx.z") sch.bind(block_idx, "blockIdx.x") sch.bind(block_idy, "blockIdx.y") thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) sch.bind(thread_idy, "threadIdx.y") - sch.bind(kr, "threadIdx.z") - + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 if not enable: return @@ -2064,6 +2063,7 @@ def decode_fetch_to_shared(block, idx): block_shared = sch.cache_read(block, idx, shared_scope) sch.compute_at(block_shared, k0, preserve_unit_loops=True) + # TODO(lei): the factor should be analyzed more deeper. decode_factor = get_coalesced_veclen(sch.get(block_shared)) _, B_shared_vi, _ = sch.split( @@ -2165,6 +2165,9 @@ def get_idx(): _ = decode_fetch_to_shared(block_outer, 1) + # Put the thread binding after the shared memory prefetch + # Otherwise there's a axis mssing bug behind tvm + sch.bind(kr, "threadIdx.z") # create read cache to load matrix from shared memory to wmma fragments A_mat = sch.cache_read(block_outer, 0, "warp") B_mat = sch.cache_read(block_outer, 1, "warp") From 78b6a3dcded27f8b791ab38f2a6bdd64be2f14d1 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 21 Jul 2024 15:54:22 +0000 Subject: [PATCH 45/88] typo fix --- bitblas/gpu/matmul_mma_dequantize.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index ca4779569..f7dede4a1 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -1991,13 +1991,13 @@ def get_param_indices( block_idx = sch.fuse(i1, j1) thread_idy = i2 thread_idz = j2 - + sch.bind(batch, "blockIdx.z") sch.bind(block_idx, "blockIdx.x") sch.bind(block_idy, "blockIdx.y") thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) sch.bind(thread_idy, "threadIdx.y") - + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 if not enable: return @@ -2063,7 +2063,6 @@ def decode_fetch_to_shared(block, idx): block_shared = sch.cache_read(block, idx, shared_scope) sch.compute_at(block_shared, k0, preserve_unit_loops=True) - # TODO(lei): the factor should be analyzed more deeper. decode_factor = get_coalesced_veclen(sch.get(block_shared)) _, B_shared_vi, _ = sch.split( @@ -2166,7 +2165,7 @@ def get_idx(): _ = decode_fetch_to_shared(block_outer, 1) # Put the thread binding after the shared memory prefetch - # Otherwise there's a axis mssing bug behind tvm + # Otherwise there's a axis missing bug behind tvm sch.bind(kr, "threadIdx.z") # create read cache to load matrix from shared memory to wmma fragments A_mat = sch.cache_read(block_outer, 0, "warp") From 9c552185837e2d495710a560847e149832bee62b Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 06:56:27 +0000 Subject: [PATCH 46/88] BUG Fix for block reduce. --- bitblas/base/roller/policy/tensorcore.py | 28 ++++++++++++++---------- bitblas/gpu/matmul_mma_dequantize.py | 6 ++--- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index e69bcabc3..00b5bfcc5 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -4,7 +4,7 @@ from bitblas import tvm from typing import Dict, List, Tuple, Optional import numpy as np - +import logging from ...arch import TileDevice from ..hint import Hint, Stride, TileDict, IntrinInfo from ..node import PrimFuncNode @@ -12,6 +12,7 @@ from .default import DefaultPolicy from ..rasterization import NoRasterization, Rasterization2DColumn +logger = logging.getLogger(__name__) class TensorCorePolicy(DefaultPolicy): @@ -47,9 +48,9 @@ def _legalize_info(self): self.use_async_copy = False # TODO: block reduction depth is not used for now. # As there still exists some performance issues for block reduction. - # block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth") - # if block_reduction_depth: - # self.block_reduction_depth = block_reduction_depth + block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth") + if block_reduction_depth: + self.block_reduction_depth = block_reduction_depth def _compute_tc_strides( self, @@ -185,12 +186,12 @@ def _enlarge(rstep_id): rstep = _optimize(node, rstep_map) rstep_map = rstep - if is_block_reduction: - # If block reduction, we should constrain the max value is 64 - # Otherwise it will introduce an issue of cuda invalid args. - MAX_REDUCE_K = 64 - for k in rstep_map: - rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K) + # if is_block_reduction: + # # If block reduction, we should constrain the max value is 64 + # # Otherwise it will introduce an issue of cuda invalid args. + # MAX_REDUCE_K = 64 + # for k in rstep_map: + # rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K) td.rstep_map = rstep_map td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) return @@ -315,7 +316,12 @@ def _score(node, thread): # small is better if intrin_info["out_dtype"] in ["float32"]: codegen_dict.shared_scope = "shared.dyn" # smem capacity - if td.smem_cost > self.arch.smem_cap: + # TODO: This is a dummy mul which avoid reusing some shared memory. + # Should be removed in the future. + if td.smem_cost > (self.arch.smem_cap * 1.3): + info_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \ + " use dynamic shared memory." + logger.info(info_message) codegen_dict.shared_scope = "shared.dyn" codegen_dict.complete_config(node) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index f7dede4a1..f04a5d043 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -1998,6 +1998,9 @@ def get_param_indices( thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) sch.bind(thread_idy, "threadIdx.y") + # Put the thread binding after the shared memory prefetch + # Otherwise there's a axis missing bug behind tvm + sch.bind(kr, "threadIdx.z") def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 if not enable: return @@ -2164,9 +2167,6 @@ def get_idx(): _ = decode_fetch_to_shared(block_outer, 1) - # Put the thread binding after the shared memory prefetch - # Otherwise there's a axis missing bug behind tvm - sch.bind(kr, "threadIdx.z") # create read cache to load matrix from shared memory to wmma fragments A_mat = sch.cache_read(block_outer, 0, "warp") B_mat = sch.cache_read(block_outer, 1, "warp") From 1aa886833b602fe0aa9c01f58225079231d0b77e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 06:57:03 +0000 Subject: [PATCH 47/88] Lint fix --- bitblas/base/roller/policy/tensorcore.py | 2 +- bitblas/gpu/matmul_mma_dequantize.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index 00b5bfcc5..9e6fff9ee 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -14,6 +14,7 @@ logger = logging.getLogger(__name__) + class TensorCorePolicy(DefaultPolicy): def __init__(self, @@ -121,7 +122,6 @@ def _check_small_tile(td: TileDict): smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) rstep_map = td.rstep_map.copy() - is_block_reduction = self.block_reduction_depth is not None def _optimize(node, rstep): all_steps = self.get_node_reduce_step_candidates(node) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index f04a5d043..37098b2d0 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -2001,6 +2001,7 @@ def get_param_indices( # Put the thread binding after the shared memory prefetch # Otherwise there's a axis missing bug behind tvm sch.bind(kr, "threadIdx.z") + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 if not enable: return From 5f082a5923cbbb24259e28d8e8ae81ccef51f0cc Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 06:58:19 +0000 Subject: [PATCH 48/88] Refactor block reduce schedule template --- bitblas/gpu/matmul_mma_dequantize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index d71443124..de1b5b896 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -1986,7 +1986,7 @@ def get_param_indices( k0, kr = sch.split(k0, [None, reduce_k]) sch.reorder(i0, j0, i1, j1, i2, j2, kr, k0, k1, i3, j3) - # sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) + block_idy = sch.fuse(i0, j0) block_idx = sch.fuse(i1, j1) thread_idy = i2 From b4fb31e7383c4652960b508355cb3dc399928a62 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 07:42:46 +0000 Subject: [PATCH 49/88] transform branch from bitblas to bitblas_tl --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 2b8c136f9..59029c198 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 2b8c136f9270c84ef93ae9a47b1ef7c5d2e4d639 +Subproject commit 59029c198501f0d7fc92a945a6b9e3ead4e9e019 From 35eaa00d868274731273d32d6878cd830bce17a8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 08:19:34 +0000 Subject: [PATCH 50/88] Fix subproject commit reference in 3rdparty/tvm --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 59029c198..699efe771 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 59029c198501f0d7fc92a945a6b9e3ead4e9e019 +Subproject commit 699efe77188cf7de45d49c445ef1b198e8414aba From 254dd7423bfc2ff9cb276ffe0088c9d5ba615cc0 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 08:25:44 +0000 Subject: [PATCH 51/88] chore: update submodule branch from bitblas to bitblas_tl --- .gitmodules | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index c95978101..57576c5fe 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "3rdparty/tvm"] path = 3rdparty/tvm url = https://github.com/LeiWang1999/tvm - branch = bitblas + branch = bitblas_tl From 31a44aae4ae38eac140aa69f5be7a152dcf63e51 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 09:08:26 +0000 Subject: [PATCH 52/88] force update config.cmake --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 699efe771..3b545bfe1 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 699efe77188cf7de45d49c445ef1b198e8414aba +Subproject commit 3b545bfe17929583a98e4c3d55a2d9f3895e1e29 From 427800ed36a995a84ceb6d0dee65d05118afba53 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 12:30:15 +0000 Subject: [PATCH 53/88] Bug fix --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 3b545bfe1..049a8c5f4 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 3b545bfe17929583a98e4c3d55a2d9f3895e1e29 +Subproject commit 049a8c5f44d5c911be992f650dba78e8c7a75203 From 96db111f5893faed6e0ebb8a4d4c434aaf9cf4a4 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 13:37:24 +0000 Subject: [PATCH 54/88] Fix subproject commit reference in 3rdparty/cutlass --- 3rdparty/cutlass | 1 + 1 file changed, 1 insertion(+) create mode 160000 3rdparty/cutlass diff --git a/3rdparty/cutlass b/3rdparty/cutlass new file mode 160000 index 000000000..44c704eae --- /dev/null +++ b/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit 44c704eae85da352d277d6f092f41412772f70e4 From 38b251a345048b90b425b8f4dc7a217cd5a753fc Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 13:48:31 +0000 Subject: [PATCH 55/88] chore: Add submodule for cutlass library --- .gitmodules | 4 ++++ 3rdparty/cutlass | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index 57576c5fe..5980500e4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,3 +2,7 @@ path = 3rdparty/tvm url = https://github.com/LeiWang1999/tvm branch = bitblas_tl +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass.git + branch = v3.2.2 diff --git a/3rdparty/cutlass b/3rdparty/cutlass index 44c704eae..56b46e2d1 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit 44c704eae85da352d277d6f092f41412772f70e4 +Subproject commit 56b46e2d13875b46b8f6a03f9f5ac91e2bfdc01a From 87d1c5a47476e1c3458056f99bd09f0a8e199823 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 13:51:35 +0000 Subject: [PATCH 56/88] update tl cutlass path --- bitblas/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bitblas/__init__.py b/bitblas/__init__.py index e40f17f3e..8fa04c7b6 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -5,13 +5,17 @@ # installing tvm install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") +install_cutlass_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass") if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") + os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include" sys.path.insert(0, install_tvm_path + "/python") develop_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm") +develop_cutlass_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass") if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path: os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") + os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include" sys.path.insert(0, develop_tvm_path + "/python") import tvm as tvm # noqa: E402 From 0ffe0b50434fb6fa4c76c8b41b09f5b86ad9ff06 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 14:43:20 +0000 Subject: [PATCH 57/88] Refactor BitBLASLinear test module for improved readability and maintainability --- 3rdparty/cutlass | 2 +- 3rdparty/tvm | 2 +- bitblas/__init__.py | 6 +- testing/python/tilelang/test_tilelang_gemm.py | 177 ++++++++++++++++++ 4 files changed, 183 insertions(+), 4 deletions(-) create mode 100644 testing/python/tilelang/test_tilelang_gemm.py diff --git a/3rdparty/cutlass b/3rdparty/cutlass index 56b46e2d1..44c704eae 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit 56b46e2d13875b46b8f6a03f9f5ac91e2bfdc01a +Subproject commit 44c704eae85da352d277d6f092f41412772f70e4 diff --git a/3rdparty/tvm b/3rdparty/tvm index 049a8c5f4..d9391a502 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 049a8c5f44d5c911be992f650dba78e8c7a75203 +Subproject commit d9391a502b5544722eb67c4a0c4dff49a3476c06 diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 8fa04c7b6..ee79bc3c9 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -5,14 +5,16 @@ # installing tvm install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") -install_cutlass_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass") +install_cutlass_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass") if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include" sys.path.insert(0, install_tvm_path + "/python") develop_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm") -develop_cutlass_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass") +develop_cutlass_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass") if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path: os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include" diff --git a/testing/python/tilelang/test_tilelang_gemm.py b/testing/python/tilelang/test_tilelang_gemm.py new file mode 100644 index 000000000..50e39ed23 --- /dev/null +++ b/testing/python/tilelang/test_tilelang_gemm.py @@ -0,0 +1,177 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas import tvm as tvm +import bitblas.testing +from tvm import tl + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + dtypeAB, + dtypeC, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + import tvm.tl.language as T + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, dtypeAB), C: T.Buffer((M, N), dtypeC) + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + trans_A, + trans_B, + dtypeAB, + dtypeC, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + dtypeAB, + dtypeC, + dtypeAccum, + num_stages, + num_threads, + ) + mod, params = tl.lower(program) + mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + mod.assert_allclose(ref_program) + + +def test_gemm_f16f16f16_nn(): + run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def test_gemm_f16f16f32_nn(): + run_gemm(512, 1024, 768, False, False, "float16", "float16", "float32", 128, 128, 32) + + +def test_gemm_bf16bf16f32_nn(): + run_gemm(512, 1024, 768, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32) + + +def test_gemm_f32f32f32_nn(): + run_gemm(512, 1024, 768, False, False, "float32", "float32", "float32", 64, 128, 32) + + +def test_gemm_f64f64f64_nn(): + run_gemm(512, 1024, 768, False, False, "float64", "float64", "float64", 64, 64, 16) + + +def test_gemm_i8i8i32_nn(): + run_gemm(512, 1024, 768, False, False, "int8", "int8", "int32", 128, 128, 64) + + +def test_gemm_f16f16f16_tn(): + run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2) + + +def test_gemm_f16f16f16_nt(): + run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) + + +def test_gemm_i8i8i32_nt(): + run_gemm(512, 1024, 768, False, True, "int8", "int8", "int32", 128, 128, 64) + + +def test_gemm_i8i8i32_tn(): + run_gemm(512, 1024, 768, True, False, "int8", "int8", "int32", 128, 128, 64) + + +def test_gemm_f64f64f64_nt(): + run_gemm(512, 1024, 768, False, True, "float64", "float64", "float64", 64, 32, 16) + + +def test_gemm_f64f64f64_tn(): + run_gemm(512, 1024, 768, True, False, "float64", "float64", "float64", 64, 32, 16) + + +def test_gemm_f32f32f32_nt(): + run_gemm(512, 1024, 768, False, True, "float32", "float32", "float32", 64, 128, 32) + + +def test_gemm_f32f32f32_tn(): + run_gemm(512, 1024, 768, True, False, "float32", "float32", "float32", 64, 128, 32) + + +def test_pad_aligned_f16f16f16_nn(): + run_gemm( + 512 - 8, 1024 - 32, 768 - 24, False, False, "float16", "float16", "float16", 128, 256, 32, 2 + ) + + +def test_pad_f16f16f16_nn(): + run_gemm( + 512 - 9, 1024 - 7, 768 - 5, False, False, "float16", "float16", "float16", 128, 256, 32, 2 + ) + + +def test_pad_f16f16f32_nn(): + run_gemm( + 512 + 19, 1024 + 17, 768 + 15, False, False, "float16", "float16", "float32", 128, 64, 32 + ) + + +if __name__ == "__main__": + bitblas.testing.main() From 8e08e77e35123327033a866117b473e10777c429 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 14:43:35 +0000 Subject: [PATCH 58/88] format fix --- testing/python/tilelang/test_tilelang_gemm.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/testing/python/tilelang/test_tilelang_gemm.py b/testing/python/tilelang/test_tilelang_gemm.py index 50e39ed23..c75e4ccc1 100644 --- a/testing/python/tilelang/test_tilelang_gemm.py +++ b/testing/python/tilelang/test_tilelang_gemm.py @@ -29,9 +29,8 @@ def matmul( import tvm.tl.language as T @T.prim_func - def main( - A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, dtypeAB), C: T.Buffer((M, N), dtypeC) - ): + def main(A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, dtypeAB), C: T.Buffer((M, N), + dtypeC)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, dtypeAB) B_shared = T.alloc_shared(B_shared_shape, dtypeAB) @@ -156,21 +155,18 @@ def test_gemm_f32f32f32_tn(): def test_pad_aligned_f16f16f16_nn(): - run_gemm( - 512 - 8, 1024 - 32, 768 - 24, False, False, "float16", "float16", "float16", 128, 256, 32, 2 - ) + run_gemm(512 - 8, 1024 - 32, 768 - 24, False, False, "float16", "float16", "float16", 128, 256, + 32, 2) def test_pad_f16f16f16_nn(): - run_gemm( - 512 - 9, 1024 - 7, 768 - 5, False, False, "float16", "float16", "float16", 128, 256, 32, 2 - ) + run_gemm(512 - 9, 1024 - 7, 768 - 5, False, False, "float16", "float16", "float16", 128, 256, + 32, 2) def test_pad_f16f16f32_nn(): - run_gemm( - 512 + 19, 1024 + 17, 768 + 15, False, False, "float16", "float16", "float32", 128, 64, 32 - ) + run_gemm(512 + 19, 1024 + 17, 768 + 15, False, False, "float16", "float16", "float32", 128, 64, + 32) if __name__ == "__main__": From df05a642108841e959927e505b268564452cd441 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 15:53:48 +0000 Subject: [PATCH 59/88] Copy CUTLASS to the package directory --- setup.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/setup.py b/setup.py index 6954fa798..9950ac706 100644 --- a/setup.py +++ b/setup.py @@ -252,6 +252,22 @@ def run(self): os.makedirs(target_dir) shutil.copy2(source_dir, target_dir) + # Copy CUTLASS to the package directory + CUTLASS_PREBUILD_ITEMS = [ + "3rdparty/cutlass", + ] + for item in CUTLASS_PREBUILD_ITEMS: + source_dir = os.path.join(ROOT_DIR, item) + target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item) + if os.path.isdir(source_dir): + self.mkpath(target_dir) + distutils.dir_util.copy_tree(source_dir, target_dir) + else: + target_dir = os.path.dirname(target_dir) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + shutil.copy2(source_dir, target_dir) + class BitBLASSdistCommand(sdist): """Customized setuptools sdist command - includes the pyproject.toml file.""" From 4f529c5f4bff94b8183c78b02b6557a9e7c7f8ab Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 22 Jul 2024 16:45:30 +0000 Subject: [PATCH 60/88] Refactor setup.py to include additional TVM header files --- setup.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/setup.py b/setup.py index 9950ac706..f6f3fb171 100644 --- a/setup.py +++ b/setup.py @@ -239,6 +239,16 @@ def run(self): "3rdparty/tvm/mypy.ini", "3rdparty/tvm/pyproject.toml", "3rdparty/tvm/version.py", + "3rdparty/tvm/common.h", + "3rdparty/tvm/copy.h", + "3rdparty/tvm/copy_sm90.h", + "3rdparty/tvm/gemm.h", + "3rdparty/tvm/gemm_sm70.h", + "3rdparty/tvm/gemm_sm80.h", + "3rdparty/tvm/gemm_sm90.h", + "3rdparty/tvm/ldsm.h", + "3rdparty/tvm/reduce.h", + "3rdparty/tvm/threadblock_swizzle.h" ] for item in TVM_PREBUILD_ITEMS: source_dir = os.path.join(ROOT_DIR, item) From d02bbc7458db2f0e633065468eaa1352b2e909ea Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 02:59:46 +0000 Subject: [PATCH 61/88] lint fix --- setup.py | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/setup.py b/setup.py index f6f3fb171..46c95e454 100644 --- a/setup.py +++ b/setup.py @@ -226,29 +226,14 @@ def run(self): # Copy the built TVM to the package directory TVM_PREBUILD_ITEMS = [ - "3rdparty/tvm/build/libtvm_runtime.so", - "3rdparty/tvm/build/libtvm.so", - "3rdparty/tvm/build/config.cmake", - "3rdparty/tvm/python", - "3rdparty/tvm/licenses", - "3rdparty/tvm/conftest.py", - "3rdparty/tvm/CONTRIBUTORS.md", - "3rdparty/tvm/KEYS", - "3rdparty/tvm/LICENSE", - "3rdparty/tvm/README.md", - "3rdparty/tvm/mypy.ini", - "3rdparty/tvm/pyproject.toml", - "3rdparty/tvm/version.py", - "3rdparty/tvm/common.h", - "3rdparty/tvm/copy.h", - "3rdparty/tvm/copy_sm90.h", - "3rdparty/tvm/gemm.h", - "3rdparty/tvm/gemm_sm70.h", - "3rdparty/tvm/gemm_sm80.h", - "3rdparty/tvm/gemm_sm90.h", - "3rdparty/tvm/ldsm.h", - "3rdparty/tvm/reduce.h", - "3rdparty/tvm/threadblock_swizzle.h" + "3rdparty/tvm/build/libtvm_runtime.so", "3rdparty/tvm/build/libtvm.so", + "3rdparty/tvm/build/config.cmake", "3rdparty/tvm/python", "3rdparty/tvm/licenses", + "3rdparty/tvm/conftest.py", "3rdparty/tvm/CONTRIBUTORS.md", "3rdparty/tvm/KEYS", + "3rdparty/tvm/LICENSE", "3rdparty/tvm/README.md", "3rdparty/tvm/mypy.ini", + "3rdparty/tvm/pyproject.toml", "3rdparty/tvm/version.py", "3rdparty/tvm/common.h", + "3rdparty/tvm/copy.h", "3rdparty/tvm/copy_sm90.h", "3rdparty/tvm/gemm.h", + "3rdparty/tvm/gemm_sm70.h", "3rdparty/tvm/gemm_sm80.h", "3rdparty/tvm/gemm_sm90.h", + "3rdparty/tvm/ldsm.h", "3rdparty/tvm/reduce.h", "3rdparty/tvm/threadblock_swizzle.h" ] for item in TVM_PREBUILD_ITEMS: source_dir = os.path.join(ROOT_DIR, item) From cffe3fde336f78ed6b3b6560fc9d5a3c33c1f86d Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 03:47:23 +0000 Subject: [PATCH 62/88] bug fix --- setup.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index 46c95e454..5fe71db40 100644 --- a/setup.py +++ b/setup.py @@ -226,14 +226,20 @@ def run(self): # Copy the built TVM to the package directory TVM_PREBUILD_ITEMS = [ - "3rdparty/tvm/build/libtvm_runtime.so", "3rdparty/tvm/build/libtvm.so", - "3rdparty/tvm/build/config.cmake", "3rdparty/tvm/python", "3rdparty/tvm/licenses", - "3rdparty/tvm/conftest.py", "3rdparty/tvm/CONTRIBUTORS.md", "3rdparty/tvm/KEYS", - "3rdparty/tvm/LICENSE", "3rdparty/tvm/README.md", "3rdparty/tvm/mypy.ini", - "3rdparty/tvm/pyproject.toml", "3rdparty/tvm/version.py", "3rdparty/tvm/common.h", - "3rdparty/tvm/copy.h", "3rdparty/tvm/copy_sm90.h", "3rdparty/tvm/gemm.h", - "3rdparty/tvm/gemm_sm70.h", "3rdparty/tvm/gemm_sm80.h", "3rdparty/tvm/gemm_sm90.h", - "3rdparty/tvm/ldsm.h", "3rdparty/tvm/reduce.h", "3rdparty/tvm/threadblock_swizzle.h" + "3rdparty/tvm/build/libtvm_runtime.so", + "3rdparty/tvm/build/libtvm.so", + "3rdparty/tvm/build/config.cmake", + "3rdparty/tvm/python", + "3rdparty/tvm/licenses", + "3rdparty/tvm/conftest.py", + "3rdparty/tvm/CONTRIBUTORS.md", + "3rdparty/tvm/KEYS", + "3rdparty/tvm/LICENSE", + "3rdparty/tvm/README.md", + "3rdparty/tvm/mypy.ini", + "3rdparty/tvm/pyproject.toml", + "3rdparty/tvm/version.py", + "3rdparty/tvm/src/tl/tl_templates", ] for item in TVM_PREBUILD_ITEMS: source_dir = os.path.join(ROOT_DIR, item) From a8bed748d8a2a86b51a2c7838ce0a5827ee33122 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 04:32:21 +0000 Subject: [PATCH 63/88] Refactor BitBLASLinear test module for improved readability and maintainability --- benchmark/operators/base.py | 67 +++++++++++++++++++++ benchmark/operators/benchmark_ops_matmul.py | 45 ++++++++++++++ bitblas/ops/__init__.py | 2 +- 3 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 benchmark/operators/base.py create mode 100644 benchmark/operators/benchmark_ops_matmul.py diff --git a/benchmark/operators/base.py b/benchmark/operators/base.py new file mode 100644 index 000000000..2bd32feb1 --- /dev/null +++ b/benchmark/operators/base.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple, Optional +from bitblas.ops import Operator, OperatorConfig +from bitblas import auto_detect_nvidia_target + +class BitblasOperatorBenchmarkBase(ABC): + + # separate benchmark sets for different operators set + benchmark_sets: Dict[str, List[Tuple[Operator, OperatorConfig]]] = {} + + # currently we only support nvidia target for benchmarking + benchmark_target: str = auto_detect_nvidia_target() + + # benchmark results + benchmark_results: Dict[str, List[Optional[float]]] = {} + + @abstractmethod + def prepare_benchmark_sets(self): + pass + + def add_benchmark_set(self, name:str, benchmark_set:List[Tuple[Operator, OperatorConfig]]): + if name in self.benchmark_sets: + self.benchmark_sets[name].extend(benchmark_set) + else: + self.benchmark_sets[name] = benchmark_set + + def run(self): + self.prepare_benchmark_sets() + self.benchmark() + print("Benchmark results:", self.benchmark_results) + self.report() + self.cleanup() + + def report(self): + return NotImplementedError + + def cleanup(self): + # clean up the benchmark sets + self.benchmark_sets.clear() + + def benchmark(self): + for name, benchmark_set in self.benchmark_sets.items(): + self.benchmark_results[name] = [] + for operator, config in benchmark_set: + self.benchmark_results[name].append(self.run_benchmark(operator, config)) + + def run_benchmark(self, operator:Operator, config:OperatorConfig) -> Optional[float]: + op_inst = operator(config, target=self.benchmark_target) + return op_inst.profile_latency() + + @abstractmethod + def get_operator(self) -> Operator: + raise NotImplementedError + + @abstractmethod + def get_operator_config(self) -> OperatorConfig: + raise NotImplementedError + + def get_benchmark_sets(self, name:Optional[str]=None) -> List[Tuple[Operator, OperatorConfig]]: + if name is None: + return self.benchmark_sets + else: + assert name in self.benchmark_sets, f"Operator {name} not found in benchmark sets" + return self.benchmark_sets[name] diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py new file mode 100644 index 000000000..50010c3f6 --- /dev/null +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from base import BitblasOperatorBenchmarkBase +from bitblas.ops import Matmul, MatmulConfig +from bitblas import set_log_level + +set_log_level("DEBUG") + +class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase): + + config_map = { + "FP16xFP16_ACCFP16_NT": { + "in_dtype": "float16", + "out_dtype": "float16", + "accum_dtype": "float16", + } + } + + def prepare_benchmark_sets(self): + self.add_benchmark_set( + "FP16xFP16_ACCFP16_NT", + [ + (Matmul, self.generate_operator_config("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384)), + ], + ) + + def generate_operator_config(self, name:str, M, N, K) -> MatmulConfig: + if name not in self.config_map: + raise ValueError(f"Operator {name} not found in config map") + return MatmulConfig( + M=M, + N=N, + K=K, + **self.config_map[name], + ) + + def get_operator(self): + return Matmul + + def get_operator_config(self): + return MatmulConfig + +if __name__ == "__main__": + BitblasMatmulOpsBenchmark().run() diff --git a/bitblas/ops/__init__.py b/bitblas/ops/__init__.py index cdacc5bad..a8704141b 100644 --- a/bitblas/ops/__init__.py +++ b/bitblas/ops/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .operator import Operator # noqa: F401 +from .operator import Operator, OperatorConfig # noqa: F401 from .matmul import Matmul, MatmulConfig # noqa: F401 from .matmul_dequantize import MatmulWeightOnlyDequantize, MatmulWeightOnlyDequantizeConfig # noqa: F401 from .ladder_permutate import LadderPermutate, LadderPermutateConfig # noqa: F401 From d4eb5fd678865e07290bab657c995a567a63d17d Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 08:39:58 +0000 Subject: [PATCH 64/88] Implement Matmul Benchmark Design --- benchmark/operators/base.py | 67 ---------- benchmark/operators/benchmark_ops_matmul.py | 107 +++++++++++++++- bitblas/benchmark/__init__.py | 4 + bitblas/benchmark/operator/__init__.py | 134 ++++++++++++++++++++ bitblas/cache/operator.py | 3 +- bitblas/ops/operator.py | 34 +++-- bitblas/utils/__init__.py | 18 +++ 7 files changed, 286 insertions(+), 81 deletions(-) delete mode 100644 benchmark/operators/base.py create mode 100644 bitblas/benchmark/__init__.py create mode 100644 bitblas/benchmark/operator/__init__.py diff --git a/benchmark/operators/base.py b/benchmark/operators/base.py deleted file mode 100644 index 2bd32feb1..000000000 --- a/benchmark/operators/base.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from abc import ABC, abstractmethod -from typing import Dict, List, Tuple, Optional -from bitblas.ops import Operator, OperatorConfig -from bitblas import auto_detect_nvidia_target - -class BitblasOperatorBenchmarkBase(ABC): - - # separate benchmark sets for different operators set - benchmark_sets: Dict[str, List[Tuple[Operator, OperatorConfig]]] = {} - - # currently we only support nvidia target for benchmarking - benchmark_target: str = auto_detect_nvidia_target() - - # benchmark results - benchmark_results: Dict[str, List[Optional[float]]] = {} - - @abstractmethod - def prepare_benchmark_sets(self): - pass - - def add_benchmark_set(self, name:str, benchmark_set:List[Tuple[Operator, OperatorConfig]]): - if name in self.benchmark_sets: - self.benchmark_sets[name].extend(benchmark_set) - else: - self.benchmark_sets[name] = benchmark_set - - def run(self): - self.prepare_benchmark_sets() - self.benchmark() - print("Benchmark results:", self.benchmark_results) - self.report() - self.cleanup() - - def report(self): - return NotImplementedError - - def cleanup(self): - # clean up the benchmark sets - self.benchmark_sets.clear() - - def benchmark(self): - for name, benchmark_set in self.benchmark_sets.items(): - self.benchmark_results[name] = [] - for operator, config in benchmark_set: - self.benchmark_results[name].append(self.run_benchmark(operator, config)) - - def run_benchmark(self, operator:Operator, config:OperatorConfig) -> Optional[float]: - op_inst = operator(config, target=self.benchmark_target) - return op_inst.profile_latency() - - @abstractmethod - def get_operator(self) -> Operator: - raise NotImplementedError - - @abstractmethod - def get_operator_config(self) -> OperatorConfig: - raise NotImplementedError - - def get_benchmark_sets(self, name:Optional[str]=None) -> List[Tuple[Operator, OperatorConfig]]: - if name is None: - return self.benchmark_sets - else: - assert name in self.benchmark_sets, f"Operator {name} not found in benchmark sets" - return self.benchmark_sets[name] diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index 50010c3f6..8a2092d21 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -1,14 +1,24 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from base import BitblasOperatorBenchmarkBase +from bitblas.benchmark import BitblasOperatorBenchmarkBase from bitblas.ops import Matmul, MatmulConfig +from bitblas.utils import get_commit_id from bitblas import set_log_level +from tabulate import tabulate +import json +from os import path, makedirs +from typing import Tuple set_log_level("DEBUG") + class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase): + BENCHMARK_RESULTS_FILE = "benchmark_results.json" + BENCHMARK_SHAPES_FILE = "benchmark_shapes.json" + BENCHMARK_DEVICE_FILE = "benchmark_device.json" + config_map = { "FP16xFP16_ACCFP16_NT": { "in_dtype": "float16", @@ -17,15 +27,23 @@ class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase): } } + CURRENT_COMMIT_ID = get_commit_id() + def prepare_benchmark_sets(self): + """Prepare benchmark sets.""" + self.disable_tuning() self.add_benchmark_set( "FP16xFP16_ACCFP16_NT", [ - (Matmul, self.generate_operator_config("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384)), + ( + Matmul, + self.generate_operator_config("FP16xFP16_ACCFP16_NT", [1, 1024], 16384, 16384), + ), ], ) - def generate_operator_config(self, name:str, M, N, K) -> MatmulConfig: + def generate_operator_config(self, name: str, M, N, K) -> MatmulConfig: + """Generate configuration for the given operator.""" if name not in self.config_map: raise ValueError(f"Operator {name} not found in config map") return MatmulConfig( @@ -35,11 +53,94 @@ def generate_operator_config(self, name:str, M, N, K) -> MatmulConfig: **self.config_map[name], ) + def serialize_results(self) -> None: + """Serialize benchmark results into JSON files.""" + commit_id_path = f"CommitID_{self.CURRENT_COMMIT_ID}" + log_commit_path = path.join(self.log_path, commit_id_path) + + if not path.exists(log_commit_path): + makedirs(log_commit_path) + + # Save benchmark results into JSON + self._save_json( + self.benchmark_results, + path.join(log_commit_path, self.BENCHMARK_RESULTS_FILE), + ) + + # Save benchmark shapes into JSON + shapes = [(config.M, config.N, config.K) + for name, results in self.benchmark_results.items() for i, _ in enumerate(results) + for config in [self.benchmark_sets[name][i][1]]] + self._save_json(shapes, path.join(log_commit_path, self.BENCHMARK_SHAPES_FILE)) + + # Save device info into JSON + self._save_json( + {"device": self.benchmark_target}, + path.join(log_commit_path, self.BENCHMARK_DEVICE_FILE), + ) + + def _save_json(self, data, file_path): + """Helper function to save JSON data to a file.""" + with open(file_path, "w") as f: + json.dump(data, f) + + def deserialize_results(self, log_path: str) -> None: + """Deserialize benchmark results from JSON files.""" + self.benchmark_results = self._load_json(path.join(log_path, self.BENCHMARK_RESULTS_FILE)) + + shapes_file = path.join(log_path, self.BENCHMARK_SHAPES_FILE) + with open(shapes_file, "r") as f: + shapes = json.load(f) + # TODO: Reconstruction of benchmark_sets from shapes + del shapes + + self.benchmark_target = self._load_json(path.join(log_path, + self.BENCHMARK_DEVICE_FILE))["device"] + + def _load_json(self, file_path): + """Helper function to load JSON data from a file.""" + with open(file_path, "r") as f: + return json.load(f) + + def report(self): + """Generate and print a report of the benchmark results.""" + for name, results in self.benchmark_results.items(): + table_data = [ + ["TAG:", name, "Device:", self.benchmark_target], + [ + "Shape (M-N-K)", + "Time (ms)", + "Throughput (TFLOPS)", + "Tune Time (s)", + ], + ] + + for i, (latency, tuning_time) in enumerate(results): + op_config = self.benchmark_sets[name][i][1] + shape = f"{op_config.M}-{op_config.N}-{op_config.K}" + + benchmark_M = ( + sum(op_config.M) / + len(op_config.M) if isinstance(op_config.M, Tuple) else op_config.M) + + throughput = ( + f"{(2 * benchmark_M * op_config.N * op_config.K / (latency * 1e-3) / 1e12):.3f}" + if latency else "N/A") + latency_str = "N/A" if latency is None else f"{latency:.3f}" + tuning_time_str = ("N/A" if tuning_time is None else f"{tuning_time:.3f}") + + table_data.append([shape, latency_str, throughput, tuning_time_str]) + + print(tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")) + def get_operator(self): + """Return the Matmul operator.""" return Matmul def get_operator_config(self): + """Return the Matmul operator configuration.""" return MatmulConfig + if __name__ == "__main__": BitblasMatmulOpsBenchmark().run() diff --git a/bitblas/benchmark/__init__.py b/bitblas/benchmark/__init__.py new file mode 100644 index 000000000..cbd6c1c3e --- /dev/null +++ b/bitblas/benchmark/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .operator import BitblasOperatorBenchmarkBase # noqa: F401 diff --git a/bitblas/benchmark/operator/__init__.py b/bitblas/benchmark/operator/__init__.py new file mode 100644 index 000000000..0e2ad0b4a --- /dev/null +++ b/bitblas/benchmark/operator/__init__.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from os import path, makedirs +from time import perf_counter +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple, Optional +from bitblas.ops import Operator, OperatorConfig +from bitblas.utils import get_default_cache_path +from bitblas import auto_detect_nvidia_target +from bitblas import tvm as tvm + +class BitblasOperatorBenchmarkBase(ABC): + # Separate benchmark sets for different operators + benchmark_sets: Dict[str, List[Tuple[Operator, OperatorConfig]]] = {} + + # Currently we only support NVIDIA target for benchmarking + benchmark_target: str = auto_detect_nvidia_target() + + # Benchmark results: a list of tuples, each containing latency and tuning time + benchmark_results: Dict[str, List[Tuple[Optional[float], Optional[float]]]] = {} + + # Enable hardware-aware tuning + enable_hardware_aware_tuning: bool = False + + # Log path + log_path: Optional[str] = None + + # Dynamic symbolic constraints + dynamic_symbolic_constraints: Optional[Dict] = None + + @abstractmethod + def prepare_benchmark_sets(self): + pass + + def add_benchmark_set(self, name: str, benchmark_set: List[Tuple[Operator, OperatorConfig]]): + """Add a benchmark set to the collection.""" + if name in self.benchmark_sets: + self.benchmark_sets[name].extend(benchmark_set) + else: + self.benchmark_sets[name] = benchmark_set + + def run(self, report=True, serialize=True, enable_tuning: bool = False): + """Run the benchmark process.""" + self.log_path = path.join(get_default_cache_path(), "benchmark") + + if not path.exists(self.log_path): + makedirs(self.log_path) + + if enable_tuning: + self.enable_tuning() + + self.prepare_benchmark_sets() + self.benchmark() + + if report: + self.report() + + if serialize: + self.serialize_results() + + self.cleanup() + + @abstractmethod + def report(self): + """Generate a report of the benchmark results.""" + raise NotImplementedError + + def cleanup(self): + """Clean up the benchmark sets.""" + self.benchmark_sets.clear() + + def benchmark(self): + """Run benchmarks on all benchmark sets.""" + for name, benchmark_set in self.benchmark_sets.items(): + self.benchmark_results[name] = [self.run_benchmark(op, config) for op, config in benchmark_set] + + def run_benchmark(self, operator: Operator, config: OperatorConfig) -> Optional[float]: + """Run a single benchmark.""" + op_inst = operator(config, target=self.benchmark_target) + tuning_time = None + + if self.enable_hardware_aware_tuning: + start = perf_counter() + op_inst.hardware_aware_finetune(topk=20, parallel_build=True) + tuning_time = perf_counter() - start + + latency = op_inst.profile_latency(dynamic_symbolic_constraints=self.dynamic_symbolic_constraints) + + return latency, tuning_time + + @abstractmethod + def get_operator(self) -> Operator: + """Get the operator to be benchmarked.""" + raise NotImplementedError + + @abstractmethod + def get_operator_config(self) -> OperatorConfig: + """Get the configuration for the operator.""" + raise NotImplementedError + + def get_benchmark_sets(self, name: Optional[str] = None) -> List[Tuple[Operator, OperatorConfig]]: + """Retrieve benchmark sets by name, or all if name is None.""" + if name is None: + return self.benchmark_sets + else: + assert name in self.benchmark_sets, f"Operator {name} not found in benchmark sets" + return self.benchmark_sets[name] + + @abstractmethod + def serialize_results(self) -> None: + """Serialize the benchmark results.""" + pass + + @abstractmethod + def deserialize_results(self) -> None: + """Deserialize the benchmark results.""" + pass + + def enable_tuning(self): + """Enable hardware-aware tuning.""" + self.enable_hardware_aware_tuning = True + + def disable_tuning(self): + """Disable hardware-aware tuning.""" + self.enable_hardware_aware_tuning = False + + def set_log_path(self, log_path: str): + """Set the log path.""" + self.log_path = log_path + + def set_dynamic_symbolic_constraints(self, dynamic_symbolic_constraints: Dict): + """Set dynamic symbolic constraints.""" + self.dynamic_symbolic_constraints = dynamic_symbolic_constraints diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index 17702c6fc..6c5ea1ebe 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import bitblas +from bitblas.utils import get_default_cache_path from bitblas.ops.operator import OperatorConfig, Operator from dataclasses import asdict import os @@ -14,7 +15,7 @@ logger = logging.getLogger(__name__) -BITBLAS_DATABASE_PATH = os.path.expanduser("~/.cache/bitblas") +BITBLAS_DATABASE_PATH = get_default_cache_path() BITBLAS_WRAPPED_SOURCE_NAME = "wrapper_source.cu" BITBLAS_WRAPPED_COMPILED_NAME = "wrapper_compiled.so" diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 9c592f9f2..8617d70b9 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -207,19 +207,33 @@ def hardware_aware_finetune(self, func, target, topk, parallel_build=parallel_build) self._build_runtime_module(self.target) - def get_profile_tensors(self, dynamic_symbolic_constrains: Optional[Dict] = None): - if dynamic_symbolic_constrains is None: - dynamic_symbolic_constrains = {} + def get_profile_tensors(self, dynamic_symbolic_constraints: Optional[Dict] = None): + if dynamic_symbolic_constraints is None: + dynamic_symbolic_constraints = {} func = self.prim_func device = self.arch.device def var_warpper(v): if isinstance(v, tvm.tir.Var): - if v.name in dynamic_symbolic_constrains: - return dynamic_symbolic_constrains[v.name] + if v.name in dynamic_symbolic_constraints: + return dynamic_symbolic_constraints[v.name] assert "opt_shapes" in func.attrs assert v.name in func.attrs["opt_shapes"] - return func.attrs["opt_shapes"][v.name].value + if isinstance(func.attrs["opt_shapes"][v.name], tvm.tir.IntImm): + return func.attrs["opt_shapes"][v.name].value + elif isinstance(func.attrs["opt_shapes"][v.name], tvm.ir.container.Array): + avg_shape: int = 0 + for i in func.attrs["opt_shapes"][v.name]: + avg_shape += i.value + avg_shape = avg_shape // len(func.attrs["opt_shapes"][v.name]) + _info_message = f"Doesn't provide dynamic symbolic constrains for {v.name} when do benchmarking, "\ + f"use average shape {avg_shape}" + logger.info(_info_message) + return avg_shape + else: + raise RuntimeError("Not supported type: ", + type(func.attrs["opt_shapes"][v.name])) + elif isinstance(v, tvm.tir.IntImm): return v.value else: @@ -251,10 +265,10 @@ def map_numpy_type(intype): self.profile_tensors = profile_tensors return profile_tensors - def profile_latency(self, dynamic_symbolic_constrains: Optional[Dict] = None) -> str: - if dynamic_symbolic_constrains is None: - dynamic_symbolic_constrains = {} - profile_tensors = self.get_profile_tensors(dynamic_symbolic_constrains) + def profile_latency(self, dynamic_symbolic_constraints: Optional[Dict] = None) -> str: + if dynamic_symbolic_constraints is None: + dynamic_symbolic_constraints = {} + profile_tensors = self.get_profile_tensors(dynamic_symbolic_constraints) latency = self.time_evaluator(*profile_tensors).mean * 1e3 return latency diff --git a/bitblas/utils/__init__.py b/bitblas/utils/__init__.py index bdf9589f7..d4ded65e6 100644 --- a/bitblas/utils/__init__.py +++ b/bitblas/utils/__init__.py @@ -4,3 +4,21 @@ from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401 from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401 from .rtmod_analysis import get_annotated_device_mod # noqa: F401 + +import os +import subprocess + +BITBLAS_DEFAULT_CACHE_PATH = os.path.expanduser("~/.cache/bitblas") + + +def get_commit_id(): + try: + commit_id = (subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8")) + return commit_id + except subprocess.CalledProcessError as e: + print(f"Error: {e.output}") + return None + + +def get_default_cache_path(): + return BITBLAS_DEFAULT_CACHE_PATH From 4c6c2c15efc7fed39542339d19a063e57196dd4d Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 09:20:02 +0000 Subject: [PATCH 65/88] chore: Update BitBLAS Matmul benchmark script --- benchmark/operators/benchmark_ops_matmul.py | 64 ++++++++++++++--- bitblas/benchmark/operator/__init__.py | 79 ++++++++++++++------- 2 files changed, 107 insertions(+), 36 deletions(-) diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index 8a2092d21..4bdbc361b 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from bitblas.benchmark import BitblasOperatorBenchmarkBase -from bitblas.ops import Matmul, MatmulConfig +from bitblas import Matmul, MatmulConfig from bitblas.utils import get_commit_id from bitblas import set_log_level from tabulate import tabulate @@ -12,6 +12,12 @@ set_log_level("DEBUG") +HELPER_MESSAGE = """ +**Note**: Bitblas supports dynamic shape tensors as input, resulting in two possible formats for the \ +"Shape (M-N-K / N-K_M)" column in the report. The "M-N-K" format indicates a static shape operator, \ +while the "N-K_M" format denotes a dynamic shape operator where only the M dimension is dynamic. \ +In this context, "_M" represents the specific M shape used for dynamic profiling. +""" class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase): @@ -21,8 +27,19 @@ class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase): config_map = { "FP16xFP16_ACCFP16_NT": { - "in_dtype": "float16", - "out_dtype": "float16", + "A_dtype": "float16", + "W_dtype": "float16", + "accum_dtype": "float16", + }, + "INT8xINT8_ACCINT32_NT": { + "A_dtype": "int8", + "W_dtype": "int8", + "accum_dtype": "int32", + "out_dtype": "int8", + }, + "FP16xINT4_ACCINT32_NT": { + "A_dtype": "float16", + "W_dtype": "int4", "accum_dtype": "float16", } } @@ -31,13 +48,19 @@ class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase): def prepare_benchmark_sets(self): """Prepare benchmark sets.""" - self.disable_tuning() self.add_benchmark_set( "FP16xFP16_ACCFP16_NT", [ - ( - Matmul, - self.generate_operator_config("FP16xFP16_ACCFP16_NT", [1, 1024], 16384, 16384), + self.generate_op_unit( + self.generate_operator_config( + "FP16xFP16_ACCFP16_NT", 16384, 16384, 16384 + ), + ), + self.generate_op_unit( + self.generate_operator_config( + "FP16xFP16_ACCFP16_NT", [1, 1024], 16384, 16384 + ), + dynamic_profiling_shape={"M": 1024}, ), ], ) @@ -46,7 +69,7 @@ def generate_operator_config(self, name: str, M, N, K) -> MatmulConfig: """Generate configuration for the given operator.""" if name not in self.config_map: raise ValueError(f"Operator {name} not found in config map") - return MatmulConfig( + return self.get_operator_config()( M=M, N=N, K=K, @@ -108,16 +131,29 @@ def report(self): table_data = [ ["TAG:", name, "Device:", self.benchmark_target], [ - "Shape (M-N-K)", + "Shape (M-N-K / N-K_M)", "Time (ms)", "Throughput (TFLOPS)", "Tune Time (s)", ], ] + def legalize_shape(M, N, K, dyn_prof_shape): + if isinstance(M, int): + return f"{M}-{N}-{K}" + elif dyn_prof_shape: + return f"{N}-{K}_{dyn_prof_shape['M']}" + else: + assert isinstance(M, Tuple) + opt_m = sum(M) / len(M) + return f"{N}-{K}_{opt_m}" + for i, (latency, tuning_time) in enumerate(results): op_config = self.benchmark_sets[name][i][1] - shape = f"{op_config.M}-{op_config.N}-{op_config.K}" + dyn_prof_shape = self.benchmark_sets[name][i][2] + shape = legalize_shape( + op_config.M, op_config.N, op_config.K, dyn_prof_shape + ) benchmark_M = ( sum(op_config.M) / @@ -132,6 +168,7 @@ def report(self): table_data.append([shape, latency_str, throughput, tuning_time_str]) print(tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")) + print(HELPER_MESSAGE) def get_operator(self): """Return the Matmul operator.""" @@ -141,6 +178,11 @@ def get_operator_config(self): """Return the Matmul operator configuration.""" return MatmulConfig + def make_operator(self, operator: Matmul, config: MatmulConfig) -> Matmul: + """Make an Matmul instance.""" + # Disable default tuning when do benchmark + return operator(config, target=self.benchmark_target, enable_tuning=False) + if __name__ == "__main__": - BitblasMatmulOpsBenchmark().run() + BitblasMatmulOpsBenchmark().run(enable_tuning=True) diff --git a/bitblas/benchmark/operator/__init__.py b/bitblas/benchmark/operator/__init__.py index 0e2ad0b4a..39f7bd90b 100644 --- a/bitblas/benchmark/operator/__init__.py +++ b/bitblas/benchmark/operator/__init__.py @@ -10,15 +10,21 @@ from bitblas import auto_detect_nvidia_target from bitblas import tvm as tvm + class BitblasOperatorBenchmarkBase(ABC): - # Separate benchmark sets for different operators - benchmark_sets: Dict[str, List[Tuple[Operator, OperatorConfig]]] = {} + # Separate benchmark sets for different operators, where the last key represents + # the dynamic profing shape + benchmark_sets: Dict[ + str, List[Tuple[Operator, OperatorConfig, Optional[int]]] + ] = {} # Currently we only support NVIDIA target for benchmarking benchmark_target: str = auto_detect_nvidia_target() # Benchmark results: a list of tuples, each containing latency and tuning time - benchmark_results: Dict[str, List[Tuple[Optional[float], Optional[float]]]] = {} + benchmark_results: Dict[ + str, List[Tuple[Optional[float], Optional[float]]] + ] = {} # Enable hardware-aware tuning enable_hardware_aware_tuning: bool = False @@ -26,14 +32,25 @@ class BitblasOperatorBenchmarkBase(ABC): # Log path log_path: Optional[str] = None - # Dynamic symbolic constraints - dynamic_symbolic_constraints: Optional[Dict] = None - @abstractmethod def prepare_benchmark_sets(self): pass - def add_benchmark_set(self, name: str, benchmark_set: List[Tuple[Operator, OperatorConfig]]): + def generate_op_unit( + self, + config: OperatorConfig, + dynamic_profiling_shape: Optional[Dict[str, int]] = None, + ) -> Tuple[Operator, OperatorConfig, Optional[Dict[str, int]]]: + """Generate a benchmark element for an operator.""" + return self.get_operator(), config, dynamic_profiling_shape + + def add_benchmark_set( + self, + name: str, + benchmark_set: List[ + Tuple[Operator, OperatorConfig, Optional[Dict[str, int]]] + ], + ): """Add a benchmark set to the collection.""" if name in self.benchmark_sets: self.benchmark_sets[name].extend(benchmark_set) @@ -49,16 +66,16 @@ def run(self, report=True, serialize=True, enable_tuning: bool = False): if enable_tuning: self.enable_tuning() - + self.prepare_benchmark_sets() self.benchmark() - + if report: self.report() - + if serialize: self.serialize_results() - + self.cleanup() @abstractmethod @@ -73,20 +90,32 @@ def cleanup(self): def benchmark(self): """Run benchmarks on all benchmark sets.""" for name, benchmark_set in self.benchmark_sets.items(): - self.benchmark_results[name] = [self.run_benchmark(op, config) for op, config in benchmark_set] - - def run_benchmark(self, operator: Operator, config: OperatorConfig) -> Optional[float]: + self.benchmark_results[name] = [ + self.run_benchmark(op, config, opt) for op, config, opt in benchmark_set + ] + + def make_operator( + self, operator: Operator, config: OperatorConfig + ) -> Operator: + """Make an operator instance.""" + return operator(config, target=self.benchmark_target) + + def run_benchmark( + self, operator: Operator, config: OperatorConfig, dynamic_profiling_shape: Optional[Dict[str, int]]=None, + ) -> Optional[float]: """Run a single benchmark.""" - op_inst = operator(config, target=self.benchmark_target) + op_inst = self.make_operator(operator, config) tuning_time = None - + if self.enable_hardware_aware_tuning: start = perf_counter() op_inst.hardware_aware_finetune(topk=20, parallel_build=True) tuning_time = perf_counter() - start - - latency = op_inst.profile_latency(dynamic_symbolic_constraints=self.dynamic_symbolic_constraints) - + + latency = op_inst.profile_latency( + dynamic_symbolic_constraints=dynamic_profiling_shape + ) + return latency, tuning_time @abstractmethod @@ -99,12 +128,16 @@ def get_operator_config(self) -> OperatorConfig: """Get the configuration for the operator.""" raise NotImplementedError - def get_benchmark_sets(self, name: Optional[str] = None) -> List[Tuple[Operator, OperatorConfig]]: + def get_benchmark_sets( + self, name: Optional[str] = None + ) -> List[Tuple[Operator, OperatorConfig]]: """Retrieve benchmark sets by name, or all if name is None.""" if name is None: return self.benchmark_sets else: - assert name in self.benchmark_sets, f"Operator {name} not found in benchmark sets" + assert ( + name in self.benchmark_sets + ), f"Operator {name} not found in benchmark sets" return self.benchmark_sets[name] @abstractmethod @@ -128,7 +161,3 @@ def disable_tuning(self): def set_log_path(self, log_path: str): """Set the log path.""" self.log_path = log_path - - def set_dynamic_symbolic_constraints(self, dynamic_symbolic_constraints: Dict): - """Set dynamic symbolic constraints.""" - self.dynamic_symbolic_constraints = dynamic_symbolic_constraints From 0acaca10e37a48ebbaf5d07315f126c7f05a7b87 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 09:20:33 +0000 Subject: [PATCH 66/88] lint fix --- benchmark/operators/benchmark_ops_matmul.py | 14 +++------ bitblas/benchmark/__init__.py | 2 +- bitblas/benchmark/operator/__init__.py | 34 ++++++++------------- 3 files changed, 17 insertions(+), 33 deletions(-) diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index 4bdbc361b..6832ba2d0 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -19,6 +19,7 @@ In this context, "_M" represents the specific M shape used for dynamic profiling. """ + class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase): BENCHMARK_RESULTS_FILE = "benchmark_results.json" @@ -52,14 +53,9 @@ def prepare_benchmark_sets(self): "FP16xFP16_ACCFP16_NT", [ self.generate_op_unit( - self.generate_operator_config( - "FP16xFP16_ACCFP16_NT", 16384, 16384, 16384 - ), - ), + self.generate_operator_config("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384),), self.generate_op_unit( - self.generate_operator_config( - "FP16xFP16_ACCFP16_NT", [1, 1024], 16384, 16384 - ), + self.generate_operator_config("FP16xFP16_ACCFP16_NT", [1, 1024], 16384, 16384), dynamic_profiling_shape={"M": 1024}, ), ], @@ -151,9 +147,7 @@ def legalize_shape(M, N, K, dyn_prof_shape): for i, (latency, tuning_time) in enumerate(results): op_config = self.benchmark_sets[name][i][1] dyn_prof_shape = self.benchmark_sets[name][i][2] - shape = legalize_shape( - op_config.M, op_config.N, op_config.K, dyn_prof_shape - ) + shape = legalize_shape(op_config.M, op_config.N, op_config.K, dyn_prof_shape) benchmark_M = ( sum(op_config.M) / diff --git a/bitblas/benchmark/__init__.py b/bitblas/benchmark/__init__.py index cbd6c1c3e..d66d5da2d 100644 --- a/bitblas/benchmark/__init__.py +++ b/bitblas/benchmark/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .operator import BitblasOperatorBenchmarkBase # noqa: F401 +from .operator import BitblasOperatorBenchmarkBase # noqa: F401 diff --git a/bitblas/benchmark/operator/__init__.py b/bitblas/benchmark/operator/__init__.py index 39f7bd90b..113aab5c5 100644 --- a/bitblas/benchmark/operator/__init__.py +++ b/bitblas/benchmark/operator/__init__.py @@ -14,17 +14,13 @@ class BitblasOperatorBenchmarkBase(ABC): # Separate benchmark sets for different operators, where the last key represents # the dynamic profing shape - benchmark_sets: Dict[ - str, List[Tuple[Operator, OperatorConfig, Optional[int]]] - ] = {} + benchmark_sets: Dict[str, List[Tuple[Operator, OperatorConfig, Optional[int]]]] = {} # Currently we only support NVIDIA target for benchmarking benchmark_target: str = auto_detect_nvidia_target() # Benchmark results: a list of tuples, each containing latency and tuning time - benchmark_results: Dict[ - str, List[Tuple[Optional[float], Optional[float]]] - ] = {} + benchmark_results: Dict[str, List[Tuple[Optional[float], Optional[float]]]] = {} # Enable hardware-aware tuning enable_hardware_aware_tuning: bool = False @@ -47,9 +43,7 @@ def generate_op_unit( def add_benchmark_set( self, name: str, - benchmark_set: List[ - Tuple[Operator, OperatorConfig, Optional[Dict[str, int]]] - ], + benchmark_set: List[Tuple[Operator, OperatorConfig, Optional[Dict[str, int]]]], ): """Add a benchmark set to the collection.""" if name in self.benchmark_sets: @@ -94,14 +88,15 @@ def benchmark(self): self.run_benchmark(op, config, opt) for op, config, opt in benchmark_set ] - def make_operator( - self, operator: Operator, config: OperatorConfig - ) -> Operator: + def make_operator(self, operator: Operator, config: OperatorConfig) -> Operator: """Make an operator instance.""" return operator(config, target=self.benchmark_target) def run_benchmark( - self, operator: Operator, config: OperatorConfig, dynamic_profiling_shape: Optional[Dict[str, int]]=None, + self, + operator: Operator, + config: OperatorConfig, + dynamic_profiling_shape: Optional[Dict[str, int]] = None, ) -> Optional[float]: """Run a single benchmark.""" op_inst = self.make_operator(operator, config) @@ -112,9 +107,7 @@ def run_benchmark( op_inst.hardware_aware_finetune(topk=20, parallel_build=True) tuning_time = perf_counter() - start - latency = op_inst.profile_latency( - dynamic_symbolic_constraints=dynamic_profiling_shape - ) + latency = op_inst.profile_latency(dynamic_symbolic_constraints=dynamic_profiling_shape) return latency, tuning_time @@ -128,16 +121,13 @@ def get_operator_config(self) -> OperatorConfig: """Get the configuration for the operator.""" raise NotImplementedError - def get_benchmark_sets( - self, name: Optional[str] = None - ) -> List[Tuple[Operator, OperatorConfig]]: + def get_benchmark_sets(self, + name: Optional[str] = None) -> List[Tuple[Operator, OperatorConfig]]: """Retrieve benchmark sets by name, or all if name is None.""" if name is None: return self.benchmark_sets else: - assert ( - name in self.benchmark_sets - ), f"Operator {name} not found in benchmark sets" + assert (name in self.benchmark_sets), f"Operator {name} not found in benchmark sets" return self.benchmark_sets[name] @abstractmethod From 54d2227554169073399e87426b3aea6bd995cf72 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 09:23:06 +0000 Subject: [PATCH 67/88] Refactor BitBLASMatmulOpsBenchmark for improved readability and maintainability --- benchmark/operators/benchmark_ops_matmul.py | 78 +++++++++++++++------ 1 file changed, 58 insertions(+), 20 deletions(-) diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index 6832ba2d0..378101099 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -42,7 +42,7 @@ class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase): "A_dtype": "float16", "W_dtype": "int4", "accum_dtype": "float16", - } + }, } CURRENT_COMMIT_ID = get_commit_id() @@ -53,9 +53,14 @@ def prepare_benchmark_sets(self): "FP16xFP16_ACCFP16_NT", [ self.generate_op_unit( - self.generate_operator_config("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384),), + self.generate_operator_config( + "FP16xFP16_ACCFP16_NT", 16384, 16384, 16384 + ), + ), self.generate_op_unit( - self.generate_operator_config("FP16xFP16_ACCFP16_NT", [1, 1024], 16384, 16384), + self.generate_operator_config( + "FP16xFP16_ACCFP16_NT", [1, 1024], 16384, 16384 + ), dynamic_profiling_shape={"M": 1024}, ), ], @@ -87,10 +92,15 @@ def serialize_results(self) -> None: ) # Save benchmark shapes into JSON - shapes = [(config.M, config.N, config.K) - for name, results in self.benchmark_results.items() for i, _ in enumerate(results) - for config in [self.benchmark_sets[name][i][1]]] - self._save_json(shapes, path.join(log_commit_path, self.BENCHMARK_SHAPES_FILE)) + shapes = [ + (config.M, config.N, config.K) + for name, results in self.benchmark_results.items() + for i, _ in enumerate(results) + for config in [self.benchmark_sets[name][i][1]] + ] + self._save_json( + shapes, path.join(log_commit_path, self.BENCHMARK_SHAPES_FILE) + ) # Save device info into JSON self._save_json( @@ -105,7 +115,9 @@ def _save_json(self, data, file_path): def deserialize_results(self, log_path: str) -> None: """Deserialize benchmark results from JSON files.""" - self.benchmark_results = self._load_json(path.join(log_path, self.BENCHMARK_RESULTS_FILE)) + self.benchmark_results = self._load_json( + path.join(log_path, self.BENCHMARK_RESULTS_FILE) + ) shapes_file = path.join(log_path, self.BENCHMARK_SHAPES_FILE) with open(shapes_file, "r") as f: @@ -113,8 +125,9 @@ def deserialize_results(self, log_path: str) -> None: # TODO: Reconstruction of benchmark_sets from shapes del shapes - self.benchmark_target = self._load_json(path.join(log_path, - self.BENCHMARK_DEVICE_FILE))["device"] + self.benchmark_target = self._load_json( + path.join(log_path, self.BENCHMARK_DEVICE_FILE) + )["device"] def _load_json(self, file_path): """Helper function to load JSON data from a file.""" @@ -135,33 +148,56 @@ def report(self): ] def legalize_shape(M, N, K, dyn_prof_shape): + """Generate a string representation of the operator shape. + + Args: + M: The M dimension (can be an int or a tuple). + N: The N dimension (must be an int). + K: The K dimension (must be an int). + dyn_prof_shape: The dynamic profiling shape (dict with 'M' key if M is dynamic). + + Returns: + A string representing the shape in either 'M-N-K' or 'N-K_M' format. + """ if isinstance(M, int): return f"{M}-{N}-{K}" - elif dyn_prof_shape: + elif dyn_prof_shape and "M" in dyn_prof_shape: return f"{N}-{K}_{dyn_prof_shape['M']}" else: - assert isinstance(M, Tuple) + # Calculate the average of tuple M opt_m = sum(M) / len(M) return f"{N}-{K}_{opt_m}" for i, (latency, tuning_time) in enumerate(results): op_config = self.benchmark_sets[name][i][1] dyn_prof_shape = self.benchmark_sets[name][i][2] - shape = legalize_shape(op_config.M, op_config.N, op_config.K, dyn_prof_shape) + shape = legalize_shape( + op_config.M, op_config.N, op_config.K, dyn_prof_shape + ) benchmark_M = ( - sum(op_config.M) / - len(op_config.M) if isinstance(op_config.M, Tuple) else op_config.M) + sum(op_config.M) / len(op_config.M) + if isinstance(op_config.M, Tuple) + else op_config.M + ) throughput = ( f"{(2 * benchmark_M * op_config.N * op_config.K / (latency * 1e-3) / 1e12):.3f}" - if latency else "N/A") + if latency + else "N/A" + ) latency_str = "N/A" if latency is None else f"{latency:.3f}" - tuning_time_str = ("N/A" if tuning_time is None else f"{tuning_time:.3f}") + tuning_time_str = ( + "N/A" if tuning_time is None else f"{tuning_time:.3f}" + ) - table_data.append([shape, latency_str, throughput, tuning_time_str]) + table_data.append( + [shape, latency_str, throughput, tuning_time_str] + ) - print(tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")) + print( + tabulate(table_data, headers="firstrow", tablefmt="fancy_grid") + ) print(HELPER_MESSAGE) def get_operator(self): @@ -175,7 +211,9 @@ def get_operator_config(self): def make_operator(self, operator: Matmul, config: MatmulConfig) -> Matmul: """Make an Matmul instance.""" # Disable default tuning when do benchmark - return operator(config, target=self.benchmark_target, enable_tuning=False) + return operator( + config, target=self.benchmark_target, enable_tuning=False + ) if __name__ == "__main__": From c2edefb8eac4ead5d739d8df2d0f76316f93e083 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 09:38:39 +0000 Subject: [PATCH 68/88] Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run --- .github/workflows/benchmark.yml | 34 +++++++++++++++++++++ benchmark/operators/benchmark_ops_matmul.py | 2 +- 2 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/benchmark.yml diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 000000000..d6709bd20 --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,34 @@ +name: Benchmark + +on: + issue_comment: + types: [created] + +jobs: + benchmark_main: + if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/run-benchmark') + runs-on: self-hosted + + steps: + - name: Checkout code + uses: actions/checkout@v2 + with: + ref: main + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Create virtual environment + run: python -m venv bitblas_benchmark + + - name: Activate virtual environment and install dependencies + run: | + source bitblas_ci/bin/activate + python -m pip install --upgrade pip + if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi + + - name: Install project in wheel mode + run: | + source bitblas_ci/bin/activate + python -m pip install . \ No newline at end of file diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index 378101099..ae41a2c60 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -217,4 +217,4 @@ def make_operator(self, operator: Matmul, config: MatmulConfig) -> Matmul: if __name__ == "__main__": - BitblasMatmulOpsBenchmark().run(enable_tuning=True) + BitblasMatmulOpsBenchmark().run(enable_tuning=False) From e0bc723be4e0fca8c20dd70fd4cd3da3ae2c643b Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 09:38:55 +0000 Subject: [PATCH 69/88] lint fix --- benchmark/operators/benchmark_ops_matmul.py | 61 ++++++--------------- 1 file changed, 17 insertions(+), 44 deletions(-) diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index ae41a2c60..2e1ac362b 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -53,14 +53,9 @@ def prepare_benchmark_sets(self): "FP16xFP16_ACCFP16_NT", [ self.generate_op_unit( - self.generate_operator_config( - "FP16xFP16_ACCFP16_NT", 16384, 16384, 16384 - ), - ), + self.generate_operator_config("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384),), self.generate_op_unit( - self.generate_operator_config( - "FP16xFP16_ACCFP16_NT", [1, 1024], 16384, 16384 - ), + self.generate_operator_config("FP16xFP16_ACCFP16_NT", [1, 1024], 16384, 16384), dynamic_profiling_shape={"M": 1024}, ), ], @@ -92,15 +87,10 @@ def serialize_results(self) -> None: ) # Save benchmark shapes into JSON - shapes = [ - (config.M, config.N, config.K) - for name, results in self.benchmark_results.items() - for i, _ in enumerate(results) - for config in [self.benchmark_sets[name][i][1]] - ] - self._save_json( - shapes, path.join(log_commit_path, self.BENCHMARK_SHAPES_FILE) - ) + shapes = [(config.M, config.N, config.K) + for name, results in self.benchmark_results.items() for i, _ in enumerate(results) + for config in [self.benchmark_sets[name][i][1]]] + self._save_json(shapes, path.join(log_commit_path, self.BENCHMARK_SHAPES_FILE)) # Save device info into JSON self._save_json( @@ -115,9 +105,7 @@ def _save_json(self, data, file_path): def deserialize_results(self, log_path: str) -> None: """Deserialize benchmark results from JSON files.""" - self.benchmark_results = self._load_json( - path.join(log_path, self.BENCHMARK_RESULTS_FILE) - ) + self.benchmark_results = self._load_json(path.join(log_path, self.BENCHMARK_RESULTS_FILE)) shapes_file = path.join(log_path, self.BENCHMARK_SHAPES_FILE) with open(shapes_file, "r") as f: @@ -125,9 +113,8 @@ def deserialize_results(self, log_path: str) -> None: # TODO: Reconstruction of benchmark_sets from shapes del shapes - self.benchmark_target = self._load_json( - path.join(log_path, self.BENCHMARK_DEVICE_FILE) - )["device"] + self.benchmark_target = self._load_json(path.join(log_path, + self.BENCHMARK_DEVICE_FILE))["device"] def _load_json(self, file_path): """Helper function to load JSON data from a file.""" @@ -171,33 +158,21 @@ def legalize_shape(M, N, K, dyn_prof_shape): for i, (latency, tuning_time) in enumerate(results): op_config = self.benchmark_sets[name][i][1] dyn_prof_shape = self.benchmark_sets[name][i][2] - shape = legalize_shape( - op_config.M, op_config.N, op_config.K, dyn_prof_shape - ) + shape = legalize_shape(op_config.M, op_config.N, op_config.K, dyn_prof_shape) benchmark_M = ( - sum(op_config.M) / len(op_config.M) - if isinstance(op_config.M, Tuple) - else op_config.M - ) + sum(op_config.M) / + len(op_config.M) if isinstance(op_config.M, Tuple) else op_config.M) throughput = ( f"{(2 * benchmark_M * op_config.N * op_config.K / (latency * 1e-3) / 1e12):.3f}" - if latency - else "N/A" - ) + if latency else "N/A") latency_str = "N/A" if latency is None else f"{latency:.3f}" - tuning_time_str = ( - "N/A" if tuning_time is None else f"{tuning_time:.3f}" - ) + tuning_time_str = ("N/A" if tuning_time is None else f"{tuning_time:.3f}") - table_data.append( - [shape, latency_str, throughput, tuning_time_str] - ) + table_data.append([shape, latency_str, throughput, tuning_time_str]) - print( - tabulate(table_data, headers="firstrow", tablefmt="fancy_grid") - ) + print(tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")) print(HELPER_MESSAGE) def get_operator(self): @@ -211,9 +186,7 @@ def get_operator_config(self): def make_operator(self, operator: Matmul, config: MatmulConfig) -> Matmul: """Make an Matmul instance.""" # Disable default tuning when do benchmark - return operator( - config, target=self.benchmark_target, enable_tuning=False - ) + return operator(config, target=self.benchmark_target, enable_tuning=False) if __name__ == "__main__": From a4e68d178f12526e79c23384a5d13af06cc5478f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 10:13:02 +0000 Subject: [PATCH 70/88] Benchmark bot test --- .github/workflows/benchmark.yml | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index d6709bd20..da34ec096 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -24,11 +24,16 @@ jobs: - name: Activate virtual environment and install dependencies run: | - source bitblas_ci/bin/activate + source bitblas_benchmark/bin/activate python -m pip install --upgrade pip if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi - name: Install project in wheel mode run: | - source bitblas_ci/bin/activate - python -m pip install . \ No newline at end of file + source bitblas_benchmark/bin/activate + python -m pip install . + + - name: Matmul Benchmark + source bitblas_benchmark/bin/activate + cd benchmark/operators + python ./benchmark_ops_matmul.py From 1c033654d15dc98707edeaabfcd8951b3a800734 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 10:55:09 +0000 Subject: [PATCH 71/88] Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run --- .github/workflows/benchmark.yml | 55 ++++++++++++++++++-- benchmark/operators/benchmark_ops_matmul.py | 56 +++++++++++++++------ bitblas/benchmark/operator/__init__.py | 16 +++--- 3 files changed, 101 insertions(+), 26 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index da34ec096..fb8729ddd 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -5,7 +5,7 @@ on: types: [created] jobs: - benchmark_main: + benchmark: if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/run-benchmark') runs-on: self-hosted @@ -14,11 +14,37 @@ jobs: uses: actions/checkout@v2 with: ref: main + + - name: Get base branch commit ID + id: get_base_commit + run: echo "BASE_COMMIT=$(git rev-parse HEAD)" >> $GITHUB_ENV + - name: Set up Python uses: actions/setup-python@v2 with: python-version: '3.9' + - name: Activate virtual environment and install dependencies + run: | + source bitblas_benchmark/bin/activate + python -m pip install --upgrade pip + if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi + + - name: Install project in wheel mode + run: | + source bitblas_benchmark/bin/activate + python -m pip install . + + - name: Matmul Benchmark + run: | + source bitblas_benchmark/bin/activate + cd benchmark/operators + python ./benchmark_ops_matmul.py + + - name: Get PR branch commit ID + id: get_pr_commit + run: echo "PR_COMMIT=$(git rev-parse HEAD)" >> $GITHUB_ENV + - name: Create virtual environment run: python -m venv bitblas_benchmark @@ -32,8 +58,27 @@ jobs: run: | source bitblas_benchmark/bin/activate python -m pip install . - + - name: Matmul Benchmark - source bitblas_benchmark/bin/activate - cd benchmark/operators - python ./benchmark_ops_matmul.py + run: | + source bitblas_benchmark/bin/activate + cd benchmark/operators + python ./benchmark_ops_matmul.py + + - name: Install GitHub CLI + run: | + sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-key C99B11DEB97541F0 + sudo apt-add-repository https://cli.github.com/packages + sudo apt update + sudo apt install gh + + - name: Authenticate GitHub CLI + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + gh auth login --with-token <<< $GITHUB_TOKEN + + - name: Post benchmark results + run: | + cat benchmark_results.txt + gh pr comment ${{ github.event.issue.number }} --body "$(cat benchmark_results.txt)" \ No newline at end of file diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index 2e1ac362b..cbd4a142a 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -8,7 +8,7 @@ from tabulate import tabulate import json from os import path, makedirs -from typing import Tuple +from typing import Tuple, Dict, List, Union set_log_level("DEBUG") @@ -87,9 +87,16 @@ def serialize_results(self) -> None: ) # Save benchmark shapes into JSON - shapes = [(config.M, config.N, config.K) - for name, results in self.benchmark_results.items() for i, _ in enumerate(results) - for config in [self.benchmark_sets[name][i][1]]] + shapes : Dict[List[List[Union[List, int], int, int]]] = {} + + # Iterate through the benchmark results to extract the shapes + for name, results in self.benchmark_results.items(): + shapes[name] = [] + for i, _ in enumerate(results): + config = self.benchmark_sets[name][i][1] + dyn_prof_shape = self.benchmark_sets[name][i][2] + shapes[name].append([config.M, config.N, config.K, dyn_prof_shape]) + self._save_json(shapes, path.join(log_commit_path, self.BENCHMARK_SHAPES_FILE)) # Save device info into JSON @@ -103,20 +110,41 @@ def _save_json(self, data, file_path): with open(file_path, "w") as f: json.dump(data, f) - def deserialize_results(self, log_path: str) -> None: + @classmethod + def deserialize_from_logs(cls, commit_id: str) -> None: """Deserialize benchmark results from JSON files.""" - self.benchmark_results = self._load_json(path.join(log_path, self.BENCHMARK_RESULTS_FILE)) + benchmark = cls() + commit_id_path = f"CommitID_{commit_id}" + log_commit_path = path.join(benchmark.log_path, commit_id_path) - shapes_file = path.join(log_path, self.BENCHMARK_SHAPES_FILE) - with open(shapes_file, "r") as f: - shapes = json.load(f) - # TODO: Reconstruction of benchmark_sets from shapes - del shapes + benchmark.benchmark_results = cls._load_json(path.join(log_commit_path, cls.BENCHMARK_RESULTS_FILE)) - self.benchmark_target = self._load_json(path.join(log_path, - self.BENCHMARK_DEVICE_FILE))["device"] + shapes_file = path.join(log_commit_path, cls.BENCHMARK_SHAPES_FILE) - def _load_json(self, file_path): + with open(shapes_file, "r") as f: + shapes = json.load(f) + for name, shape_list in shapes.items(): + for shape in shape_list: + M, N, K, dyn_prof_shape = shape + benchmark.add_benchmark_set( + name, + [ + benchmark.generate_op_unit( + benchmark.generate_operator_config( + name, M, N, K + ), + dynamic_profiling_shape=dyn_prof_shape, + ) + ], + ) + + benchmark.benchmark_target = cls._load_json(path.join(log_commit_path, + cls.BENCHMARK_DEVICE_FILE))["device"] + + return benchmark + + @staticmethod + def _load_json(file_path): """Helper function to load JSON data from a file.""" with open(file_path, "r") as f: return json.load(f) diff --git a/bitblas/benchmark/operator/__init__.py b/bitblas/benchmark/operator/__init__.py index 113aab5c5..e6b137431 100644 --- a/bitblas/benchmark/operator/__init__.py +++ b/bitblas/benchmark/operator/__init__.py @@ -26,7 +26,7 @@ class BitblasOperatorBenchmarkBase(ABC): enable_hardware_aware_tuning: bool = False # Log path - log_path: Optional[str] = None + log_path: Optional[str] = path.join(get_default_cache_path(), "benchmark") @abstractmethod def prepare_benchmark_sets(self): @@ -53,7 +53,6 @@ def add_benchmark_set( def run(self, report=True, serialize=True, enable_tuning: bool = False): """Run the benchmark process.""" - self.log_path = path.join(get_default_cache_path(), "benchmark") if not path.exists(self.log_path): makedirs(self.log_path) @@ -135,11 +134,6 @@ def serialize_results(self) -> None: """Serialize the benchmark results.""" pass - @abstractmethod - def deserialize_results(self) -> None: - """Deserialize the benchmark results.""" - pass - def enable_tuning(self): """Enable hardware-aware tuning.""" self.enable_hardware_aware_tuning = True @@ -151,3 +145,11 @@ def disable_tuning(self): def set_log_path(self, log_path: str): """Set the log path.""" self.log_path = log_path + + def set_benchmark_target(self, target: str): + """Set the benchmark target.""" + self.benchmark_target = target + + def set_benchmark_results(self, results: Dict[str, List[Tuple[Optional[float], Optional[float]]]]): + """Set the benchmark results.""" + self.benchmark_results = results From 4f319fc05d5dc075e09565050cb2aafc3eb4dc8e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 11:10:54 +0000 Subject: [PATCH 72/88] Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run --- benchmark/operators/benchmark_ops_matmul.py | 13 ++- benchmark/operators/compare_benchmark.py | 105 ++++++++++++++++++++ bitblas/benchmark/operator/__init__.py | 3 +- 3 files changed, 113 insertions(+), 8 deletions(-) create mode 100644 benchmark/operators/compare_benchmark.py diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index cbd4a142a..db83be28a 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -87,7 +87,7 @@ def serialize_results(self) -> None: ) # Save benchmark shapes into JSON - shapes : Dict[List[List[Union[List, int], int, int]]] = {} + shapes: Dict[List[List[Union[List, int], int, int]]] = {} # Iterate through the benchmark results to extract the shapes for name, results in self.benchmark_results.items(): @@ -117,7 +117,8 @@ def deserialize_from_logs(cls, commit_id: str) -> None: commit_id_path = f"CommitID_{commit_id}" log_commit_path = path.join(benchmark.log_path, commit_id_path) - benchmark.benchmark_results = cls._load_json(path.join(log_commit_path, cls.BENCHMARK_RESULTS_FILE)) + benchmark.benchmark_results = cls._load_json( + path.join(log_commit_path, cls.BENCHMARK_RESULTS_FILE)) shapes_file = path.join(log_commit_path, cls.BENCHMARK_SHAPES_FILE) @@ -130,16 +131,14 @@ def deserialize_from_logs(cls, commit_id: str) -> None: name, [ benchmark.generate_op_unit( - benchmark.generate_operator_config( - name, M, N, K - ), + benchmark.generate_operator_config(name, M, N, K), dynamic_profiling_shape=dyn_prof_shape, ) ], ) - benchmark.benchmark_target = cls._load_json(path.join(log_commit_path, - cls.BENCHMARK_DEVICE_FILE))["device"] + benchmark.benchmark_target = cls._load_json( + path.join(log_commit_path, cls.BENCHMARK_DEVICE_FILE))["device"] return benchmark diff --git a/benchmark/operators/compare_benchmark.py b/benchmark/operators/compare_benchmark.py new file mode 100644 index 000000000..66433a323 --- /dev/null +++ b/benchmark/operators/compare_benchmark.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +from benchmark_ops_matmul import BitblasMatmulOpsBenchmark, HELPER_MESSAGE +from tabulate import tabulate +from typing import Tuple + +def compare(base: BitblasMatmulOpsBenchmark, head: BitblasMatmulOpsBenchmark): + """Generate and print a report of the benchmark results.""" + for name, results in head.benchmark_results.items(): + table_data = [ + ["TAG:", name, "Device:", head.benchmark_target], + [ + "Shape (M-N-K / N-K_M)", + "Time (ms)", + "Throughput (TFLOPS)", + "Tune Time (s)", + ], + ] + + def get_suffix(base, head): + symbol = "↑" if head > base else "↓" if head < base else "=" + ratio = f"{((head - base) / base) * 100:.2f}%" if base is not None else "N/A" + return f"{symbol}({ratio})" + + def legalize_shape(M, N, K, dyn_prof_shape): + """Generate a string representation of the operator shape. + + Args: + M: The M dimension (can be an int or a tuple). + N: The N dimension (must be an int). + K: The K dimension (must be an int). + dyn_prof_shape: The dynamic profiling shape (dict with 'M' key if M is dynamic). + + Returns: + A string representing the shape in either 'M-N-K' or 'N-K_M' format. + """ + if isinstance(M, int): + return f"{M}-{N}-{K}" + elif dyn_prof_shape and "M" in dyn_prof_shape: + return f"{N}-{K}_{dyn_prof_shape['M']}" + else: + # Calculate the average of tuple M + opt_m = sum(M) / len(M) + return f"{N}-{K}_{opt_m}" + + for i, (latency, tuning_time) in enumerate(results): + op_config = head.benchmark_sets[name][i][1] + dyn_prof_shape = head.benchmark_sets[name][i][2] + shape = legalize_shape(op_config.M, op_config.N, op_config.K, dyn_prof_shape) + + benchmark_M = ( + sum(op_config.M) / + len(op_config.M) if isinstance(op_config.M, Tuple) else op_config.M) + + base_latency = base.benchmark_results[name][i][0] + if latency is not None: + throughput = (2 * benchmark_M * op_config.N * op_config.K / (latency * 1e-3) / 1e12) + base_throughput = (2 * benchmark_M * op_config.N * op_config.K / (base_latency * 1e-3) / 1e12) + throughput = f"{throughput:.3f}{get_suffix(base_throughput, throughput)}" + else: + throughput = "N/A" + + if base_latency is not None: + latency_str = f"{latency:.3f}{get_suffix(base_latency, latency)}" + else: + latency_str = "N/A" + + base_tuning_time = base.benchmark_results[name][i][1] + if tuning_time is not None: + tuning_time_str = f"{tuning_time:.3f}{get_suffix(base_tuning_time, tuning_time)}" + else: + tuning_time_str = "N/A" + + table_data.append([shape, latency_str, throughput, tuning_time_str]) + + print(tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")) + print(HELPER_MESSAGE) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--base", + default="df7e9aa61e3db411ac3f2fd98a1854a36194ef0c", + type=str, + help="the base commit id", + ) + parser.add_argument( + "--head", + default="1c033654d15dc98707edeaabfcd8951b3a800734", + type=str, + help="the head commit id", + ) + args = parser.parse_args() + + base_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs( + args.base + ) + + head_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs( + args.head + ) + + compare(base_benchmark, head_benchmark) diff --git a/bitblas/benchmark/operator/__init__.py b/bitblas/benchmark/operator/__init__.py index e6b137431..c5e7852e3 100644 --- a/bitblas/benchmark/operator/__init__.py +++ b/bitblas/benchmark/operator/__init__.py @@ -150,6 +150,7 @@ def set_benchmark_target(self, target: str): """Set the benchmark target.""" self.benchmark_target = target - def set_benchmark_results(self, results: Dict[str, List[Tuple[Optional[float], Optional[float]]]]): + def set_benchmark_results(self, results: Dict[str, List[Tuple[Optional[float], + Optional[float]]]]): """Set the benchmark results.""" self.benchmark_results = results From a8833d4577c9c22a5b3225f9e0a4e4671a1a0a6c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 11:11:11 +0000 Subject: [PATCH 73/88] Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run --- benchmark/operators/compare_benchmark.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/benchmark/operators/compare_benchmark.py b/benchmark/operators/compare_benchmark.py index 66433a323..04e0e3e71 100644 --- a/benchmark/operators/compare_benchmark.py +++ b/benchmark/operators/compare_benchmark.py @@ -6,6 +6,7 @@ from tabulate import tabulate from typing import Tuple + def compare(base: BitblasMatmulOpsBenchmark, head: BitblasMatmulOpsBenchmark): """Generate and print a report of the benchmark results.""" for name, results in head.benchmark_results.items(): @@ -23,7 +24,7 @@ def get_suffix(base, head): symbol = "↑" if head > base else "↓" if head < base else "=" ratio = f"{((head - base) / base) * 100:.2f}%" if base is not None else "N/A" return f"{symbol}({ratio})" - + def legalize_shape(M, N, K, dyn_prof_shape): """Generate a string representation of the operator shape. @@ -57,16 +58,17 @@ def legalize_shape(M, N, K, dyn_prof_shape): base_latency = base.benchmark_results[name][i][0] if latency is not None: throughput = (2 * benchmark_M * op_config.N * op_config.K / (latency * 1e-3) / 1e12) - base_throughput = (2 * benchmark_M * op_config.N * op_config.K / (base_latency * 1e-3) / 1e12) + base_throughput = (2 * benchmark_M * op_config.N * op_config.K / + (base_latency * 1e-3) / 1e12) throughput = f"{throughput:.3f}{get_suffix(base_throughput, throughput)}" else: throughput = "N/A" - + if base_latency is not None: latency_str = f"{latency:.3f}{get_suffix(base_latency, latency)}" else: latency_str = "N/A" - + base_tuning_time = base.benchmark_results[name][i][1] if tuning_time is not None: tuning_time_str = f"{tuning_time:.3f}{get_suffix(base_tuning_time, tuning_time)}" @@ -78,6 +80,7 @@ def legalize_shape(M, N, K, dyn_prof_shape): print(tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")) print(HELPER_MESSAGE) + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( @@ -94,12 +97,8 @@ def legalize_shape(M, N, K, dyn_prof_shape): ) args = parser.parse_args() - base_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs( - args.base - ) + base_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs(args.base) - head_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs( - args.head - ) + head_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs(args.head) compare(base_benchmark, head_benchmark) From 803f6c660a18c7766e856281209ae63230870f8e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 11:11:21 +0000 Subject: [PATCH 74/88] Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run --- benchmark/operators/compare_benchmark.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/benchmark/operators/compare_benchmark.py b/benchmark/operators/compare_benchmark.py index 04e0e3e71..080d49dca 100644 --- a/benchmark/operators/compare_benchmark.py +++ b/benchmark/operators/compare_benchmark.py @@ -85,13 +85,11 @@ def legalize_shape(M, N, K, dyn_prof_shape): parser = argparse.ArgumentParser() parser.add_argument( "--base", - default="df7e9aa61e3db411ac3f2fd98a1854a36194ef0c", type=str, help="the base commit id", ) parser.add_argument( "--head", - default="1c033654d15dc98707edeaabfcd8951b3a800734", type=str, help="the head commit id", ) From df4572b71d796cfe3dab6201337b473b6ad9e4b1 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 11:15:27 +0000 Subject: [PATCH 75/88] Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run --- .github/workflows/benchmark.yml | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index fb8729ddd..5bd16ebb2 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -17,7 +17,7 @@ jobs: - name: Get base branch commit ID id: get_base_commit - run: echo "BASE_COMMIT=$(git rev-parse HEAD)" >> $GITHUB_ENV + run: echo "BASE_COMMIT_ID=$(git rev-parse HEAD)" >> $GITHUB_ENV - name: Set up Python uses: actions/setup-python@v2 @@ -43,7 +43,7 @@ jobs: - name: Get PR branch commit ID id: get_pr_commit - run: echo "PR_COMMIT=$(git rev-parse HEAD)" >> $GITHUB_ENV + run: echo "PR_COMMIT_ID=$(git rev-parse HEAD)" >> $GITHUB_ENV - name: Create virtual environment run: python -m venv bitblas_benchmark @@ -65,6 +65,12 @@ jobs: cd benchmark/operators python ./benchmark_ops_matmul.py + - name: Compare benchmark results + run: | + source bitblas_benchmark/bin/activate + cd benchmark/operators + python ./compare_benchmark.py --base ${{ env.BASE_COMMIT_ID }} --head ${{ env.PR_COMMIT_ID }} 2>&1 | tee compare_results.txt + - name: Install GitHub CLI run: | sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-key C99B11DEB97541F0 @@ -80,5 +86,5 @@ jobs: - name: Post benchmark results run: | - cat benchmark_results.txt - gh pr comment ${{ github.event.issue.number }} --body "$(cat benchmark_results.txt)" \ No newline at end of file + cat compare_results.txt + gh pr comment ${{ github.event.issue.number }} --body "$(cat compare_results.txt)" From 45ded45ee15fabc15e589ef2d2dffb4899231422 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 11:18:41 +0000 Subject: [PATCH 76/88] int8 test case --- benchmark/operators/benchmark_ops_matmul.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index db83be28a..452ea042c 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -61,6 +61,18 @@ def prepare_benchmark_sets(self): ], ) + self.add_benchmark_set( + "INT8xINT8_ACCINT32_NT", + [ + self.generate_op_unit( + self.generate_operator_config("INT8xINT8_ACCINT32_NT", 16384, 16384, 16384),), + self.generate_op_unit( + self.generate_operator_config("INT8xINT8_ACCINT32_NT", [1, 1024], 16384, 16384), + dynamic_profiling_shape={"M": 1024}, + ), + ], + ) + def generate_operator_config(self, name: str, M, N, K) -> MatmulConfig: """Generate configuration for the given operator.""" if name not in self.config_map: From 4229676dc9d5b115d9b5b3211037dc0db1826e60 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 11:21:08 +0000 Subject: [PATCH 77/88] Refactor compare_benchmark.py to handle missing benchmark results gracefully --- benchmark/operators/compare_benchmark.py | 35 ++++++++++++++++++------ 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/benchmark/operators/compare_benchmark.py b/benchmark/operators/compare_benchmark.py index 080d49dca..53ba41e98 100644 --- a/benchmark/operators/compare_benchmark.py +++ b/benchmark/operators/compare_benchmark.py @@ -55,23 +55,42 @@ def legalize_shape(M, N, K, dyn_prof_shape): sum(op_config.M) / len(op_config.M) if isinstance(op_config.M, Tuple) else op_config.M) - base_latency = base.benchmark_results[name][i][0] + try: + base_latency = base.benchmark_results[name][i][0] + except IndexError: + print(f"Operator {name} not found in benchmark sets") + base_latency = None + if latency is not None: throughput = (2 * benchmark_M * op_config.N * op_config.K / (latency * 1e-3) / 1e12) - base_throughput = (2 * benchmark_M * op_config.N * op_config.K / - (base_latency * 1e-3) / 1e12) - throughput = f"{throughput:.3f}{get_suffix(base_throughput, throughput)}" + if base_latency is not None: + base_throughput = (2 * benchmark_M * op_config.N * op_config.K / + (base_latency * 1e-3) / 1e12) + throughput = f"{throughput:.3f}{get_suffix(base_throughput, throughput)}" + else: + throughput = f"{throughput:.3f}" else: throughput = "N/A" - if base_latency is not None: - latency_str = f"{latency:.3f}{get_suffix(base_latency, latency)}" + if latency is not None: + if base_latency is not None: + latency_str = f"{latency:.3f}{get_suffix(base_latency, latency)}" + else: + latency_str = f"{latency:.3f}" else: latency_str = "N/A" - base_tuning_time = base.benchmark_results[name][i][1] + try: + base_tuning_time = base.benchmark_results[name][i][1] + except IndexError: + print(f"Operator {name} not found in benchmark sets") + base_tuning_time = None + if tuning_time is not None: - tuning_time_str = f"{tuning_time:.3f}{get_suffix(base_tuning_time, tuning_time)}" + if base_tuning_time is not None: + tuning_time_str = f"{tuning_time:.3f}{get_suffix(base_tuning_time, tuning_time)}" + else: + tuning_time_str = f"{tuning_time:.3f}" else: tuning_time_str = "N/A" From 476ffee28764030c32d76529b8c7a167f1d35779 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 11:28:27 +0000 Subject: [PATCH 78/88] ci fix --- .github/workflows/benchmark.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 5bd16ebb2..6acc309e2 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -24,6 +24,9 @@ jobs: with: python-version: '3.9' + - name: Create virtual environment + run: python -m venv bitblas_benchmark + - name: Activate virtual environment and install dependencies run: | source bitblas_benchmark/bin/activate @@ -44,9 +47,6 @@ jobs: - name: Get PR branch commit ID id: get_pr_commit run: echo "PR_COMMIT_ID=$(git rev-parse HEAD)" >> $GITHUB_ENV - - - name: Create virtual environment - run: python -m venv bitblas_benchmark - name: Activate virtual environment and install dependencies run: | From 9bd34ffc2971d6dfc8efd9e9f7100fb971a55693 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 11:28:48 +0000 Subject: [PATCH 79/88] disable ci for test benchmark --- .github/workflows/ci.yml | 110 +++++++++++++++++++-------------------- 1 file changed, 55 insertions(+), 55 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9b76866c5..6702f7116 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,70 +1,70 @@ -name: CI +# name: CI -on: [pull_request] +# on: [pull_request] -jobs: - format-check: - runs-on: self-hosted +# jobs: +# format-check: +# runs-on: self-hosted - steps: - - name: Checkout repository - uses: actions/checkout@v2 - with: - fetch-depth: 0 +# steps: +# - name: Checkout repository +# uses: actions/checkout@v2 +# with: +# fetch-depth: 0 - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.9' +# - name: Set up Python +# uses: actions/setup-python@v2 +# with: +# python-version: '3.9' - - name: Create virtual environment - run: python -m venv bitblas_ci +# - name: Create virtual environment +# run: python -m venv bitblas_ci - - name: Activate virtual environment and install dependencies - run: | - source bitblas_ci/bin/activate - python -m pip install --upgrade pip - if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi +# - name: Activate virtual environment and install dependencies +# run: | +# source bitblas_ci/bin/activate +# python -m pip install --upgrade pip +# if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi - - name: Update submodules recursively - run: git submodule update --init --recursive +# - name: Update submodules recursively +# run: git submodule update --init --recursive - - name: Run format check - run: | - source bitblas_ci/bin/activate - ./format.sh +# - name: Run format check +# run: | +# source bitblas_ci/bin/activate +# ./format.sh - build-test: - runs-on: self-hosted - needs: format-check +# build-test: +# runs-on: self-hosted +# needs: format-check - steps: - - name: Checkout repository - uses: actions/checkout@v2 - with: - fetch-depth: 0 +# steps: +# - name: Checkout repository +# uses: actions/checkout@v2 +# with: +# fetch-depth: 0 - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.9' +# - name: Set up Python +# uses: actions/setup-python@v2 +# with: +# python-version: '3.9' - - name: Create virtual environment - run: python -m venv bitblas_ci +# - name: Create virtual environment +# run: python -m venv bitblas_ci - - name: Activate virtual environment and install dependencies - run: | - source bitblas_ci/bin/activate - python -m pip install --upgrade pip - if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi +# - name: Activate virtual environment and install dependencies +# run: | +# source bitblas_ci/bin/activate +# python -m pip install --upgrade pip +# if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi - - name: Install project in wheel mode - run: | - source bitblas_ci/bin/activate - python -m pip install . +# - name: Install project in wheel mode +# run: | +# source bitblas_ci/bin/activate +# python -m pip install . - - name: Run tests - run: | - source bitblas_ci/bin/activate - cd testing/python - python -m pytest +# - name: Run tests +# run: | +# source bitblas_ci/bin/activate +# cd testing/python +# python -m pytest From 75f3dd99aee1ffa351bd66814ea321a3f1e4b230 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 11:49:12 +0000 Subject: [PATCH 80/88] Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run --- .github/workflows/benchmark.yml | 5 +++++ benchmark/operators/compare_benchmark.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 6acc309e2..974715c27 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -44,6 +44,11 @@ jobs: cd benchmark/operators python ./benchmark_ops_matmul.py + - name: Checkout PR branch code + uses: actions/checkout@v2 + with: + ref: ${{ github.event.pull_request.head.ref }} + - name: Get PR branch commit ID id: get_pr_commit run: echo "PR_COMMIT_ID=$(git rev-parse HEAD)" >> $GITHUB_ENV diff --git a/benchmark/operators/compare_benchmark.py b/benchmark/operators/compare_benchmark.py index 53ba41e98..abafb3099 100644 --- a/benchmark/operators/compare_benchmark.py +++ b/benchmark/operators/compare_benchmark.py @@ -113,7 +113,9 @@ def legalize_shape(M, N, K, dyn_prof_shape): help="the head commit id", ) args = parser.parse_args() - + + print(f"Comparing base commit {args.base} with head commit {args.head}") + base_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs(args.base) head_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs(args.head) From 79e04aa5559fe36d453874aa66617d73ee6ed645 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 12:09:23 +0000 Subject: [PATCH 81/88] remove cli installation --- .github/workflows/benchmark.yml | 7 ------- 1 file changed, 7 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 974715c27..32b5920cb 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -76,13 +76,6 @@ jobs: cd benchmark/operators python ./compare_benchmark.py --base ${{ env.BASE_COMMIT_ID }} --head ${{ env.PR_COMMIT_ID }} 2>&1 | tee compare_results.txt - - name: Install GitHub CLI - run: | - sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-key C99B11DEB97541F0 - sudo apt-add-repository https://cli.github.com/packages - sudo apt update - sudo apt install gh - - name: Authenticate GitHub CLI env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} From cdd3345b3872a5720b0375ce7d7472f0260dcf7e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 13:23:05 +0000 Subject: [PATCH 82/88] chore: Create virtual environment and install dependencies for benchmark --- .github/workflows/benchmark.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 32b5920cb..cb50ce702 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -47,12 +47,15 @@ jobs: - name: Checkout PR branch code uses: actions/checkout@v2 with: - ref: ${{ github.event.pull_request.head.ref }} + fetch-depth: 0 - name: Get PR branch commit ID id: get_pr_commit run: echo "PR_COMMIT_ID=$(git rev-parse HEAD)" >> $GITHUB_ENV - + + - name: Create virtual environment + run: python -m venv bitblas_benchmark + - name: Activate virtual environment and install dependencies run: | source bitblas_benchmark/bin/activate From f211ad434c2e7b0af96c9182ec197ba46322fa3b Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 23 Jul 2024 14:43:44 +0000 Subject: [PATCH 83/88] chore: Update benchmark workflow to include comparison step --- .github/workflows/benchmark.yml | 66 ++++++++++++++++++++++++++++++--- 1 file changed, 61 insertions(+), 5 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index cb50ce702..26d54077f 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -5,7 +5,7 @@ on: types: [created] jobs: - benchmark: + benchmark_base: if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/run-benchmark') runs-on: self-hosted @@ -17,7 +17,13 @@ jobs: - name: Get base branch commit ID id: get_base_commit - run: echo "BASE_COMMIT_ID=$(git rev-parse HEAD)" >> $GITHUB_ENV + run: echo "BASE_COMMIT_ID=$(git rev-parse HEAD)" > base_commit_id.txt + + - name: Upload base commit ID + uses: actions/upload-artifact@v3 + with: + name: base-commit-id + path: base_commit_id.txt - name: Set up Python uses: actions/setup-python@v2 @@ -44,14 +50,32 @@ jobs: cd benchmark/operators python ./benchmark_ops_matmul.py + benchmark_head: + if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/run-benchmark') + needs: benchmark_base + runs-on: self-hosted + + steps: - name: Checkout PR branch code uses: actions/checkout@v2 with: + ref: ${{ github.event.pull_request.head.ref }} fetch-depth: 0 - name: Get PR branch commit ID id: get_pr_commit - run: echo "PR_COMMIT_ID=$(git rev-parse HEAD)" >> $GITHUB_ENV + run: echo "PR_COMMIT_ID=$(git rev-parse HEAD)" > pr_commit_id.txt + + - name: Upload PR commit ID + uses: actions/upload-artifact@v3 + with: + name: pr-commit-id + path: pr_commit_id.txt + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' - name: Create virtual environment run: python -m venv bitblas_benchmark @@ -73,17 +97,49 @@ jobs: cd benchmark/operators python ./benchmark_ops_matmul.py + benchmark_compare: + if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/run-benchmark') + needs: [benchmark_base, benchmark_head] + runs-on: self-hosted + + steps: + - name: Download commit IDs + uses: actions/download-artifact@v3 + with: + name: base-commit-id + path: . + + - name: Download PR commit ID + uses: actions/download-artifact@v3 + with: + name: pr-commit-id + path: . + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + + - name: Create virtual environment + run: python -m venv bitblas_benchmark + + - name: Activate virtual environment and install dependencies + run: | + source bitblas_benchmark/bin/activate + python -m pip install --upgrade pip + if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi + - name: Compare benchmark results run: | source bitblas_benchmark/bin/activate cd benchmark/operators - python ./compare_benchmark.py --base ${{ env.BASE_COMMIT_ID }} --head ${{ env.PR_COMMIT_ID }} 2>&1 | tee compare_results.txt + python ./compare_benchmark.py --base $(cat base_commit_id.txt) --head $(cat pr_commit_id.txt) 2>&1 | tee compare_results.txt - name: Authenticate GitHub CLI env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | - gh auth login --with-token <<< $GITHUB_TOKEN + echo "${{ secrets.GITHUB_TOKEN }}" | gh auth login --with-token - name: Post benchmark results run: | From ddde02a5bb062a63f8947bf45f73968e90aa7cd8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 24 Jul 2024 04:51:44 +0000 Subject: [PATCH 84/88] Lint fix --- benchmark/operators/benchmark_ops_matmul.py | 107 ++++++++++++++++---- benchmark/operators/compare_benchmark.py | 6 +- bitblas/base/utils.py | 14 +-- bitblas/benchmark/operator/__init__.py | 2 + bitblas/ops/general_matmul/__init__.py | 8 ++ bitblas/ops/matmul.py | 1 - bitblas/ops/operator.py | 8 +- 7 files changed, 115 insertions(+), 31 deletions(-) diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index 452ea042c..723cf035b 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -38,38 +38,109 @@ class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase): "accum_dtype": "int32", "out_dtype": "int8", }, - "FP16xINT4_ACCINT32_NT": { + "FP16xUINT4_ACCFP16_NT": { "A_dtype": "float16", - "W_dtype": "int4", + "W_dtype": "uint4", "accum_dtype": "float16", }, + "FP16xUINT2_ACCFP16_NT": { + "A_dtype": "float16", + "W_dtype": "uint2", + "accum_dtype": "float16", + }, + "INT8xUINT2_ACCINT32_NT": { + "A_dtype": "int8", + "W_dtype": "uint2", + "accum_dtype": "int32", + "out_dtype": "int8", + }, } CURRENT_COMMIT_ID = get_commit_id() + def prepare_set_group_4x(self, name: str, M, N, K) -> List: + return [ + self.generate_op_unit(self.generate_operator_config(name, 1, N, K)), + self.generate_op_unit(self.generate_operator_config(name, M, N, K)), + self.generate_op_unit( + self.generate_operator_config(name, [1, M], N, K), + dynamic_profiling_shape={"m": 1}, + ), + self.generate_op_unit( + self.generate_operator_config(name, [1, M], N, K), + dynamic_profiling_shape={"m": M}, + ), + ] + + def prepare_set_group_llm(self, name: str, N, K) -> List: + return [ + self.generate_op_unit(self.generate_operator_config(name, 1, N, K)), + self.generate_op_unit(self.generate_operator_config(name, 16, N, K)), + self.generate_op_unit(self.generate_operator_config(name, 32, N, K)), + self.generate_op_unit(self.generate_operator_config(name, 64, N, K)), + self.generate_op_unit(self.generate_operator_config(name, 128, N, K)), + self.generate_op_unit(self.generate_operator_config(name, 2048, N, K)), + self.generate_op_unit( + self.generate_operator_config(name, [1, 16], N, K), + dynamic_profiling_shape={"m": 1}, + ), + self.generate_op_unit( + self.generate_operator_config(name, [1, 32], N, K), + dynamic_profiling_shape={"m": 32}, + ), + self.generate_op_unit( + self.generate_operator_config(name, [1, 64], N, K), + dynamic_profiling_shape={"m": 64}, + ), + self.generate_op_unit( + self.generate_operator_config(name, [1, 128], N, K), + dynamic_profiling_shape={"m": 128}, + ), + self.generate_op_unit( + self.generate_operator_config(name, [1, 2048], N, K), + dynamic_profiling_shape={"m": 2048}, + ), + ] + def prepare_benchmark_sets(self): """Prepare benchmark sets.""" self.add_benchmark_set( "FP16xFP16_ACCFP16_NT", [ - self.generate_op_unit( - self.generate_operator_config("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384),), - self.generate_op_unit( - self.generate_operator_config("FP16xFP16_ACCFP16_NT", [1, 1024], 16384, 16384), - dynamic_profiling_shape={"M": 1024}, - ), + *self.prepare_set_group_4x("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 3200), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8640, 3200), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 8640), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 5120), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 13824, 5120), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 13824), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 6656), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 17920, 6656), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 17920), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 1024, 8192), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 8192), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 28672, 8192), + *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 28672), ], ) self.add_benchmark_set( "INT8xINT8_ACCINT32_NT", [ - self.generate_op_unit( - self.generate_operator_config("INT8xINT8_ACCINT32_NT", 16384, 16384, 16384),), - self.generate_op_unit( - self.generate_operator_config("INT8xINT8_ACCINT32_NT", [1, 1024], 16384, 16384), - dynamic_profiling_shape={"M": 1024}, - ), + *self.prepare_set_group_4x("INT8xINT8_ACCINT32_NT", 16384, 16384, 16384), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 3200), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8640, 3200), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 8640), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 5120), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 13824, 5120), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 13824), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 6656), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 17920, 6656), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 17920), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 1024, 8192), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 8192), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 28672, 8192), + *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 28672), ], ) @@ -180,15 +251,15 @@ def legalize_shape(M, N, K, dyn_prof_shape): M: The M dimension (can be an int or a tuple). N: The N dimension (must be an int). K: The K dimension (must be an int). - dyn_prof_shape: The dynamic profiling shape (dict with 'M' key if M is dynamic). + dyn_prof_shape: The dynamic profiling shape (dict with "m" key if M is dynamic). Returns: A string representing the shape in either 'M-N-K' or 'N-K_M' format. """ if isinstance(M, int): return f"{M}-{N}-{K}" - elif dyn_prof_shape and "M" in dyn_prof_shape: - return f"{N}-{K}_{dyn_prof_shape['M']}" + elif dyn_prof_shape and "m" in dyn_prof_shape: + return f"{N}-{K}_{dyn_prof_shape['m']}" else: # Calculate the average of tuple M opt_m = sum(M) / len(M) @@ -207,7 +278,7 @@ def legalize_shape(M, N, K, dyn_prof_shape): f"{(2 * benchmark_M * op_config.N * op_config.K / (latency * 1e-3) / 1e12):.3f}" if latency else "N/A") latency_str = "N/A" if latency is None else f"{latency:.3f}" - tuning_time_str = ("N/A" if tuning_time is None else f"{tuning_time:.3f}") + tuning_time_str = "N/A" if tuning_time is None else f"{tuning_time:.3f}" table_data.append([shape, latency_str, throughput, tuning_time_str]) diff --git a/benchmark/operators/compare_benchmark.py b/benchmark/operators/compare_benchmark.py index abafb3099..c45a5a680 100644 --- a/benchmark/operators/compare_benchmark.py +++ b/benchmark/operators/compare_benchmark.py @@ -65,7 +65,7 @@ def legalize_shape(M, N, K, dyn_prof_shape): throughput = (2 * benchmark_M * op_config.N * op_config.K / (latency * 1e-3) / 1e12) if base_latency is not None: base_throughput = (2 * benchmark_M * op_config.N * op_config.K / - (base_latency * 1e-3) / 1e12) + (base_latency * 1e-3) / 1e12) throughput = f"{throughput:.3f}{get_suffix(base_throughput, throughput)}" else: throughput = f"{throughput:.3f}" @@ -113,9 +113,9 @@ def legalize_shape(M, N, K, dyn_prof_shape): help="the head commit id", ) args = parser.parse_args() - + print(f"Comparing base commit {args.base} with head commit {args.head}") - + base_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs(args.base) head_benchmark = BitblasMatmulOpsBenchmark.deserialize_from_logs(args.head) diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 1596b3c86..4bdbbed79 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -51,12 +51,14 @@ def __init__(self, config, sch, mod: Module): self.mod = mod self.code = mod.imported_modules[0].get_source() if mod else None self.latency = 1e9 - self.profile_tensors = [] self.time_evaluator = None - def profile(self): - profile_tensors = self.profile_tensors - return self.time_evaluator(*profile_tensors).mean * 1e3 + def profile(self, data_distribution="uniform"): + func = self.sch.mod["main"] + device = self.config.arch.device + profile_tensors = get_dummy_input_arrays(func, device, distribution=data_distribution) + latency = self.time_evaluator(*profile_tensors).mean * 1e3 + return latency def _apply_config( @@ -172,7 +174,6 @@ def apply_and_build_parallel(func, data_distribution="uniform") -> CompileResult: cpresults = [] - profile_tensors = get_dummy_input_arrays(func, arch.device, distribution=data_distribution) max_workers = min(len(configs), os.cpu_count(), max_workers) # apply config in thread parallel @@ -242,7 +243,6 @@ def tvm_callback_cuda_postproc(code, _): cpresult = CompileResult(config, sch, rt_mod) timer_cuda_mod = rt_mod.time_evaluator( rt_mod.entry_name, arch.device, number=num_repeats) - cpresult.profile_tensors = profile_tensors cpresult.time_evaluator = timer_cuda_mod cpresult.code = code cpresults.append(cpresult) @@ -256,7 +256,7 @@ def tvm_callback_cuda_postproc(code, _): for cpresult in cpresults: config = cpresult.config try: - latency = cpresult.profile() + latency = cpresult.profile(data_distribution=data_distribution) except Exception as e_mesg: logger.debug(f"Evaluation with config failed {e_mesg}") continue diff --git a/bitblas/benchmark/operator/__init__.py b/bitblas/benchmark/operator/__init__.py index c5e7852e3..f59ca34ee 100644 --- a/bitblas/benchmark/operator/__init__.py +++ b/bitblas/benchmark/operator/__init__.py @@ -108,6 +108,8 @@ def run_benchmark( latency = op_inst.profile_latency(dynamic_symbolic_constraints=dynamic_profiling_shape) + op_inst.cleanup() + return latency, tuning_time @abstractmethod diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 7ed8fbc39..184da0b0a 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -306,6 +306,11 @@ def dispatch_tir(self, def _alloc_workspace(self): return torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() + def _free_workspace(self): + # release the workspace if it is None + if self.workspace is not None: + self.workspace = None + def _assign_ladder_permutate_a(self, target: Target, enable_tuning: bool): ladder_permutate_a = None if self.propagate_a: @@ -534,6 +539,9 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: def __call__(self, *args: Any, **kwds: Any) -> Any: return self.forward(*args, **kwds) + def cleanup(self): + self._free_workspace() + @property def M(self): return self.config.M diff --git a/bitblas/ops/matmul.py b/bitblas/ops/matmul.py index af0370294..e515a264c 100644 --- a/bitblas/ops/matmul.py +++ b/bitblas/ops/matmul.py @@ -209,7 +209,6 @@ def var_warpper(v, m): [var_warpper(i, m) for i in arg.shape]).astype(arg.dtype), device=device, )) - self.profile_tensors = profile_tensors latency = self.time_evaluator(*profile_tensors).mean * 1e3 benchmark_latencies.append({"m": m, "latency": latency}) # ms diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 8617d70b9..d35476ee5 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -47,7 +47,6 @@ def __init__(self, name, config: OperatorConfig, target: Target = None): self.optimized_func = None self.rt_mod = None self.time_evaluator = None - self.profile_tensors = None self.arch = get_arch(target) if target else None self.dynamic_range = None self.pass_context: Dict = {} @@ -262,7 +261,6 @@ def map_numpy_type(intype): [var_warpper(i) for i in arg.shape]).astype(numpy_dtype), device=device, )) - self.profile_tensors = profile_tensors return profile_tensors def profile_latency(self, dynamic_symbolic_constraints: Optional[Dict] = None) -> str: @@ -270,6 +268,9 @@ def profile_latency(self, dynamic_symbolic_constraints: Optional[Dict] = None) - dynamic_symbolic_constraints = {} profile_tensors = self.get_profile_tensors(dynamic_symbolic_constraints) latency = self.time_evaluator(*profile_tensors).mean * 1e3 + # release the memory + for tensor in profile_tensors: + del tensor return latency def _tensor_adapter(self, tensor, device): @@ -325,6 +326,9 @@ def update_runtime_module(self, rt_mod, srcpath=None, libpath=None): self.lib.init() # TODO: update the lib code from srcpath + def cleanup(self): + raise NotImplementedError + @abstractmethod def _select_implementation(self) -> IRModule: pass From ef1b1582995cdd38e5cabc3be6f7d3991a9d2af7 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 25 Jul 2024 16:01:59 +0000 Subject: [PATCH 85/88] upodate tvm cmmit --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index d9391a502..8dff258d2 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit d9391a502b5544722eb67c4a0c4dff49a3476c06 +Subproject commit 8dff258d2837b2c0d24619ebf26dd596b2291912 From a8d8841c7f276e187b46ed3e90e84eea0997c12d Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 30 Jul 2024 10:26:43 +0000 Subject: [PATCH 86/88] Imporve lower warp memory pass --- 3rdparty/tvm | 2 +- bitblas/gpu/matmul_mma.py | 14 +++++++------- bitblas/gpu/matmul_mma_dequantize.py | 12 ++++++------ 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 8dff258d2..f136a2233 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 8dff258d2837b2c0d24619ebf26dd596b2291912 +Subproject commit f136a2233005931e3c997c2af74ae47f0a115b74 diff --git a/bitblas/gpu/matmul_mma.py b/bitblas/gpu/matmul_mma.py index 4bf8be4e6..629d93768 100644 --- a/bitblas/gpu/matmul_mma.py +++ b/bitblas/gpu/matmul_mma.py @@ -8,13 +8,13 @@ from tvm import tir, DataType from tvm.target import Target -from ..base.roller import Hint -from ..base.roller.rasterization import NoRasterization -from ..base import analysis -from .base import GPUScheduleRule -from .matmul_mma_dequantize import MatmulTensorizationMMAWithDequantizeInfo -from ..base.analysis import get_coalesced_veclen -from .matmul_analysis import ( +from bitblas.base.roller import Hint +from bitblas.base.roller.rasterization import NoRasterization +from bitblas.base import analysis +from bitblas.gpu.base import GPUScheduleRule +from bitblas.gpu.matmul_mma_dequantize import MatmulTensorizationMMAWithDequantizeInfo +from bitblas.base.analysis import get_coalesced_veclen +from bitblas.gpu.matmul_analysis import ( auto_inline_consumer_chain, is_transpose_block, is_identity_block, diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index de1b5b896..3a8cf048a 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -8,13 +8,13 @@ from tvm import tir, DataType -from ..base.roller.hint import Hint, IntrinInfo +from bitblas.base.roller.hint import Hint, IntrinInfo from tvm.target import Target -from ..base.roller.rasterization import NoRasterization -from ..base import analysis -from .base import GPUScheduleRule -from ..base.analysis import get_coalesced_veclen -from .matmul_analysis import ( +from bitblas.base.roller.rasterization import NoRasterization +from bitblas.base import analysis +from bitblas.gpu.base import GPUScheduleRule +from bitblas.base.analysis import get_coalesced_veclen +from bitblas.gpu.matmul_analysis import ( auto_inline_consumer_chain, auto_inline_producers, get_reduction_blocks, From 7736c38007b23a5cf0be7fe0bbc1635d27aa251b Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 30 Jul 2024 12:22:46 +0000 Subject: [PATCH 87/88] Bug fix --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index f136a2233..bcdcec2c7 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit f136a2233005931e3c997c2af74ae47f0a115b74 +Subproject commit bcdcec2c7502a565ea724fc145eebc993eebc484 From 199affcc5fc3b139d400fb87ad8e4782680bd00f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 31 Jul 2024 05:52:16 +0000 Subject: [PATCH 88/88] Enhance to support warp schedule. --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index bcdcec2c7..0b0faa5cd 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit bcdcec2c7502a565ea724fc145eebc993eebc484 +Subproject commit 0b0faa5cd7ae077bc730c2638bf2ab29adaede5d