Skip to content

Commit ab97913

Browse files
authored
[TL] Enhance Layout Annotate Pass to handle PTX Inst (#170)
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * buf fix for matrix support * lint fix * dispatch tensor core based on shapes * update install commands * import scripts * remove shared mem hack * revert change for swizzling * bug fix * tl examples * Enhance Swizzle * lint fix * test fix * lint fix
1 parent 9b3b73b commit ab97913

File tree

3 files changed

+42
-89
lines changed

3 files changed

+42
-89
lines changed

testing/python/tilelang/test_tilelang_dequantize_gemm.py

Lines changed: 19 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -32,57 +32,37 @@ def matmul(
3232

3333
@T.prim_func
3434
def main(
35-
A: T.Buffer(A_shape, dtypeAB),
36-
B: T.Buffer(B_shape, storage_dtype),
37-
C: T.Buffer((M, N), dtypeC),
35+
A: T.Buffer(A_shape, dtypeAB),
36+
B: T.Buffer(B_shape, storage_dtype),
37+
C: T.Buffer((M, N), dtypeC),
3838
):
39-
with T.Kernel(
40-
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
41-
) as (bx, by):
39+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
4240
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
4341
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
4442
B_local = T.alloc_fragment([8], storage_dtype, "local")
4543
B_dequantize_local = T.alloc_fragment([16], dtypeAB, "local")
46-
B_dequantize_shared = T.alloc_shared(
47-
B_dequantize_shared_shape, dtypeAB
48-
)
44+
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, dtypeAB)
4945
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
5046
T.clear(C_local)
5147
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
5248
T.copy(A[by * block_M, k * block_K], A_shared)
5349

54-
for i in T.serial(
55-
block_N * block_K // num_elems_per_byte // (threads * 16)
56-
):
50+
for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 16)):
5751
for t in T.thread_binding(0, threads, thread="threadIdx.x"):
5852
for v in T.vectorized(0, 16):
59-
vi = (i * threads * 16 + t * 16 + v) // (
60-
block_K // num_elems_per_byte
61-
)
62-
vj = (i * threads * 16 + t * 16 + v) % (
63-
block_K // num_elems_per_byte
64-
)
65-
B_shared[vi, vj] = B[
66-
bx * block_N + vi,
67-
k * block_K // num_elems_per_byte + vj,
68-
]
69-
70-
for i in T.serial(
71-
block_N * block_K // num_elems_per_byte // (threads * 4)
72-
):
53+
vi = (i * threads * 16 + t * 16 + v) // (block_K // num_elems_per_byte)
54+
vj = (i * threads * 16 + t * 16 + v) % (block_K // num_elems_per_byte)
55+
B_shared[vi, vj] = B[bx * block_N + vi,
56+
k * block_K // num_elems_per_byte + vj,]
57+
58+
for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 4)):
7359
for t in T.thread_binding(0, threads, thread="threadIdx.x"):
7460
for v in T.vectorized(0, 4):
75-
vi = (i * threads * 4 + t * 4 + v) // (
76-
block_K // num_elems_per_byte
77-
)
78-
vj = (i * threads * 4 + t * 4 + v) % (
79-
block_K // num_elems_per_byte
80-
)
61+
vi = (i * threads * 4 + t * 4 + v) // (block_K // num_elems_per_byte)
62+
vj = (i * threads * 4 + t * 4 + v) % (block_K // num_elems_per_byte)
8163
B_local[v] = B_shared[vi, vj]
8264
for v in T.serial(0, 8):
83-
B_dequantize_local[
84-
v
85-
] = _tir_packed_to_unsigned_convert("int", 8)(
65+
B_dequantize_local[v] = _tir_packed_to_unsigned_convert("int", 8)(
8666
num_bits,
8767
B_local[v // 2],
8868
v % 2,
@@ -140,15 +120,11 @@ def ref_program(A, qB):
140120
import torch
141121

142122
B = (
143-
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half)
144-
.to(torch.half)
145-
.to(A.device)
146-
)
123+
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
124+
dtype=torch.half).to(torch.half).to(A.device))
147125
for i in range(B.shape[0]):
148126
for j in range(B.shape[1]):
149-
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(
150-
torch.half
151-
)
127+
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
152128
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
153129
C = C.to(torch.__getattribute__(dtypeC))
154130
return C
@@ -157,7 +133,7 @@ def ref_program(A, qB):
157133

158134

159135
def test_run_dequantize_gemm():
160-
run_gemm(16, 16, 16, "int8", "int32", "int32", 16, 16, 16, num_threads=128)
136+
run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128)
161137

162138

163139
if __name__ == "__main__":

testing/python/tilelang/test_tilelang_flash_atten.py

Lines changed: 22 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import argparse
2-
import torch
32
from tvm import tl
43
import tvm.tl.language as T
54
from tvm.tl.autotuner import *
@@ -14,15 +13,12 @@ def get_configs():
1413
thread_num = [128, 256]
1514
_configs = list(itertools.product(block_M, block_N, num_stages, thread_num))
1615

