Skip to content

Commit

Permalink
more test case with block level programming
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Sep 26, 2024
1 parent 5cfce84 commit 9bafdef
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 13 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 1 files
+4 −1 python/tvm/tl/utils.py
128 changes: 116 additions & 12 deletions testing/python/tilelang/test_tilelang_dyanmic_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
from tvm import tl as TL
import tvm.tl.language as T
from bitblas.tl.utils import get_swizzle_layout
from bitblas.tl.macro_generator import (
TensorCoreIntrinEmitter,
TensorCoreIntrinEmitterWithLadderTransform,
)
from bitblas.gpu.intrin.lop3 import decode_i4_to_f16
from bitblas.tl.macro_generator import (TensorCoreIntrinEmitter)

torch.manual_seed(0)

Expand All @@ -33,7 +29,7 @@ def transform_func(i, j):
return T.Layout(shape, transform_func)


def tl_matmul(
def tl_matmul_macro(
N,
K,
dtypeAB,
Expand Down Expand Up @@ -176,8 +172,8 @@ def main(
return main


def assert_tl_matmul_correctness(M, N, K, in_dtype, dtypeC, accum_dtype):
matmul = tl_matmul(N, K, in_dtype, dtypeC, accum_dtype)
def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, dtypeC, accum_dtype):
matmul = tl_matmul_macro(N, K, in_dtype, dtypeC, accum_dtype)

mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
Expand All @@ -198,11 +194,119 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, dtypeC, accum_dtype):
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)


def tl_matmul_block(
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
accum_dtype,
num_stages,
threads,
):
M = tvm.te.var("m")
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

import tvm.tl.language as T

@T.prim_func
def main(A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, dtypeAB), C: T.Buffer((M, N),
dtypeC)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])

return main


def assert_tl_matmul_block_correctness(
M,
N,
K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = tl_matmul_block(
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
dtypeAccum,
num_stages,
num_threads,
)
mod, params = TL.lower(program)

A = torch.rand(M, K, device="cuda", dtype=getattr(torch, dtypeAB))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, dtypeAB))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, dtypeC))

mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)

def ref_program(A, B):
import torch

if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C

# Get Reference Result
ref_c = ref_program(A, B)

torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)


def test_assert_tl_matmul_macro():
assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_macro_correctness(66, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_macro_correctness(32, 128, 128, "float16", "float16", "float16")


def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(66, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(32, 128, 128, "float16", "float16", "float16")
def test_assert_tl_matmul_block():
assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16",
64, 64, 32)
assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16",
64, 64, 32)
assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16",
64, 64, 32)


if __name__ == "__main__":
Expand Down

0 comments on commit 9bafdef

Please sign in to comment.