From a912e02c4030d3485f73f9cf6b618f0fd19c7102 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Fri, 27 Sep 2024 16:27:02 +0800 Subject: [PATCH] [Test] Add Thread Level Macro Dequantize Gemm Test Cases (#194) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * buf fix for matrix support * lint fix * dispatch tensor core based on shapes * update install commands * import scripts * remove shared mem hack * revert change for swizzling * bug fix * tl examples * Enhance Swizzle * lint fix * test fix * lint fix * optimize layout * update tl utils. * macro optimization * test fix * gemm_ss * doc fix * lint fix * lint fix * remove debug print * remove debug print * vectorization init * lint fix * prelude update * update tvm * bug fix for reduce_k with shared memory * bug fix * bug fix * Enhance Macro Generation * Lift Layout to reduce load time * lint fix * test fix * red fix * tile lang macro example * tile lang macro example * optimize the marcro generator related items * lint fix * Tile Lang Test with Dynamic Symbolic * more test case with block level programming * all dynamic test case * simplify the test case for dequantize gemm. * dequant gemm updare. --- .../tilelang/test_tilelang_dequantize_gemm.py | 328 +++++++++++++++++- 1 file changed, 318 insertions(+), 10 deletions(-) diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index 574bac15a..f8217157a 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -1,9 +1,35 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import torch +import torch.backends import bitblas from bitblas import tvm as tvm -from tvm import tl +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T from bitblas.quantization import _tir_packed_to_unsigned_convert +from bitblas.tl.utils import get_swizzle_layout +from bitblas.tl.macro_generator import ( + TensorCoreIntrinEmitterWithLadderTransform,) + +from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 + +torch.manual_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) def matmul( @@ -47,13 +73,8 @@ def main( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(A[by * block_M, k * block_K], A_shared) - for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 16)): - for t in T.thread_binding(0, threads, thread="threadIdx.x"): - for v in T.vectorized(0, 16): - vi = (i * threads * 16 + t * 16 + v) // (block_K // num_elems_per_byte) - vj = (i * threads * 16 + t * 16 + v) % (block_K // num_elems_per_byte) - B_shared[vi, vj] = B[bx * block_N + vi, - k * block_K // num_elems_per_byte + vj,] + for i, j in T.Parallel(block_N, block_K // num_elems_per_byte): + B_shared[i, j] = B[bx * block_N + i, k * block_K // num_elems_per_byte + j] for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 4)): for t in T.thread_binding(0, threads, thread="threadIdx.x"): @@ -106,8 +127,8 @@ def run_gemm( ) print(program) - mod, params = tl.lower(program) - mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) + mod, params = TL.lower(program) + mod = TL.Profiler(mod, params, [2], TL.TensorSupplyType.Integer) out = mod.run_once() @@ -129,9 +150,296 @@ def ref_program(A, qB): mod.assert_allclose(ref_program) +def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( + M, + N, + K, + dtypeAB, + dtypeC, + accum_dtype, + transform_b, +): + assert dtypeAB in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert dtypeC in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + num_bits = 4 + num_elems_per_byte = 8 // num_bits + storage_dtype = "int8" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if dtypeC == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 1 + block_col_warps = 4 + + warp_rows = 1 + warp_cols = 2 + warp_row_tiles = micro_size_x * warp_rows + warp_col_tiles = micro_size_y * warp_cols + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + reduce_k = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = 32 if dtypeAB == "float16" else 64 + chunk = block_K // reduce_k + + is_smooth_a = False + can_swizzle = block_K * DataType(dtypeAB).bits == 512 + apply_pad_a = not (is_smooth_a or can_swizzle) + pad_factor = 8 + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, + micro_size_k // num_elems_per_byte) + A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k // num_elems_per_byte, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=dtypeAB, + b_dtype=dtypeAB, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + reduce_k=reduce_k, + transform_kind_b=transform_b, + num_elems_per_byte=num_elems_per_byte) + + vec_load_qb = 16 + if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb: + vec_load_qb = block_N * (block_K // reduce_k) // num_elems_per_byte // threads + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, + prelude=decode_i4_to_f16) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size), dtypeAB) + B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype) + B_dequantize_local = T.alloc_local((warp_cols * local_size), dtypeAB) + C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + reduced_accum_res = T.alloc_local(0, accum_dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + rk = T.thread_binding(0, reduce_k, "threadIdx.y") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + }) + + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, (block_K // reduce_k)): + vk = rk * (block_K // reduce_k) + k + A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk] + + # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load + for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // + (threads * vec_load_qb)): + for v in T.vectorized(0, vec_load_qb): + t = thread_bindings + idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v + vkk = idx % (micro_size_k // num_elems_per_byte) + vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y + vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % ( + block_K // micro_size_k) + vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // + (block_K // micro_size_k)) % ( + block_N // micro_size_y) + B_shared[vj, vk, vjj, + vkk] = B[bx * (block_N // micro_size_y) + vj, + ko * (block_K // micro_size_k) + vk, vjj, vkk] + + for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + rk=rk, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + rk=rk, + ) + + for j in T.serial(warp_cols): + local_size_b = mma_emitter.local_size_b + T.call_extern('handle', 'decode_i4u_to_f16', + T.address_of(B_local[j * local_size_b // num_elems_per_byte]), + T.address_of(B_dequantize_local[j * local_size_b]), 8) + + mma_emitter.mma(A_local, B_dequantize_local, C_local) + + if reduce_k > 1: + for n in T.serial(warp_rows * warp_cols * local_size): + T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float16(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + C_local[n], + True, + reduced_accum_res[0], + rk, + dtype="handle", + )) + if rk == 0: + C_local[n] = reduced_accum_res[0] + + if rk == 0: + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + for i, j in T.Parallel(block_M, (block_N // reduce_k)): + vj = rk * (block_N // reduce_k) + j + C[by * block_M + i, + bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y, + i % micro_size_x, vj % micro_size_y] + + return main + + +def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness( + M, + N, + K, + in_dtype, + dtypeC, + accum_dtype, + transform_b, +): + matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( + M, N, K, in_dtype, dtypeC, accum_dtype, transform_b) + + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + + # src_code is the generated cuda source + assert src_code is not None + num_bits = 4 + num_elems_per_byte = 8 // num_bits + storage_dtype = "int8" + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + qB = torch.randint( + 0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + transform_kind=transform_b, + transpose_matrix=True, + dequantize_bits=num_bits, + storage_dtype=storage_dtype, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( + M=N, + N=K, + datatype=in_dtype, + dequantize_bits=num_bits, + storage_dtype=storage_dtype, + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + QLB = ladder_permutate(qB.cpu()).cuda() + QLB = lop3_permutate(QLB.cpu()).cuda() + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + mod(A, QLB, C) + + latency = mod.do_bench(mod.func, warmup=25) + + # Ensure that the latency is not None + assert latency is not None + + B = ( + torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, + dtype=torch.half).to(torch.half).to(A.device)) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) + + # Get Reference Result + ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype)) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + def test_run_dequantize_gemm(): run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128) +def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): + assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness( + 256, 1024, 512, "float16", "float16", "float16", 3) + + if __name__ == "__main__": bitblas.testing.main()