17-
configs = [
18-
{
19-
"block_M": c[0],
20-
"block_N": c[1],
21-
"num_stages": c[2],
22-
"thread_num": c[3],
23-
}
24-
for c in _configs
25-
]
16+
configs = [{
17+
"block_M": c[0],
18+
"block_N": c[1],
19+
"num_stages": c[2],
20+
"thread_num": c[3],
21+
} for c in _configs]
2622
return configs
2723

2824

@@ -48,21 +44,20 @@ def flashattn(batch, heads, seq_len, dim, is_casual):
4844
atol=0.01,
4945
)
5046
def kernel(block_M=None, block_N=None, num_stages=None, thread_num=None):
51-
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
47+
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
5248
shape = [batch, seq_len, heads, dim]
5349
dtype = "float16"
5450
accum_dtype = "float"
5551

5652
@T.prim_func
5753
def main(
58-
Q: T.Buffer(shape, dtype), # type: ignore
59-
K: T.Buffer(shape, dtype), # type: ignore
60-
V: T.Buffer(shape, dtype), # type: ignore
61-
Output: T.Buffer(shape, dtype), # type: ignore
54+
Q: T.Buffer(shape, dtype), # type: ignore
55+
K: T.Buffer(shape, dtype), # type: ignore
56+
V: T.Buffer(shape, dtype), # type: ignore
57+
Output: T.Buffer(shape, dtype), # type: ignore
6258
):
6359
with T.Kernel(
64-
T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num
65-
) as (bx, by, bz):
60+
T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):
6661
Q_shared = T.alloc_shared([block_M, dim], dtype)
6762
Q_local = T.alloc_fragment([block_M, dim], dtype)
6863
K_shared = T.alloc_shared([block_N, dim], dtype)
@@ -76,27 +71,19 @@ def main(
7671
scores_sum = T.alloc_fragment([block_M], accum_dtype)
7772
logsum = T.alloc_fragment([block_M], accum_dtype)
7873

79-
T.annotate_layout(
80-
{Q_shared: tl.layout.make_swizzled_layout(Q_shared)}
81-
)
82-
T.copy(
83-
Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared
84-
)
74+
T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})
75+
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
8576
T.fill(acc_o, 0)
8677
T.fill(logsum, 0)
8778
T.fill(scores_max, -T.infinity(accum_dtype))
8879
T.copy(Q_shared, Q_local)
8980
for i, j in T.Parallel(block_M, dim):
9081
Q_local[i, j] *= scale
9182
loop_range = (
92-
T.ceildiv((bx + 1) * block_M, block_N)
93-
if is_casual
94-
else T.ceildiv(seq_len, block_N)
95-
)
83+
T.ceildiv(
84+
(bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N))
9685
for k in T.Pipelined(loop_range, num_stages=num_stages):
97-
T.copy(
98-
K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared
99-
)
86+
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
10087
if is_casual:
10188
for i, j in T.Parallel(block_M, block_N):
10289
acc_s[i, j] = T.if_then_else(
@@ -113,15 +100,11 @@ def main(
113100
transpose_B=True,
114101
policy=T.GemmWarpPolicy.FullRow,
115102
)
116-
T.copy(
117-
V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared
118-
)
103+
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
119104
T.copy(scores_max, scores_max_prev)
120105
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
121106
for i in T.Parallel(block_M):
122-
scores_scale[i] = T.exp2(
123-
scores_max_prev[i] - scores_max[i]
124-
)
107+
scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])
125108
for i, j in T.Parallel(block_M, dim):
126109
acc_o[i, j] *= scores_scale[i]
127110
for i, j in T.Parallel(block_M, block_N):
@@ -138,9 +121,7 @@ def main(
138121
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
139122
for i, j in T.Parallel(block_M, dim):
140123
acc_o[i, j] /= logsum[i]
141-
T.copy(
142-
acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]
143-
)
124+
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
144125

145126
return main
146127

@@ -152,9 +133,7 @@ def main(
152133
parser.add_argument("--batch", type=int, default=64, help="Batch size")
153134
parser.add_argument("--h", type=int, default=12, help="Number of heads")
154135
parser.add_argument("--n_ctx", type=int, default=2048, help="Context size")
155-
parser.add_argument(
156-
"--d_head", type=int, default=256, help="Head dimension"
157-
)
136+
parser.add_argument("--d_head", type=int, default=256, help="Head dimension")
158137
parser.add_argument("--casual", type=bool, default=True, help="Casual flag")
159138
args = parser.parse_args()
160139
BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head
@@ -164,9 +143,7 @@ def main(
164143
if casual:
165144
total_flops *= 0.5
166145

167-
best_latency, best_config, ref_latency = flashattn(
168-
BATCH, H, N_CTX, D_HEAD, casual
169-
)
146+
best_latency, best_config, ref_latency = flashattn(BATCH, H, N_CTX, D_HEAD, casual)
170147
print(f"Best latency: {best_latency}")
171148
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
172149
print(f"Best config: {best_config}")

0 commit comments

Comments
 (0)