From 033495e5b00a9ebe304d5af33e2a44b4b6c9c13b Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sat, 28 Sep 2024 12:56:41 +0800 Subject: [PATCH] [TL][BugFix] Disable Buffer Vectorization and Add OP Related TL Test Cases (#197) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * buf fix for matrix support * lint fix * dispatch tensor core based on shapes * update install commands * import scripts * remove shared mem hack * revert change for swizzling * bug fix * tl examples * Enhance Swizzle * lint fix * test fix * lint fix * optimize layout * update tl utils. * macro optimization * test fix * gemm_ss * doc fix * lint fix * lint fix * remove debug print * remove debug print * vectorization init * lint fix * prelude update * update tvm * bug fix for reduce_k with shared memory * bug fix * bug fix * Enhance Macro Generation * Lift Layout to reduce load time * lint fix * test fix * red fix * tile lang macro example * tile lang macro example * optimize the marcro generator related items * lint fix * Tile Lang Test with Dynamic Symbolic * more test case with block level programming * all dynamic test case * simplify the test case for dequantize gemm. * dequant gemm updare. * Tile Lang GEMM Implementation * Tile Lang Gemm Fix * Update subproject commit in 3rdparty/tvm Refactor mma_layout.py and remove unused imports Add matmul_macro_tensorcore and matmul_macro_tensorcore_weight_propagation_level_ldmatrix to dense/__init__.py Refactor test_general_matmul_tilelang_impl.py to include additional matmul functions * test fix --- 3rdparty/tvm | 2 +- .../general_matmul/tilelang/dense/__init__.py | 6 +- .../general_matmul/tilelang/dense/matmul.py | 160 +++++++++++++-- bitblas/tl/mma_layout.py | 3 - .../test_general_matmul_tilelang_impl.py | 186 ++++++++++++++++-- 5 files changed, 318 insertions(+), 39 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 2852f55c2..c115bfd4c 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 2852f55c268db21dc4e9d3e18aae65823c1157e6 +Subproject commit c115bfd4cc9c5257b0b7b3046571d5ab60db39d3 diff --git a/bitblas/ops/general_matmul/tilelang/dense/__init__.py b/bitblas/ops/general_matmul/tilelang/dense/__init__.py index 89a4aefbd..03b5a81f3 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dense/__init__.py @@ -1,4 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .matmul import matmul_blocked # noqa: F401 +from .matmul import ( + matmul_blocked, # noqa: F401 + matmul_macro_tensorcore, # noqa: F401 + matmul_macro_tensorcore_weight_propagation_level_ldmatrix # noqa: F401 +) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index 14efbae07..49858bf2f 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm +from tvm import DataType import tvm.tl.language as T from bitblas.tl.utils import ( @@ -8,18 +9,12 @@ make_swizzle_layout, ) -from bitblas.tl.macro_generator import (TensorCoreIntrinEmitter) - +from bitblas.tl.macro_generator import ( + TensorCoreIntrinEmitter, + TensorCoreIntrinEmitterWithLadderTransform, +) -def maybe_pipeline( - iterable, - num_stages, -): - enable_pipeline = num_stages > 1 - if enable_pipeline: - return T.Pipelined(iterable, num_stages=num_stages) - else: - return T.serial(iterable) +from bitblas.ops.operator import TransformKind def matmul_blocked( @@ -59,7 +54,7 @@ def main( T.use_swizzle(10) T.clear(C_local) - for k in maybe_pipeline(T.ceildiv(K, block_K), num_stages): + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): if trans_A: T.copy(A[k * block_K, by * block_M], A_shared) else: @@ -80,6 +75,8 @@ def matmul_macro_tensorcore( K, dtypeAB, dtypeC, + trans_A, + trans_B, accum_dtype, block_row_warps, block_col_warps, @@ -89,6 +86,8 @@ def matmul_macro_tensorcore( num_stages=2, enable_rasterization=False, ): + assert trans_A is False, "Currently only support Matrix A is not transposed" + assert trans_B is True, "Currently only support Matrix B is transposed" block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles @@ -129,9 +128,9 @@ def main( ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, dtypeAB, shared_scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, dtypeAB, shared_scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, dtypeC, shared_scope=shared_scope) + A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size), dtypeAB) B_local = T.alloc_local((warp_cols * local_size), dtypeAB) C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) @@ -147,7 +146,7 @@ def main( T.clear(C_local) - for ko in maybe_pipeline(T.ceildiv(K, block_K), num_stages): + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for i, k in T.Parallel(block_M, block_K): A_shared[i, k] = A[by * block_M + i, ko * block_K + k] @@ -187,3 +186,132 @@ def main( i % micro_size_x, j % micro_size_y] return main + + +def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( + M, + N, + K, + dtypeAB, + dtypeC, + trans_A, + trans_B, + accum_dtype, + block_row_warps, + block_col_warps, + warp_row_tiles, + warp_col_tiles, + chunk, + num_stages=2, + enable_rasterization=False, +): + assert trans_A is False, "Currently only support Matrix A is not transposed" + assert trans_B is True, "Currently only support Matrix B is transposed" + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + # TODO(lei): Can be generalized to analyzed from bank size + pad_factor = 8 if dtypeAB == "float16" else 16 + + can_swizzle_a = block_K * DataType(dtypeAB).bits == 512 + apply_pad_a = not can_swizzle_a + + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtypeAB) + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + B_shared_shape = (block_N // micro_size_y, block_K // micro_size_k, micro_size_y, micro_size_k) + C_shared_shape = (block_M // micro_size_x, block_N // micro_size_y, micro_size_x, micro_size_y) + + warp_size = 32 # nvidia gpu warp size is 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=dtypeAB, + b_dtype=dtypeAB, + accum_dtype=accum_dtype, + a_transposed=trans_A, + b_transposed=trans_B, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=TransformKind.LDMatrixTransform, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, dtypeAB), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), dtypeAB) + B_local = T.alloc_local((warp_cols * local_size), dtypeAB) + C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + if enable_rasterization: + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, + micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, jj, kk] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + mma_emitter.mma(A_local, B_local, C_local) + + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, + bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, + i % micro_size_x, j % micro_size_y] + + return main diff --git a/bitblas/tl/mma_layout.py b/bitblas/tl/mma_layout.py index 01a729e9c..8be21a1d1 100644 --- a/bitblas/tl/mma_layout.py +++ b/bitblas/tl/mma_layout.py @@ -1,8 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from tvm import arith -from tvm import DataType -from typing import Union, Literal def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py index 314c85100..45558ba69 100644 --- a/testing/python/operators/test_general_matmul_tilelang_impl.py +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -4,27 +4,32 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import tl -from bitblas.ops.general_matmul.tilelang.dense import matmul_blocked +from bitblas.ops.general_matmul.tilelang.dense import ( + matmul_blocked, + matmul_macro_tensorcore, + matmul_macro_tensorcore_weight_propagation_level_ldmatrix, +) + import torch import torch.backends torch.manual_seed(0) -def assert_tl_matmul_correctness(M, - N, - K, - block_M=64, - block_N=64, - block_K=32, - trans_A=False, - trans_B=True, - dtypeAB="float16", - dtypeC="float16", - accum_dtype="float16", - num_stages=2, - threads=128, - enable_rasterization=False): +def assert_matmul_blocked_correctness(M, + N, + K, + block_M=64, + block_N=64, + block_K=32, + trans_A=False, + trans_B=True, + dtypeAB="float16", + dtypeC="float16", + accum_dtype="float16", + num_stages=2, + threads=128, + enable_rasterization=False): matmul = matmul_blocked( M, N, @@ -66,12 +71,157 @@ def assert_tl_matmul_correctness(M, torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) +def assert_matmul_macro_tensorcore_correctness( + M, + N, + K, + dtypeAB="float16", + dtypeC="float16", + trans_A=False, + trans_B=True, + accum_dtype="float16", + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=16, + warp_col_tiles=16, + chunk=32, + num_stages=2, + enable_rasterization=False, +): + matmul = matmul_macro_tensorcore( + M=M, + N=N, + K=K, + dtypeAB=dtypeAB, + dtypeC=dtypeC, + trans_A=trans_A, + trans_B=trans_B, + accum_dtype=accum_dtype, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + num_stages=num_stages, + enable_rasterization=enable_rasterization, + ) + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code represents generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def assert_tl_matmul_with_ladder_weight_only_transform_correctness( + M, + N, + K, + dtypeAB="float16", + dtypeC="float16", + trans_A=False, + trans_B=True, + accum_dtype="float16", + block_row_warps=1, + block_col_warps=1, + warp_row_tiles=16, + warp_col_tiles=16, + chunk=32, + num_stages=2, + enable_rasterization=False, +): + matmul = matmul_macro_tensorcore_weight_propagation_level_ldmatrix( + M=M, + N=N, + K=K, + dtypeAB=dtypeAB, + dtypeC=dtypeC, + trans_A=trans_A, + trans_B=trans_B, + accum_dtype=accum_dtype, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + num_stages=num_stages, + enable_rasterization=enable_rasterization, + ) + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + transform_kind=3, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + LB = ladder_permutate(B.cpu()).cuda() + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(A, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + def test_matmul_blocked(): # Pipeline - assert_tl_matmul_correctness(1024, 1024, 1024, num_stages=2) - assert_tl_matmul_correctness(1024, 1024, 1024, num_stages=1) + assert_matmul_blocked_correctness(1024, 1024, 1024, num_stages=2) + assert_matmul_blocked_correctness(1024, 1024, 1024, num_stages=1) + # L2 Cache + assert_matmul_blocked_correctness(1024, 1024, 1024, enable_rasterization=True) + + +def test_matmul_macro_tensorcore(): + # Pipeline + assert_matmul_macro_tensorcore_correctness(1024, 1024, 1024, num_stages=2) + assert_matmul_macro_tensorcore_correctness(1024, 1024, 1024, num_stages=1) + assert_matmul_macro_tensorcore_correctness(1024, 1024, 1024, num_stages=0) + # L2 Cache + assert_matmul_macro_tensorcore_correctness(1024, 1024, 1024, enable_rasterization=True) + + +def test_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): + # Pipeline + assert_tl_matmul_with_ladder_weight_only_transform_correctness(1024, 1024, 1024, num_stages=2) + assert_tl_matmul_with_ladder_weight_only_transform_correctness(1024, 1024, 1024, num_stages=1) + assert_tl_matmul_with_ladder_weight_only_transform_correctness(1024, 1024, 1024, num_stages=0) # L2 Cache - assert_tl_matmul_correctness(1024, 1024, 1024, enable_rasterization=True) + assert_tl_matmul_with_ladder_weight_only_transform_correctness( + 1024, 1024, 1024, enable_rasterization=True) if __name__ == "__main__":