From 3aa943975577a18f725a542f45c0e2ed98559857 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 4 Sep 2024 10:16:40 +0800 Subject: [PATCH 1/2] chore(deps): bump actions/download-artifact in /.github/workflows (#175) Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 3 to 4.1.7. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v3...v4.1.7) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/benchmark.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 013345f6f..c19ed8495 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -106,13 +106,13 @@ jobs: steps: - name: Download commit IDs - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: name: base-commit-id path: . - name: Download PR commit ID - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: name: pr-commit-id path: . From c15744ed2ea2cb0c673e4503c59d825525ed572b Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 4 Sep 2024 11:26:55 +0800 Subject: [PATCH 2/2] [TL] Add TL Layout and Macro utils (#174) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * buf fix for matrix support * lint fix * dispatch tensor core based on shapes * update install commands * import scripts * remove shared mem hack * revert change for swizzling * bug fix * tl examples * Enhance Swizzle * lint fix * test fix * lint fix * optimize layout * update tl utils. * macro optimization * test fix --- 3rdparty/tvm | 2 +- bitblas/tl/__init__.py | 10 + bitblas/tl/macro_generator.py | 211 ++++++++++++++++++ bitblas/tl/utils.py | 119 ++++++++++ integration/BitNet/utils_quant.py | 10 +- .../tilelang/test_tilelang_dequantize_gemm.py | 3 - 6 files changed, 344 insertions(+), 11 deletions(-) create mode 100644 bitblas/tl/__init__.py create mode 100644 bitblas/tl/macro_generator.py create mode 100644 bitblas/tl/utils.py diff --git a/3rdparty/tvm b/3rdparty/tvm index a1d78ebc6..32c5c790b 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit a1d78ebc682dbaec70e792470c6842b9ec3342c6 +Subproject commit 32c5c790baffe5fa605de52e70640ce67b30f4e6 diff --git a/bitblas/tl/__init__.py b/bitblas/tl/__init__.py new file mode 100644 index 000000000..69e20496b --- /dev/null +++ b/bitblas/tl/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .utils import ( + get_swizzle_layout, # noqa: F401 + mma_store_index_map, # noqa: F401 + get_ldmatrix_offset, # noqa: F401 +) + +from .macro_generator import TensorCorePTXMacroGenerator # noqa: F401 diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/macro_generator.py new file mode 100644 index 000000000..790281a1f --- /dev/null +++ b/bitblas/tl/macro_generator.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tvm.tl.language as T + +from tvm import DataType +from tvm.runtime import convert +from .utils import ( + mma_store_index_map, + get_ldmatrix_offset, +) + +lift = convert + + +class TensorCorePTXMacroGenerator(object): + """ + To eliminate Python syntax within TIR Macro. + """ + + M_DIM = 16 + N_DIM = 16 + WARP_SIZE = 32 + dtype_abbrv = { + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "int8": "int8", + "int32": "int32", + "e4m3_float8": "e4m3", + "e5m2_float8": "e5m2", + } + + def __init__( + self, + a_dtype="float16", + b_dtype="float16", + accum_dtype="float16", + a_transposed=False, + b_transposed=False, + block_row_warps=2, + block_col_warps=2, + warp_row_tiles=8, + warp_col_tiles=8, + chunk=16, + threads=128, + ): + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.accum_dtype = accum_dtype + self.a_transposed = a_transposed + self.b_transposed = b_transposed + # Hint Information + self.block_row_warps = block_row_warps + self.block_col_warps = block_col_warps + self.warp_row_tiles = warp_row_tiles + self.warp_col_tiles = warp_col_tiles + self.chunk = chunk + self._initialize_k_dim(a_dtype) + self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) + self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) + self._initialize_mma_prefix(self.k_dim) + self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) + self.warp_rows = warp_row_tiles // self.micro_size_x + self.warp_cols = warp_col_tiles // self.micro_size_y + self._initialize_thread_axis(threads, self.WARP_SIZE, block_row_warps, block_col_warps) + + def _initialize_k_dim(self, a_dtype="float16"): + self.k_dim = 256 // DataType(a_dtype).bits + + def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): + self.local_size_a = (m_dim * k_dim) // warp_size + self.local_size_b = (n_dim * k_dim) // warp_size + self.local_size_out = (m_dim * n_dim) // warp_size + + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): + self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] + self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] + self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] + + def _initialize_mma_prefix(self, k_dim=16): + if k_dim == 16: + self.mma_prefix = "m16n8k16" + elif k_dim == 32: + self.mma_prefix = "m16n8k32" + else: + raise ValueError("Unsupported k_dim") + + def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): + self.micro_size_x = m_dim + self.micro_size_y = n_dim + self.micro_size_k = k_dim + + def _initialize_thread_axis(self, + threads=128, + warp_size=32, + block_row_warps=2, + block_col_warps=2): + self.threads = threads + # thread_bindings = T.env_thread("threadIdx.x") + # self.tx = thread_bindings % warp_size + # self.ty = (thread_bindings // warp_size) % block_row_warps + # self.tz = thread_bindings // (warp_size * block_row_warps) + + @staticmethod + @T.macro + def MMA(inst, A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(inst.warp_rows, inst.warp_cols): + T.ptx_mma( + inst.accum_dtype, + "m16n8k16", + "row", + "col", + inst.a_dtype_abbrv, + inst.b_dtype_abbrv, + inst.accum_dtype_abbrv, + A_local_buf.data, + i * inst.local_size_a, + B_local_buf.data, + j * inst.local_size_b, + C_local_buf.data, + i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out, + T.bool(False), + ) + + T.ptx_mma( + inst.accum_dtype, + "m16n8k16", + "row", + "col", + inst.a_dtype_abbrv, + inst.b_dtype_abbrv, + inst.accum_dtype_abbrv, + A_local_buf.data, + i * inst.local_size_a, + B_local_buf.data, + j * inst.local_size_b + lift(inst.local_size_b) // 2, + C_local_buf.data, + i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out + + lift(inst.local_size_out) // 2, + T.bool(False), + ) + + @staticmethod + @T.macro + def LDMATRIX_A( + inst, + A_local_buf, + A_shared_buf, + ki, + thread_bindings, + ): + stride = inst.chunk + tx = thread_bindings % inst.WARP_SIZE + ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps + # self.ty = (thread_bindings // warp_size) % block_row_warps + # self.tz = thread_bindings // (warp_size * block_row_warps) + for i in T.serial(inst.warp_rows): + T.ptx_ldmatrix( + "float16", + T.bool(False), + 4, + ".b16", + A_local_buf.data, + i * inst.local_size_a, + T.address_of(A_shared_buf[ty * inst.warp_row_tiles + i * inst.micro_size_x, + ki * inst.micro_size_k,]), + get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, False), + ) + + @staticmethod + @T.macro + def LDMATRIX_B( + inst, + B_local_buf, + B_shared_buf, + ki, + thread_bindings, + ): + stride = inst.chunk + tx = thread_bindings % inst.WARP_SIZE + tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps) + for j in T.serial(inst.warp_cols): + T.ptx_ldmatrix( + "float16", + T.bool(False), # TODO(lei): should be optimized + 4, + ".b16", + B_local_buf.data, + j * inst.local_size_b, + T.address_of(B_shared_buf[tz * inst.warp_col_tiles + j * inst.micro_size_y, + ki * inst.micro_size_k,]), + get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, True), + ) + + # STS + # MMA Store must be in simulated instead of TVM Intrins + # As TVM Intrins is like a hack that the threadIdx.x should be always + # equal to the warp_size + @staticmethod + @T.macro + def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings): + tx = thread_bindings % inst.WARP_SIZE + ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps + tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps) + for i, j in T.grid(inst.warp_rows, inst.warp_cols): + for local_id in T.serial(inst.local_size_out): + row, col = T.meta_var(mma_store_index_map(tx, local_id)) + C_shared_buf[ty * inst.warp_rows + i, tz * inst.warp_cols + j, row, + col] = C_local_buf[i * (inst.warp_cols * inst.local_size_out) + + j * inst.local_size_out + local_id] diff --git a/bitblas/tl/utils.py b/bitblas/tl/utils.py new file mode 100644 index 000000000..d0df62cfa --- /dev/null +++ b/bitblas/tl/utils.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm import arith +from tvm import DataType +from typing import Union, Literal + + +def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): + ana = arith.Analyzer() + BANK_SIZE_BYTES = 128 + if isinstance(dtype, str): + dtype = DataType(dtype) + col_idx_outer, col_idx_inner = col_idx // (BANK_SIZE_BYTES // dtype.bits), col_idx % ( + BANK_SIZE_BYTES // dtype.bits) + # use transaction bits to support diverse dtype. + # for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits + # for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits + coalescent_bits = dtype.bits * row_size + # permutation on 4 banks, each bank has 32 bits + bank_elems = BANK_SIZE_BYTES // dtype.bits + new_col_idx_outer = None + print(f"coalescent_bits: {coalescent_bits}") + if coalescent_bits % 1024 == 0: + # Use 8 * 8 permuted layout + # Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read + # Every row below corresponds to 32 banks + # 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 + # 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 + # 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5 + # 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4 + # 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3 + # 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2 + # 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1 + # 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0 + row_idx_sub = row_idx % bank_elems + new_col_idx_outer = col_idx_outer ^ row_idx_sub + else: + assert coalescent_bits % 512 == 0 + # Use 8 * 4 permuted layout + # Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read + # Every row below corresponds to 16 banks + # 0 1 2 3 ==> 0 1 2 3 + # 0 1 2 3 ==> 0 1 2 3 + # 0 1 2 3 ==> 1 0 3 2 + # 0 1 2 3 ==> 1 0 3 2 + # 0 1 2 3 ==> 2 3 0 1 + # 0 1 2 3 ==> 2 3 0 1 + # 0 1 2 3 ==> 3 2 1 0 + # 0 1 2 3 ==> 3 2 1 0 + # View with 8 elements per row: + # 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3 + # 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2 + # 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1 + # 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0 + row_idx_sub = row_idx % bank_elems + # Interleave elems per byte + interleave_elems = 32 // dtype.bits + new_col_idx_outer = col_idx_outer ^ (row_idx_sub // interleave_elems) + + assert (new_col_idx_outer is not None), f"Unsupported dtype {dtype} with {coalescent_bits} bits" + return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner) + + +def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): + row = thread_id % 16 + col = 8 * (thread_id // 16) + local_id % 8 + return row, col + + +def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): + row = 8 * (thread_id // 16) + (thread_id % 8) + col = 8 * ((thread_id % 16) // 8) + local_id % 8 + return row, col + + +def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id): + row = thread_id % 16 + col = local_id + (thread_id // 16) * 16 + return row, col + + +def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id): + row = (thread_id // 16) * 8 + (thread_id % 8) + col = local_id + 16 * ((thread_id % 16) // 8) + return row, col + + +def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): + row = 8 * (local_id % 4 // 2) + (thread_id // 4) + col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2) + return row, col + + +def get_ldmatrix_offset( + matrix: Literal["A", "B"], + row_idx, + col_idx, + stride, + dtype: Literal["float16", "int8"] = "float16", + transpose: bool = False, +): + assert matrix in ["A", "B"], "matrix should be either A or B" + transform_func = ( + ldmatrix_32x8_to_shared_16x16_layout + if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_b) + transform_func_trans = ( + ldmatrix_trans_32x8_to_shared_16x16_layout + if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_a) + if matrix == "A": + assert not transpose, "A matrix should not be transposed" + new_row_idx, new_col_idx = transform_func(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + else: + new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx) + return new_row_idx * stride + new_col_idx + + +def mma_store_index_map(*args, **kwargs): + return mma_store_32x8_to_shared_16x16_layout(*args, **kwargs) diff --git a/integration/BitNet/utils_quant.py b/integration/BitNet/utils_quant.py index 3da74c213..a1c0a8fc9 100644 --- a/integration/BitNet/utils_quant.py +++ b/integration/BitNet/utils_quant.py @@ -165,7 +165,7 @@ def activation_quant(self, x, num_bits=8): Qp = 2**(num_bits - 1) - 1 s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) result = (x * s).round().clamp(Qn, Qp) - return result.type(torch.int8) + return result.type(torch.int8), s @torch.compile def post_quant_process(self, input, si, sw): @@ -186,7 +186,7 @@ def native_forward(self, input): return out def forward_fp32_simulated(self, input): - quant_input = self.activation_quant(input, self.input_bits).detach() + quant_input, si = self.activation_quant(input, self.input_bits).detach() quant_weight = self.weight_quant(self.weight).detach() fp32_simulated_input = quant_input.float() @@ -194,8 +194,6 @@ def forward_fp32_simulated(self, input): fp32_simulated_out = nn.functional.linear(fp32_simulated_input, fp32_simulated_weight) sw = 1 / self.weight.abs().mean().clamp(min=1e-5) - Qp = 2**(self.input_bits - 1) - 1 - si = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) # if / (si * sw) it will inf in some cases out = fp32_simulated_out / si out = out / sw @@ -206,11 +204,9 @@ def forward_fp32_simulated(self, input): def forward(self, input): # return self.forward_fp32_simulated(input) - quant_input = self.activation_quant(input, self.input_bits).detach() + quant_input, si = self.activation_quant(input, self.input_bits) fp32_out = self.bitblas_matmul(quant_input, self.qweight) sw = self.sw - Qp = self.Qp - si = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) # if / (si * sw) it will inf in some cases out = self.post_quant_process(fp32_out, si, sw) diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index 9db978cd9..574bac15a 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -113,9 +113,6 @@ def run_gemm( print(f"output is {out}") - with open("debug/kernel.cu", "w") as f: - f.write(mod.mod.imported_modules[0].get_source()) - def ref_program(A, qB): import torch