Skip to content

Commit

Permalink
[TL] Update several TL Examples (#168)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* buf fix for matrix support

* lint fix

* dispatch tensor core based on shapes

* update install commands

* import scripts

* remove shared mem hack

* revert change for swizzling

* bug fix

* tl examples
  • Loading branch information
LeiWang1999 authored Sep 2, 2024
1 parent c55600f commit 9b3b73b
Show file tree
Hide file tree
Showing 2 changed files with 337 additions and 0 deletions.
164 changes: 164 additions & 0 deletions testing/python/tilelang/test_tilelang_dequantize_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import bitblas
from bitblas import tvm as tvm
from tvm import tl
from bitblas.quantization import _tir_packed_to_unsigned_convert


def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
dtypeAB,
dtypeC,
accum_dtype,
num_stages,
threads,
num_bits=4,
):
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)

import tvm.tl.language as T

@T.prim_func
def main(
A: T.Buffer(A_shape, dtypeAB),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), dtypeC),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment([8], storage_dtype, "local")
B_dequantize_local = T.alloc_fragment([16], dtypeAB, "local")
B_dequantize_shared = T.alloc_shared(
B_dequantize_shared_shape, dtypeAB
)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)

for i in T.serial(
block_N * block_K // num_elems_per_byte // (threads * 16)
):
for t in T.thread_binding(0, threads, thread="threadIdx.x"):
for v in T.vectorized(0, 16):
vi = (i * threads * 16 + t * 16 + v) // (
block_K // num_elems_per_byte
)
vj = (i * threads * 16 + t * 16 + v) % (
block_K // num_elems_per_byte
)
B_shared[vi, vj] = B[
bx * block_N + vi,
k * block_K // num_elems_per_byte + vj,
]

for i in T.serial(
block_N * block_K // num_elems_per_byte // (threads * 4)
):
for t in T.thread_binding(0, threads, thread="threadIdx.x"):
for v in T.vectorized(0, 4):
vi = (i * threads * 4 + t * 4 + v) // (
block_K // num_elems_per_byte
)
vj = (i * threads * 4 + t * 4 + v) % (
block_K // num_elems_per_byte
)
B_local[v] = B_shared[vi, vj]
for v in T.serial(0, 8):
B_dequantize_local[
v
] = _tir_packed_to_unsigned_convert("int", 8)(
num_bits,
B_local[v // 2],
v % 2,
dtype=dtypeAB,
)
for v in T.vectorized(0, 8):
vi = (i * threads * 8 + t * 8 + v) // (block_K)
vj = (i * threads * 8 + t * 8 + v) % (block_K)
B_dequantize_shared[vi, vj] = B_dequantize_local[v]
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])

return main


def run_gemm(
M,
N,
K,
dtypeAB,
dtypeC,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
dtypeAB,
dtypeC,
dtypeAccum,
num_stages,
num_threads,
)
print(program)

mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)

out = mod.run_once()

print(f"output is {out}")

with open("debug/kernel.cu", "w") as f:
f.write(mod.mod.imported_modules[0].get_source())

def ref_program(A, qB):
import torch

B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half)
.to(torch.half)
.to(A.device)
)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(
torch.half
)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C

mod.assert_allclose(ref_program)


def test_run_dequantize_gemm():
run_gemm(16, 16, 16, "int8", "int32", "int32", 16, 16, 16, num_threads=128)


if __name__ == "__main__":
bitblas.testing.main()
173 changes: 173 additions & 0 deletions testing/python/tilelang/test_tilelang_flash_atten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import argparse
import torch
from tvm import tl
import tvm.tl.language as T
from tvm.tl.autotuner import *
from functools import partial
import itertools


def get_configs():
block_M = [32, 64, 128]
block_N = [32, 64, 128]
num_stages = [1, 2]
thread_num = [128, 256]
_configs = list(itertools.product(block_M, block_N, num_stages, thread_num))

configs = [
{
"block_M": c[0],
"block_N": c[1],
"num_stages": c[2],
"thread_num": c[3],
}
for c in _configs
]
return configs


def ref_program(Q, K, V, casual):
from flash_attn.flash_attn_interface import flash_attn_func

return flash_attn_func(Q, K, V, causal=casual)


def flashattn(batch, heads, seq_len, dim, is_casual):

@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "thread_num"],
warmup=10,
rep=5,
)
@jit(
out_idx=[3],
supply_type=tl.TensorSupplyType.Normal,
ref_prog=partial(ref_program, casual=is_casual),
rtol=0.01,
atol=0.01,
)
def kernel(block_M=None, block_N=None, num_stages=None, thread_num=None):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"

@T.prim_func
def main(
Q: T.Buffer(shape, dtype), # type: ignore
K: T.Buffer(shape, dtype), # type: ignore
V: T.Buffer(shape, dtype), # type: ignore
Output: T.Buffer(shape, dtype), # type: ignore
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num
) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
Q_local = T.alloc_fragment([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)

T.annotate_layout(
{Q_shared: tl.layout.make_swizzled_layout(Q_shared)}
)
T.copy(
Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared
)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.copy(Q_shared, Q_local)
for i, j in T.Parallel(block_M, dim):
Q_local[i, j] *= scale
loop_range = (
T.ceildiv((bx + 1) * block_M, block_N)
if is_casual
else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared
)
if is_casual:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i >= k * block_N + j,
0,
-T.infinity(acc_s.dtype),
)
else:
T.clear(acc_s)
T.gemm(
Q_local,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.copy(
V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared
)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(
scores_max_prev[i] - scores_max[i]
)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])
T.copy(acc_s, acc_s_cast)
T.gemm(
acc_s_cast,
V_shared,
acc_o,
policy=T.GemmWarpPolicy.FullRow,
)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(
acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]
)

return main

return kernel()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=64, help="Batch size")
parser.add_argument("--h", type=int, default=12, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=2048, help="Context size")
parser.add_argument(
"--d_head", type=int, default=256, help="Head dimension"
)
parser.add_argument("--casual", type=bool, default=True, help="Casual flag")
args = parser.parse_args()
BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head
casual = args.casual
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if casual:
total_flops *= 0.5

best_latency, best_config, ref_latency = flashattn(
BATCH, H, N_CTX, D_HEAD, casual
)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}")

0 comments on commit 9b3b73b

Please sign in to comment.