Skip to content

Commit

Permalink
all dynamic test case
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Sep 26, 2024
1 parent 9bafdef commit 15f64c1
Showing 1 changed file with 109 additions and 0 deletions.
109 changes: 109 additions & 0 deletions testing/python/tilelang/test_tilelang_dyanmic_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,106 @@ def ref_program(A, B):
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)


def tl_matmul_block_all_dynamic(
block_M,
block_N,
block_K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
accum_dtype,
num_stages,
threads,
):
M = tvm.te.var("m")
N = tvm.te.var("n")
K = tvm.te.var("k")

A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

import tvm.tl.language as T

@T.prim_func
def main(A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, dtypeAB), C: T.Buffer((M, N),
dtypeC)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])

return main


def assert_tl_matmul_block_all_dynamic_correctness(
M,
N,
K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = tl_matmul_block_all_dynamic(
block_M,
block_N,
block_K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
dtypeAccum,
num_stages,
num_threads,
)
mod, params = TL.lower(program)

A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, dtypeC))

mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
print(mod.mod.imported_modules[0].get_source())

def ref_program(A, B):
import torch

if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C

# Get Reference Result
ref_c = ref_program(A, B)

torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)


def test_assert_tl_matmul_macro():
assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_macro_correctness(66, 128, 128, "float16", "float16", "float16")
Expand All @@ -309,5 +409,14 @@ def test_assert_tl_matmul_block():
64, 64, 32)


def test_assert_tl_matmul_block_all_dynamic():
assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16",
"float16", "float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16",
"float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16",
"float16", 64, 64, 32)


if __name__ == "__main__":
bitblas.testing.main()

0 comments on commit 15f64c1

Please sign in to comment.