Skip to content

Commit

Permalink
[TL] Enhance Layout Annotate Pass to handle PTX Inst (#170)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* buf fix for matrix support

* lint fix

* dispatch tensor core based on shapes

* update install commands

* import scripts

* remove shared mem hack

* revert change for swizzling

* bug fix

* tl examples

* Enhance Swizzle

* lint fix

* test fix

* lint fix
  • Loading branch information
LeiWang1999 authored Sep 3, 2024
1 parent 9b3b73b commit ab97913
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 89 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
62 changes: 19 additions & 43 deletions testing/python/tilelang/test_tilelang_dequantize_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,57 +32,37 @@ def matmul(

@T.prim_func
def main(
A: T.Buffer(A_shape, dtypeAB),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), dtypeC),
A: T.Buffer(A_shape, dtypeAB),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), dtypeC),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment([8], storage_dtype, "local")
B_dequantize_local = T.alloc_fragment([16], dtypeAB, "local")
B_dequantize_shared = T.alloc_shared(
B_dequantize_shared_shape, dtypeAB
)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, dtypeAB)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)

for i in T.serial(
block_N * block_K // num_elems_per_byte // (threads * 16)
):
for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 16)):
for t in T.thread_binding(0, threads, thread="threadIdx.x"):
for v in T.vectorized(0, 16):
vi = (i * threads * 16 + t * 16 + v) // (
block_K // num_elems_per_byte
)
vj = (i * threads * 16 + t * 16 + v) % (
block_K // num_elems_per_byte
)
B_shared[vi, vj] = B[
bx * block_N + vi,
k * block_K // num_elems_per_byte + vj,
]

for i in T.serial(
block_N * block_K // num_elems_per_byte // (threads * 4)
):
vi = (i * threads * 16 + t * 16 + v) // (block_K // num_elems_per_byte)
vj = (i * threads * 16 + t * 16 + v) % (block_K // num_elems_per_byte)
B_shared[vi, vj] = B[bx * block_N + vi,
k * block_K // num_elems_per_byte + vj,]

for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 4)):
for t in T.thread_binding(0, threads, thread="threadIdx.x"):
for v in T.vectorized(0, 4):
vi = (i * threads * 4 + t * 4 + v) // (
block_K // num_elems_per_byte
)
vj = (i * threads * 4 + t * 4 + v) % (
block_K // num_elems_per_byte
)
vi = (i * threads * 4 + t * 4 + v) // (block_K // num_elems_per_byte)
vj = (i * threads * 4 + t * 4 + v) % (block_K // num_elems_per_byte)
B_local[v] = B_shared[vi, vj]
for v in T.serial(0, 8):
B_dequantize_local[
v
] = _tir_packed_to_unsigned_convert("int", 8)(
B_dequantize_local[v] = _tir_packed_to_unsigned_convert("int", 8)(
num_bits,
B_local[v // 2],
v % 2,
Expand Down Expand Up @@ -140,15 +120,11 @@ def ref_program(A, qB):
import torch

B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half)
.to(torch.half)
.to(A.device)
)
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(
torch.half
)
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
Expand All @@ -157,7 +133,7 @@ def ref_program(A, qB):


def test_run_dequantize_gemm():
run_gemm(16, 16, 16, "int8", "int32", "int32", 16, 16, 16, num_threads=128)
run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128)


if __name__ == "__main__":
Expand Down
67 changes: 22 additions & 45 deletions testing/python/tilelang/test_tilelang_flash_atten.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import torch
from tvm import tl
import tvm.tl.language as T
from tvm.tl.autotuner import *
Expand All @@ -14,15 +13,12 @@ def get_configs():
thread_num = [128, 256]
_configs = list(itertools.product(block_M, block_N, num_stages, thread_num))

configs = [
{
"block_M": c[0],
"block_N": c[1],
"num_stages": c[2],
"thread_num": c[3],
}
for c in _configs
]
configs = [{
"block_M": c[0],
"block_N": c[1],
"num_stages": c[2],
"thread_num": c[3],
} for c in _configs]
return configs


Expand All @@ -48,21 +44,20 @@ def flashattn(batch, heads, seq_len, dim, is_casual):
atol=0.01,
)
def kernel(block_M=None, block_N=None, num_stages=None, thread_num=None):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"

@T.prim_func
def main(
Q: T.Buffer(shape, dtype), # type: ignore
K: T.Buffer(shape, dtype), # type: ignore
V: T.Buffer(shape, dtype), # type: ignore
Output: T.Buffer(shape, dtype), # type: ignore
Q: T.Buffer(shape, dtype), # type: ignore
K: T.Buffer(shape, dtype), # type: ignore
V: T.Buffer(shape, dtype), # type: ignore
Output: T.Buffer(shape, dtype), # type: ignore
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num
) as (bx, by, bz):
T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
Q_local = T.alloc_fragment([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
Expand All @@ -76,27 +71,19 @@ def main(
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)

T.annotate_layout(
{Q_shared: tl.layout.make_swizzled_layout(Q_shared)}
)
T.copy(
Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared
)
T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.copy(Q_shared, Q_local)
for i, j in T.Parallel(block_M, dim):
Q_local[i, j] *= scale
loop_range = (
T.ceildiv((bx + 1) * block_M, block_N)
if is_casual
else T.ceildiv(seq_len, block_N)
)
T.ceildiv(
(bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(
K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared
)
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_casual:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
Expand All @@ -113,15 +100,11 @@ def main(
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.copy(
V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared
)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(
scores_max_prev[i] - scores_max[i]
)
scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
Expand All @@ -138,9 +121,7 @@ def main(
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(
acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]
)
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])

return main

Expand All @@ -152,9 +133,7 @@ def main(
parser.add_argument("--batch", type=int, default=64, help="Batch size")
parser.add_argument("--h", type=int, default=12, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=2048, help="Context size")
parser.add_argument(
"--d_head", type=int, default=256, help="Head dimension"
)
parser.add_argument("--d_head", type=int, default=256, help="Head dimension")
parser.add_argument("--casual", type=bool, default=True, help="Casual flag")
args = parser.parse_args()
BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head
Expand All @@ -164,9 +143,7 @@ def main(
if casual:
total_flops *= 0.5

best_latency, best_config, ref_latency = flashattn(
BATCH, H, N_CTX, D_HEAD, casual
)
best_latency, best_config, ref_latency = flashattn(BATCH, H, N_CTX, D_HEAD, casual)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
Expand Down

0 comments on commit ab97913

Please sign in to comment.