diff --git a/3rdparty/tvm b/3rdparty/tvm index a29c8ad7e..a1d78ebc6 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit a29c8ad7e78f61e0658946bd494f45cc9bebd36e +Subproject commit a1d78ebc682dbaec70e792470c6842b9ec3342c6 diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index a0d2feae3..9db978cd9 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -32,57 +32,37 @@ def matmul( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, storage_dtype), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), dtypeC), ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads - ) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, dtypeAB) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment([8], storage_dtype, "local") B_dequantize_local = T.alloc_fragment([16], dtypeAB, "local") - B_dequantize_shared = T.alloc_shared( - B_dequantize_shared_shape, dtypeAB - ) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, dtypeAB) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(A[by * block_M, k * block_K], A_shared) - for i in T.serial( - block_N * block_K // num_elems_per_byte // (threads * 16) - ): + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 16)): for t in T.thread_binding(0, threads, thread="threadIdx.x"): for v in T.vectorized(0, 16): - vi = (i * threads * 16 + t * 16 + v) // ( - block_K // num_elems_per_byte - ) - vj = (i * threads * 16 + t * 16 + v) % ( - block_K // num_elems_per_byte - ) - B_shared[vi, vj] = B[ - bx * block_N + vi, - k * block_K // num_elems_per_byte + vj, - ] - - for i in T.serial( - block_N * block_K // num_elems_per_byte // (threads * 4) - ): + vi = (i * threads * 16 + t * 16 + v) // (block_K // num_elems_per_byte) + vj = (i * threads * 16 + t * 16 + v) % (block_K // num_elems_per_byte) + B_shared[vi, vj] = B[bx * block_N + vi, + k * block_K // num_elems_per_byte + vj,] + + for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 4)): for t in T.thread_binding(0, threads, thread="threadIdx.x"): for v in T.vectorized(0, 4): - vi = (i * threads * 4 + t * 4 + v) // ( - block_K // num_elems_per_byte - ) - vj = (i * threads * 4 + t * 4 + v) % ( - block_K // num_elems_per_byte - ) + vi = (i * threads * 4 + t * 4 + v) // (block_K // num_elems_per_byte) + vj = (i * threads * 4 + t * 4 + v) % (block_K // num_elems_per_byte) B_local[v] = B_shared[vi, vj] for v in T.serial(0, 8): - B_dequantize_local[ - v - ] = _tir_packed_to_unsigned_convert("int", 8)( + B_dequantize_local[v] = _tir_packed_to_unsigned_convert("int", 8)( num_bits, B_local[v // 2], v % 2, @@ -140,15 +120,11 @@ def ref_program(A, qB): import torch B = ( - torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half) - .to(torch.half) - .to(A.device) - ) + torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, + dtype=torch.half).to(torch.half).to(A.device)) for i in range(B.shape[0]): for j in range(B.shape[1]): - B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to( - torch.half - ) + B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = C.to(torch.__getattribute__(dtypeC)) return C @@ -157,7 +133,7 @@ def ref_program(A, qB): def test_run_dequantize_gemm(): - run_gemm(16, 16, 16, "int8", "int32", "int32", 16, 16, 16, num_threads=128) + run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128) if __name__ == "__main__": diff --git a/testing/python/tilelang/test_tilelang_flash_atten.py b/testing/python/tilelang/test_tilelang_flash_atten.py index 63a52bba8..a8b8c4986 100644 --- a/testing/python/tilelang/test_tilelang_flash_atten.py +++ b/testing/python/tilelang/test_tilelang_flash_atten.py @@ -1,5 +1,4 @@ import argparse -import torch from tvm import tl import tvm.tl.language as T from tvm.tl.autotuner import * @@ -14,15 +13,12 @@ def get_configs(): thread_num = [128, 256] _configs = list(itertools.product(block_M, block_N, num_stages, thread_num)) - configs = [ - { - "block_M": c[0], - "block_N": c[1], - "num_stages": c[2], - "thread_num": c[3], - } - for c in _configs - ] + configs = [{ + "block_M": c[0], + "block_N": c[1], + "num_stages": c[2], + "thread_num": c[3], + } for c in _configs] return configs @@ -48,21 +44,20 @@ def flashattn(batch, heads, seq_len, dim, is_casual): atol=0.01, ) def kernel(block_M=None, block_N=None, num_stages=None, thread_num=None): - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) shape = [batch, seq_len, heads, dim] dtype = "float16" accum_dtype = "float" @T.prim_func def main( - Q: T.Buffer(shape, dtype), # type: ignore - K: T.Buffer(shape, dtype), # type: ignore - V: T.Buffer(shape, dtype), # type: ignore - Output: T.Buffer(shape, dtype), # type: ignore + Q: T.Buffer(shape, dtype), # type: ignore + K: T.Buffer(shape, dtype), # type: ignore + V: T.Buffer(shape, dtype), # type: ignore + Output: T.Buffer(shape, dtype), # type: ignore ): with T.Kernel( - T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num - ) as (bx, by, bz): + T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) Q_local = T.alloc_fragment([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) @@ -76,12 +71,8 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout( - {Q_shared: tl.layout.make_swizzled_layout(Q_shared)} - ) - T.copy( - Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared - ) + T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) @@ -89,14 +80,10 @@ def main( for i, j in T.Parallel(block_M, dim): Q_local[i, j] *= scale loop_range = ( - T.ceildiv((bx + 1) * block_M, block_N) - if is_casual - else T.ceildiv(seq_len, block_N) - ) + T.ceildiv( + (bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N)) for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy( - K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared - ) + T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) if is_casual: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( @@ -113,15 +100,11 @@ def main( transpose_B=True, policy=T.GemmWarpPolicy.FullRow, ) - T.copy( - V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared - ) + T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_M): - scores_scale[i] = T.exp2( - scores_max_prev[i] - scores_max[i] - ) + scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i]) for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scores_scale[i] for i, j in T.Parallel(block_M, block_N): @@ -138,9 +121,7 @@ def main( logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] for i, j in T.Parallel(block_M, dim): acc_o[i, j] /= logsum[i] - T.copy( - acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :] - ) + T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) return main @@ -152,9 +133,7 @@ def main( parser.add_argument("--batch", type=int, default=64, help="Batch size") parser.add_argument("--h", type=int, default=12, help="Number of heads") parser.add_argument("--n_ctx", type=int, default=2048, help="Context size") - parser.add_argument( - "--d_head", type=int, default=256, help="Head dimension" - ) + parser.add_argument("--d_head", type=int, default=256, help="Head dimension") parser.add_argument("--casual", type=bool, default=True, help="Casual flag") args = parser.parse_args() BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head @@ -164,9 +143,7 @@ def main( if casual: total_flops *= 0.5 - best_latency, best_config, ref_latency = flashattn( - BATCH, H, N_CTX, D_HEAD, casual - ) + best_latency, best_config, ref_latency = flashattn(BATCH, H, N_CTX, D_HEAD, casual) print(f"Best latency: {best_latency}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best config: {best_config}")