diff --git a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py index 8af25704..9af34e03 100644 --- a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py +++ b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py @@ -294,6 +294,106 @@ def ref_program(A, B): torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) +def tl_matmul_block_all_dynamic( + block_M, + block_N, + block_K, + trans_A, + trans_B, + dtypeAB, + dtypeC, + accum_dtype, + num_stages, + threads, +): + M = tvm.te.var("m") + N = tvm.te.var("n") + K = tvm.te.var("k") + + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + import tvm.tl.language as T + + @T.prim_func + def main(A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, dtypeAB), C: T.Buffer((M, N), + dtypeC)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def assert_tl_matmul_block_all_dynamic_correctness( + M, + N, + K, + trans_A, + trans_B, + dtypeAB, + dtypeC, + dtypeAccum, + block_M, + block_N, + block_K, + num_stages=3, + num_threads=128, +): + program = tl_matmul_block_all_dynamic( + block_M, + block_N, + block_K, + trans_A, + trans_B, + dtypeAB, + dtypeC, + dtypeAccum, + num_stages, + num_threads, + ) + mod, params = TL.lower(program) + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, dtypeC)) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod(A, B, C) + print(mod.mod.imported_modules[0].get_source()) + + def ref_program(A, B): + import torch + + if trans_A: + A = A.T + if trans_B: + B = B.T + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(dtypeC)) + return C + + # Get Reference Result + ref_c = ref_program(A, B) + + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + def test_assert_tl_matmul_macro(): assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16") assert_tl_matmul_macro_correctness(66, 128, 128, "float16", "float16", "float16") @@ -309,5 +409,14 @@ def test_assert_tl_matmul_block(): 64, 64, 32) +def test_assert_tl_matmul_block_all_dynamic(): + assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16", + "float16", "float16", 64, 64, 32) + assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16", + "float16", 64, 64, 32) + assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16", + "float16", 64, 64, 32) + + if __name__ == "__main__": bitblas.testing.main()