Skip to content

Commit 15f64c1

Browse files
committed
all dynamic test case
1 parent 9bafdef commit 15f64c1

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

testing/python/tilelang/test_tilelang_dyanmic_symbolic.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,106 @@ def ref_program(A, B):
294294
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
295295

296296

297+
def tl_matmul_block_all_dynamic(
298+
block_M,
299+
block_N,
300+
block_K,
301+
trans_A,
302+
trans_B,
303+
dtypeAB,
304+
dtypeC,
305+
accum_dtype,
306+
num_stages,
307+
threads,
308+
):
309+
M = tvm.te.var("m")
310+
N = tvm.te.var("n")
311+
K = tvm.te.var("k")
312+
313+
A_shape = (K, M) if trans_A else (M, K)
314+
B_shape = (N, K) if trans_B else (K, N)
315+
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
316+
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
317+
318+
import tvm.tl.language as T
319+
320+
@T.prim_func
321+
def main(A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, dtypeAB), C: T.Buffer((M, N),
322+
dtypeC)):
323+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
324+
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
325+
B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
326+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
327+
T.clear(C_local)
328+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
329+
if trans_A:
330+
T.copy(A[k * block_K, by * block_M], A_shared)
331+
else:
332+
T.copy(A[by * block_M, k * block_K], A_shared)
333+
if trans_B:
334+
T.copy(B[bx * block_N, k * block_K], B_shared)
335+
else:
336+
T.copy(B[k * block_K, bx * block_N], B_shared)
337+
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
338+
T.copy(C_local, C[by * block_M, bx * block_N])
339+
340+
return main
341+
342+
343+
def assert_tl_matmul_block_all_dynamic_correctness(
344+
M,
345+
N,
346+
K,
347+
trans_A,
348+
trans_B,
349+
dtypeAB,
350+
dtypeC,
351+
dtypeAccum,
352+
block_M,
353+
block_N,
354+
block_K,
355+
num_stages=3,
356+
num_threads=128,
357+
):
358+
program = tl_matmul_block_all_dynamic(
359+
block_M,
360+
block_N,
361+
block_K,
362+
trans_A,
363+
trans_B,
364+
dtypeAB,
365+
dtypeC,
366+
dtypeAccum,
367+
num_stages,
368+
num_threads,
369+
)
370+
mod, params = TL.lower(program)
371+
372+
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB))
373+
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB))
374+
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, dtypeC))
375+
376+
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
377+
mod(A, B, C)
378+
print(mod.mod.imported_modules[0].get_source())
379+
380+
def ref_program(A, B):
381+
import torch
382+
383+
if trans_A:
384+
A = A.T
385+
if trans_B:
386+
B = B.T
387+
C = torch.matmul(A.to(torch.float), B.to(torch.float))
388+
C = C.to(torch.__getattribute__(dtypeC))
389+
return C
390+
391+
# Get Reference Result
392+
ref_c = ref_program(A, B)
393+
394+
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
395+
396+
297397
def test_assert_tl_matmul_macro():
298398
assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
299399
assert_tl_matmul_macro_correctness(66, 128, 128, "float16", "float16", "float16")
@@ -309,5 +409,14 @@ def test_assert_tl_matmul_block():
309409
64, 64, 32)
310410

311411

412+
def test_assert_tl_matmul_block_all_dynamic():
413+
assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16",
414+
"float16", "float16", 64, 64, 32)
415+
assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16",
416+
"float16", 64, 64, 32)
417+
assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16",
418+
"float16", 64, 64, 32)
419+
420+
312421
if __name__ == "__main__":
313422
bitblas.testing.main()

0 commit comments

Comments
 (0)