From 9b3b73b2c4ce0447aff909b1dc40fdbd86247e8d Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Mon, 2 Sep 2024 17:48:19 +0800 Subject: [PATCH] [TL] Update several TL Examples (#168) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * buf fix for matrix support * lint fix * dispatch tensor core based on shapes * update install commands * import scripts * remove shared mem hack * revert change for swizzling * bug fix * tl examples --- .../tilelang/test_tilelang_dequantize_gemm.py | 164 +++++++++++++++++ .../tilelang/test_tilelang_flash_atten.py | 173 ++++++++++++++++++ 2 files changed, 337 insertions(+) create mode 100644 testing/python/tilelang/test_tilelang_dequantize_gemm.py create mode 100644 testing/python/tilelang/test_tilelang_flash_atten.py diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py new file mode 100644 index 000000000..a0d2feae3 --- /dev/null +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -0,0 +1,164 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +from bitblas import tvm as tvm +from tvm import tl +from bitblas.quantization import _tir_packed_to_unsigned_convert + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + dtypeAB, + dtypeC, + accum_dtype, + num_stages, + threads, + num_bits=4, +): + num_elems_per_byte = 8 // num_bits + storage_dtype = "int8" + A_shape = (M, K) + B_shape = (N, K // num_elems_per_byte) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + + import tvm.tl.language as T + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads + ) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_fragment([8], storage_dtype, "local") + B_dequantize_local = T.alloc_fragment([16], dtypeAB, "local") + B_dequantize_shared = T.alloc_shared( + B_dequantize_shared_shape, dtypeAB + ) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + + for i in T.serial( + block_N * block_K // num_elems_per_byte // (threads * 16) + ): + for t in T.thread_binding(0, threads, thread="threadIdx.x"): + for v in T.vectorized(0, 16): + vi = (i * threads * 16 + t * 16 + v) // ( + block_K // num_elems_per_byte + ) + vj = (i * threads * 16 + t * 16 + v) % ( + block_K // num_elems_per_byte + ) + B_shared[vi, vj] = B[ + bx * block_N + vi, + k * block_K // num_elems_per_byte + vj, + ] + + for i in T.serial( + block_N * block_K // num_elems_per_byte // (threads * 4) + ): + for t in T.thread_binding(0, threads, thread="threadIdx.x"): + for v in T.vectorized(0, 4): + vi = (i * threads * 4 + t * 4 + v) // ( + block_K // num_elems_per_byte + ) + vj = (i * threads * 4 + t * 4 + v) % ( + block_K // num_elems_per_byte + ) + B_local[v] = B_shared[vi, vj] + for v in T.serial(0, 8): + B_dequantize_local[ + v + ] = _tir_packed_to_unsigned_convert("int", 8)( + num_bits, + B_local[v // 2], + v % 2, + dtype=dtypeAB, + ) + for v in T.vectorized(0, 8): + vi = (i * threads * 8 + t * 8 + v) // (block_K) + vj = (i * threads * 8 + t * 8 + v) % (block_K) + B_dequantize_shared[vi, vj] = B_dequantize_local[v] + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm( + M, + N, + K, + dtypeAB, + dtypeC, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + dtypeAB, + dtypeC, + dtypeAccum, + num_stages, + num_threads, + ) + print(program) + + mod, params = tl.lower(program) + mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) + + out = mod.run_once() + + print(f"output is {out}") + + with open("debug/kernel.cu", "w") as f: + f.write(mod.mod.imported_modules[0].get_source()) + + def ref_program(A, qB): + import torch + + B = ( + torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half) + .to(torch.half) + .to(A.device) + ) + for i in range(B.shape[0]): + for j in range(B.shape[1]): + B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to( + torch.half + ) + C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + mod.assert_allclose(ref_program) + + +def test_run_dequantize_gemm(): + run_gemm(16, 16, 16, "int8", "int32", "int32", 16, 16, 16, num_threads=128) + + +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/tilelang/test_tilelang_flash_atten.py b/testing/python/tilelang/test_tilelang_flash_atten.py new file mode 100644 index 000000000..63a52bba8 --- /dev/null +++ b/testing/python/tilelang/test_tilelang_flash_atten.py @@ -0,0 +1,173 @@ +import argparse +import torch +from tvm import tl +import tvm.tl.language as T +from tvm.tl.autotuner import * +from functools import partial +import itertools + + +def get_configs(): + block_M = [32, 64, 128] + block_N = [32, 64, 128] + num_stages = [1, 2] + thread_num = [128, 256] + _configs = list(itertools.product(block_M, block_N, num_stages, thread_num)) + + configs = [ + { + "block_M": c[0], + "block_N": c[1], + "num_stages": c[2], + "thread_num": c[3], + } + for c in _configs + ] + return configs + + +def ref_program(Q, K, V, casual): + from flash_attn.flash_attn_interface import flash_attn_func + + return flash_attn_func(Q, K, V, causal=casual) + + +def flashattn(batch, heads, seq_len, dim, is_casual): + + @autotune( + configs=get_configs(), + keys=["block_M", "block_N", "num_stages", "thread_num"], + warmup=10, + rep=5, + ) + @jit( + out_idx=[3], + supply_type=tl.TensorSupplyType.Normal, + ref_prog=partial(ref_program, casual=is_casual), + rtol=0.01, + atol=0.01, + ) + def kernel(block_M=None, block_N=None, num_stages=None, thread_num=None): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main( + Q: T.Buffer(shape, dtype), # type: ignore + K: T.Buffer(shape, dtype), # type: ignore + V: T.Buffer(shape, dtype), # type: ignore + Output: T.Buffer(shape, dtype), # type: ignore + ): + with T.Kernel( + T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num + ) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + Q_local = T.alloc_fragment([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.annotate_layout( + {Q_shared: tl.layout.make_swizzled_layout(Q_shared)} + ) + T.copy( + Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared + ) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.copy(Q_shared, Q_local) + for i, j in T.Parallel(block_M, dim): + Q_local[i, j] *= scale + loop_range = ( + T.ceildiv((bx + 1) * block_M, block_N) + if is_casual + else T.ceildiv(seq_len, block_N) + ) + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy( + K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared + ) + if is_casual: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, + 0, + -T.infinity(acc_s.dtype), + ) + else: + T.clear(acc_s) + T.gemm( + Q_local, + K_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy( + V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared + ) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2( + scores_max_prev[i] - scores_max[i] + ) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) + T.copy(acc_s, acc_s_cast) + T.gemm( + acc_s_cast, + V_shared, + acc_o, + policy=T.GemmWarpPolicy.FullRow, + ) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy( + acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :] + ) + + return main + + return kernel() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=64, help="Batch size") + parser.add_argument("--h", type=int, default=12, help="Number of heads") + parser.add_argument("--n_ctx", type=int, default=2048, help="Context size") + parser.add_argument( + "--d_head", type=int, default=256, help="Head dimension" + ) + parser.add_argument("--casual", type=bool, default=True, help="Casual flag") + args = parser.parse_args() + BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head + casual = args.casual + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD + total_flops = 2 * flops_per_matmul + if casual: + total_flops *= 0.5 + + best_latency, best_config, ref_latency = flashattn( + BATCH, H, N_CTX, D_HEAD, casual + ) + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}")