From 150815bd17d5905c4f72369626ec890fba83bca4 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Thu, 26 Sep 2024 20:31:27 +0800 Subject: [PATCH] [TL] Append Macro Test Case for GEMM and Dequant GEMM (#190) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * buf fix for matrix support * lint fix * dispatch tensor core based on shapes * update install commands * import scripts * remove shared mem hack * revert change for swizzling * bug fix * tl examples * Enhance Swizzle * lint fix * test fix * lint fix * optimize layout * update tl utils. * macro optimization * test fix * gemm_ss * doc fix * lint fix * lint fix * remove debug print * remove debug print * vectorization init * lint fix * prelude update * update tvm * bug fix for reduce_k with shared memory * bug fix * bug fix * Enhance Macro Generation * Lift Layout to reduce load time * lint fix * test fix * red fix * tile lang macro example * tile lang macro example * optimize the marcro generator related items * lint fix --- 3rdparty/tvm | 2 +- bitblas/tl/__init__.py | 5 +- bitblas/tl/macro_generator.py | 208 +--- .../tilelang/test_tilelang_macro_gemm.py | 893 ++++++++++++++++++ 4 files changed, 937 insertions(+), 171 deletions(-) create mode 100644 testing/python/tilelang/test_tilelang_macro_gemm.py diff --git a/3rdparty/tvm b/3rdparty/tvm index 68969a60..c7a8c4ee 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 68969a6008a639ce937075e6ad75cb417a7c3ed6 +Subproject commit c7a8c4eef38315f9e1796818b1c3d4c68f85a8af diff --git a/bitblas/tl/__init__.py b/bitblas/tl/__init__.py index 69e20496..919b7066 100644 --- a/bitblas/tl/__init__.py +++ b/bitblas/tl/__init__.py @@ -7,4 +7,7 @@ get_ldmatrix_offset, # noqa: F401 ) -from .macro_generator import TensorCorePTXMacroGenerator # noqa: F401 +from .macro_generator import ( + TensorCoreIntrinEmitter, # noqa: F401 + TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 +) diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/macro_generator.py index f863fa62..0f0b361c 100644 --- a/bitblas/tl/macro_generator.py +++ b/bitblas/tl/macro_generator.py @@ -15,7 +15,7 @@ lift = convert -class TensorCorePTXMacroGenerator(object): +class TensorCoreIntrinEmitter(object): """ To eliminate Python syntax within TIR Macro. """ @@ -45,8 +45,6 @@ def __init__(self, warp_col_tiles=8, chunk=16, reduce_k=1, - transform_kind_a: Union[int, TransformKind] = 0, - transform_kind_b: Union[int, TransformKind] = 0, num_elems_per_byte=1): self.a_dtype = a_dtype self.b_dtype = b_dtype @@ -68,7 +66,6 @@ def __init__(self, self.warp_cols = warp_col_tiles // self.micro_size_y self.reduce_k = reduce_k self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k - self._initialize_transform_kind(transform_kind_a, transform_kind_b) self.num_elems_per_byte = num_elems_per_byte def _initialize_k_dim(self, a_dtype="float16"): @@ -99,26 +96,8 @@ def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): self.micro_size_y = n_dim self.micro_size_k = k_dim - def _initialize_transform_kind(self, transform_kind_a, transform_kind_b): - if isinstance(transform_kind_a, int): - self.transform_kind_a = TransformKind(transform_kind_a) - elif isinstance(transform_kind_a, TransformKind): - self.transform_kind_a = transform_kind_a - else: - raise ValueError("Unsupported transform_kind_a") - - if isinstance(transform_kind_b, int): - self.transform_kind_b = TransformKind(transform_kind_b) - elif isinstance(transform_kind_b, TransformKind): - self.transform_kind_b = transform_kind_b - else: - raise ValueError("Unsupported transform_kind_b") - - assert transform_kind_b in [0, 3], "Currently only support 0 and 3" - - @staticmethod @T.macro - def LDMATRIX_A( + def _warp_ldmatrix_a( inst, A_local_buf, A_shared_buf, @@ -143,9 +122,8 @@ def LDMATRIX_A( get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, inst.a_transposed), ) - @staticmethod @T.macro - def LDMATRIX_B( + def _warp_ldmatrix_b( inst, B_local_buf, B_shared_buf, @@ -173,9 +151,8 @@ def LDMATRIX_B( get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, inst.b_transposed), ) - @staticmethod @T.macro - def MMA(inst, A_local_buf, B_local_buf, C_local_buf): + def _warp_mma(inst, A_local_buf, B_local_buf, C_local_buf): for i, j in T.grid(inst.warp_rows, inst.warp_cols): T.ptx_mma( inst.accum_dtype, @@ -216,9 +193,8 @@ def MMA(inst, A_local_buf, B_local_buf, C_local_buf): # MMA Store must be in simulated instead of TVM Intrins # As TVM Intrins is like a hack that the threadIdx.x should be always # equal to the warp_size - @staticmethod @T.macro - def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings): + def _warp_stmatrix(inst, C_local_buf, C_shared_buf, thread_bindings): tx = thread_bindings % inst.WARP_SIZE ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps tz = (thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)) % inst.block_col_warps @@ -231,55 +207,25 @@ def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings): col] = C_local_buf[i * (inst.warp_cols * inst.local_size_out) + j * inst.local_size_out + local_id] - # Allow GEMM from shared memory to local memory - @staticmethod - @T.macro - def GEMM_SS(inst, A_shared_buf, B_shared_buf, C_local_buf, thread_bindings): - # TODO(lei): alloc_buffer within the macro is not supported yet. - A_local_buf = T.alloc_fragment((inst.warp_rows * inst.local_size_a), - inst.a_dtype, - scope="local") - B_local_buf = T.alloc_fragment((inst.warp_cols * inst.local_size_b), - inst.b_dtype, - scope="local") - for ki in T.serial(0, (inst.chunk // inst.micro_size_k)): - inst.LDMATRIX_A( - inst, - A_local_buf, - A_shared_buf, - ki, - thread_bindings=thread_bindings, - ) + def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0): + return self._warp_ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk) - inst.LDMATRIX_B( - inst, - B_local_buf, - B_shared_buf, - ki, - thread_bindings=thread_bindings, - ) + def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): + return self._warp_ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk) + + def mma(self, A_local_buf, B_local_buf, C_local_buf): + return self._warp_mma(self, A_local_buf, B_local_buf, C_local_buf) - inst.MMA(inst, A_local_buf, B_local_buf, C_local_buf) + def stmatrix(self, C_local_buf, C_shared_buf, thread_bindings): + return self._warp_stmatrix(self, C_local_buf, C_shared_buf, thread_bindings) -class TensorCorePTXMacroGeneratorWithLadderTransform(object): +class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): """ To eliminate Python syntax within TIR Macro. + With Ladder Transform Plugin. """ - M_DIM = 16 - N_DIM = 16 - WARP_SIZE = 32 - dtype_abbrv = { - "float16": "fp16", - "bfloat16": "bf16", - "float32": "fp32", - "int8": "int8", - "int32": "int32", - "e4m3_float8": "e4m3", - "e5m2_float8": "e5m2", - } - def __init__( self, a_dtype="float16", @@ -297,28 +243,21 @@ def __init__( transform_kind_b: Union[int, TransformKind] = 0, num_elems_per_byte=1, ): - self.a_dtype = a_dtype - self.b_dtype = b_dtype - self.accum_dtype = accum_dtype - self.a_transposed = a_transposed - self.b_transposed = b_transposed - # Hint Information - self.block_row_warps = block_row_warps - self.block_col_warps = block_col_warps - self.warp_row_tiles = warp_row_tiles - self.warp_col_tiles = warp_col_tiles - self.chunk = chunk - self._initialize_k_dim(a_dtype) - self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) - self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) - self._initialize_mma_prefix(self.k_dim) - self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) - self.warp_rows = warp_row_tiles // self.micro_size_x - self.warp_cols = warp_col_tiles // self.micro_size_y - self.reduce_k = reduce_k - self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k + super().__init__( + a_dtype=a_dtype, + b_dtype=b_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + reduce_k=reduce_k, + num_elems_per_byte=num_elems_per_byte, + ) self._initialize_transform_kind(transform_kind_a, transform_kind_b) - self.num_elems_per_byte = num_elems_per_byte def _initialize_k_dim(self, a_dtype="float16"): self.k_dim = 256 // DataType(a_dtype).bits @@ -361,38 +300,13 @@ def _initialize_transform_kind(self, transform_kind_a, transform_kind_b): else: raise ValueError("Unsupported transform_kind_b") - assert transform_kind_b in [0, 3], "Currently only support 0 and 3" - - @staticmethod - @T.macro - def LDMATRIX_A( - inst, - A_local_buf, - A_shared_buf, - ki, - thread_bindings, - rk=0, - ): - stride = A_shared_buf.shape[-1] - tx = thread_bindings % inst.WARP_SIZE - ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps + if self.transform_kind_a != TransformKind.NonTransform: + raise ValueError("TransformKind A is not supported yet") - for i in T.serial(inst.warp_rows): - T.ptx_ldmatrix( - inst.a_dtype, - T.bool(False), - 4, - ".b16", - A_local_buf.data, - i * inst.local_size_a, - T.address_of(A_shared_buf[ty * inst.warp_row_tiles + i * inst.micro_size_x, - rk * inst.chunk + ki * inst.micro_size_k,]), - get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, inst.a_transposed), - ) + assert transform_kind_b in [0, 3], "Currently only support 0 and 3" - @staticmethod @T.macro - def LDMATRIX_B( + def _warp_ldmatrix_b( inst, B_local_buf, B_shared_buf, @@ -436,9 +350,8 @@ def LDMATRIX_B( B_local_buf[j * local_size_dequantize + local_id] = B_shared_buf[ri, rj, rii, rjj] - @staticmethod @T.macro - def MMA(inst, A_local_buf, B_local_buf, C_local_buf): + def _warp_mma(inst, A_local_buf, B_local_buf, C_local_buf): for i, j in T.grid(inst.warp_rows, inst.warp_cols): T.ptx_mma( inst.accum_dtype, @@ -475,51 +388,8 @@ def MMA(inst, A_local_buf, B_local_buf, C_local_buf): T.bool(False), ) - # STS - # MMA Store must be in simulated instead of TVM Intrins - # As TVM Intrins is like a hack that the threadIdx.x should be always - # equal to the warp_size - @staticmethod - @T.macro - def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings): - tx = thread_bindings % inst.WARP_SIZE - ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps - tz = (thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)) % inst.block_col_warps - for i, j in T.grid(inst.warp_rows, inst.warp_cols): - for local_id_o in T.serial(inst.local_size_out // 2): - for local_id_i in T.vectorized(2): - local_id = local_id_o * 2 + local_id_i - row, col = T.meta_var(mma_store_index_map(tx, local_id)) - C_shared_buf[ty * inst.warp_rows + i, tz * inst.warp_cols + j, row, - col] = C_local_buf[i * (inst.warp_cols * inst.local_size_out) + - j * inst.local_size_out + local_id] - - # Allow GEMM from shared memory to local memory - @staticmethod - @T.macro - def GEMM_SS(inst, A_shared_buf, B_shared_buf, C_local_buf, thread_bindings): - # TODO(lei): alloc_buffer within the macro is not supported yet. - A_local_buf = T.alloc_fragment((inst.warp_rows * inst.local_size_a), - inst.a_dtype, - scope="local") - B_local_buf = T.alloc_fragment((inst.warp_cols * inst.local_size_b), - inst.b_dtype, - scope="local") - for ki in T.serial(0, (inst.chunk // inst.micro_size_k)): - inst.LDMATRIX_A( - inst, - A_local_buf, - A_shared_buf, - ki, - thread_bindings=thread_bindings, - ) - - inst.LDMATRIX_B( - inst, - B_local_buf, - B_shared_buf, - ki, - thread_bindings=thread_bindings, - ) + def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): + return self._warp_ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk) - inst.MMA(inst, A_local_buf, B_local_buf, C_local_buf) + def mma(self, A_local_buf, B_local_buf, C_local_buf): + return self._warp_mma(self, A_local_buf, B_local_buf, C_local_buf) diff --git a/testing/python/tilelang/test_tilelang_macro_gemm.py b/testing/python/tilelang/test_tilelang_macro_gemm.py new file mode 100644 index 00000000..9d797ff6 --- /dev/null +++ b/testing/python/tilelang/test_tilelang_macro_gemm.py @@ -0,0 +1,893 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import get_swizzle_layout +from bitblas.tl.macro_generator import ( + TensorCoreIntrinEmitter, + TensorCoreIntrinEmitterWithLadderTransform, +) +from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 + +torch.manual_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +def tl_matmul( + M, + N, + K, + dtypeAB, + dtypeC, + accum_dtype, +): + assert dtypeAB in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert dtypeC in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if dtypeC == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 1 + block_col_warps = 1 + warp_row_tiles = 16 + warp_col_tiles = 16 + chunk = 32 if dtypeAB == "float16" else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=dtypeAB, + b_dtype=dtypeAB, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, dtypeAB), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), dtypeAB) + B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, + bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, + i % micro_size_x, j % micro_size_y,] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, dtypeC, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, dtypeC, accum_dtype) + + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def tl_matmul_with_block_reduce( + M, + N, + K, + dtypeAB, + dtypeC, + accum_dtype, +): + assert dtypeAB in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert dtypeC in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if dtypeC == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 1 + block_col_warps = 1 + warp_row_tiles = 16 + warp_col_tiles = 16 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = 32 if dtypeAB == "float16" else 64 + reduce_k = 2 + chunk = block_K // reduce_k + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=dtypeAB, + b_dtype=dtypeAB, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + reduce_k=reduce_k) + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, dtypeAB), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), dtypeAB) + B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + reduced_accum_res = T.alloc_local(0, accum_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + rk = T.thread_binding(0, reduce_k, "threadIdx.y") + + if block_K == 32: # Swizzling only works for chunk size 32 + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, (block_K // reduce_k)): + vk = rk * (block_K // reduce_k) + k + A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk] + + # Load B into shared memory + for j, k in T.Parallel(block_N, (block_K // reduce_k)): + vk = rk * (block_K // reduce_k) + k + B_shared[j, vk] = B[bx * block_N + j, ko * block_K + vk] + + for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + rk=rk, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + rk=rk, + ) + + mma_emitter.mma(A_local, B_local, C_local) + + for n in T.serial(warp_rows * warp_cols * local_size): + init_value = getattr(T, accum_dtype)(0) + T.attr( + T.comm_reducer(lambda x, y: x + y, [init_value]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + C_local[n], + True, + reduced_accum_res[0], + rk, + dtype="handle", + )) + if rk == 0: + C_local[n] = reduced_accum_res[0] + + if rk == 0: + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + for i, j in T.Parallel(block_M, (block_N // reduce_k)): + vj = rk * (block_N // reduce_k) + j + C[by * block_M + i, + bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y, + i % micro_size_x, vj % micro_size_y] + + return main + + +def assert_tl_matmul_with_block_reduce_correctness(M, N, K, in_dtype, dtypeC, accum_dtype): + matmul = tl_matmul_with_block_reduce(M, N, K, in_dtype, dtypeC, accum_dtype) + + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def tl_matmul_with_ladder_weight_only_transform( + M, + N, + K, + dtypeAB, + dtypeC, + accum_dtype, + transform_b, +): + assert dtypeAB in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert dtypeC in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if dtypeC == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + + warp_rows = 2 + warp_cols = 2 + warp_row_tiles = micro_size_x * warp_rows + warp_col_tiles = micro_size_y * warp_cols + + chunk = 64 if dtypeAB == "float16" else 128 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + is_smooth_a = False + can_swizzle = block_K * DataType(dtypeAB).bits == 512 + apply_pad_a = not (is_smooth_a or can_swizzle) + pad_factor = 8 + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + B_shared_shape = (block_N // micro_size_y, block_K // micro_size_k, micro_size_y, micro_size_k) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=dtypeAB, + b_dtype=dtypeAB, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=transform_b) + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, dtypeAB), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), dtypeAB) + B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + }) + + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, + micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, jj, kk] + + for ki in T.serial(0, (block_K // micro_size_k)): + + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + mma_emitter.mma(A_local, B_local, C_local) + + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, + bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, + i % micro_size_x, j % micro_size_y] + + return main + + +def assert_tl_matmul_with_ladder_weight_only_transform_correctness(M, N, K, in_dtype, dtypeC, + accum_dtype, transform_b): + matmul = tl_matmul_with_ladder_weight_only_transform(M, N, K, in_dtype, dtypeC, accum_dtype, + transform_b) + + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + transform_kind=transform_b, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + LB = ladder_permutate(B.cpu()).cuda() + + mod(A, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( + M, + N, + K, + dtypeAB, + dtypeC, + accum_dtype, + transform_b, +): + assert dtypeAB in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert dtypeC in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + num_bits = 4 + num_elems_per_byte = 8 // num_bits + storage_dtype = "int8" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if dtypeC == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 1 + block_col_warps = 4 + + warp_rows = 1 + warp_cols = 2 + warp_row_tiles = micro_size_x * warp_rows + warp_col_tiles = micro_size_y * warp_cols + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + reduce_k = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = 32 if dtypeAB == "float16" else 64 + chunk = block_K // reduce_k + + is_smooth_a = False + can_swizzle = block_K * DataType(dtypeAB).bits == 512 + apply_pad_a = not (is_smooth_a or can_swizzle) + pad_factor = 8 + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, + micro_size_k // num_elems_per_byte) + A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k // num_elems_per_byte, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=dtypeAB, + b_dtype=dtypeAB, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + reduce_k=reduce_k, + transform_kind_b=transform_b, + num_elems_per_byte=num_elems_per_byte) + + vec_load_qb = 16 + if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb: + vec_load_qb = block_N * (block_K // reduce_k) // num_elems_per_byte // threads + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, + prelude=decode_i4_to_f16) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), dtypeAB) + B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype) + B_dequantize_local = T.alloc_local((warp_cols * local_size), dtypeAB) + C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + reduced_accum_res = T.alloc_local(0, accum_dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + rk = T.thread_binding(0, reduce_k, "threadIdx.y") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + }) + + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, (block_K // reduce_k)): + vk = rk * (block_K // reduce_k) + k + A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk] + + # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load + for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // + (threads * vec_load_qb)): + for v in T.vectorized(0, vec_load_qb): + t = thread_bindings + idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v + vkk = idx % (micro_size_k // num_elems_per_byte) + vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y + vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % ( + block_K // micro_size_k) + vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // + (block_K // micro_size_k)) % ( + block_N // micro_size_y) + B_shared[vj, vk, vjj, + vkk] = B[bx * (block_N // micro_size_y) + vj, + ko * (block_K // micro_size_k) + vk, vjj, vkk] + + for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + rk=rk, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + rk=rk, + ) + + for j in T.serial(warp_cols): + local_size_b = mma_emitter.local_size_b + T.call_extern('handle', 'decode_i4u_to_f16', + T.address_of(B_local[j * local_size_b // num_elems_per_byte]), + T.address_of(B_dequantize_local[j * local_size_b]), 8) + + mma_emitter.mma(A_local, B_dequantize_local, C_local) + + if reduce_k > 1: + for n in T.serial(warp_rows * warp_cols * local_size): + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float16(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + C_local[n], + True, + reduced_accum_res[0], + rk, + dtype="handle", + )) + if rk == 0: + C_local[n] = reduced_accum_res[0] + + if rk == 0: + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + for i, j in T.Parallel(block_M, (block_N // reduce_k)): + vj = rk * (block_N // reduce_k) + j + C[by * block_M + i, + bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y, + i % micro_size_x, vj % micro_size_y] + + return main + + +def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness( + M, + N, + K, + in_dtype, + dtypeC, + accum_dtype, + transform_b, +): + matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( + M, N, K, in_dtype, dtypeC, accum_dtype, transform_b) + + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + num_bits = 4 + num_elems_per_byte = 8 // num_bits + storage_dtype = "int8" + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + qB = torch.randint( + 0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + transform_kind=transform_b, + transpose_matrix=True, + dequantize_bits=num_bits, + storage_dtype=storage_dtype, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( + M=N, + N=K, + datatype=in_dtype, + dequantize_bits=num_bits, + storage_dtype=storage_dtype, + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + QLB = ladder_permutate(qB.cpu()).cuda() + QLB = lop3_permutate(QLB.cpu()).cuda() + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + mod(A, QLB, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + B = ( + torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, + dtype=torch.half).to(torch.half).to(A.device)) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +def test_assert_tl_matmul_with_block_reduce(): + assert_tl_matmul_with_block_reduce_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_with_block_reduce_correctness(128, 256, 256, "float16", "float32", "float32") + + +def test_assert_assert_tl_matmul_with_ladder_weight_only_transform(): + assert_tl_matmul_with_ladder_weight_only_transform_correctness(256, 256, 256, "float16", + "float16", "float16", 3) + assert_tl_matmul_with_ladder_weight_only_transform_correctness(256, 256, 256, "float16", + "float32", "float32", 3) + + +def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): + assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness( + 256, 1024, 512, "float16", "float16", "float16", 3) + + +if __name__ == "__main__": + bitblas.testing.main()