-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TL] Update several TL Examples (#168)
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * buf fix for matrix support * lint fix * dispatch tensor core based on shapes * update install commands * import scripts * remove shared mem hack * revert change for swizzling * bug fix * tl examples
- Loading branch information
1 parent
c55600f
commit 9b3b73b
Showing
2 changed files
with
337 additions
and
0 deletions.
There are no files selected for viewing
164 changes: 164 additions & 0 deletions
164
testing/python/tilelang/test_tilelang_dequantize_gemm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
import bitblas | ||
from bitblas import tvm as tvm | ||
from tvm import tl | ||
from bitblas.quantization import _tir_packed_to_unsigned_convert | ||
|
||
|
||
def matmul( | ||
M, | ||
N, | ||
K, | ||
block_M, | ||
block_N, | ||
block_K, | ||
dtypeAB, | ||
dtypeC, | ||
accum_dtype, | ||
num_stages, | ||
threads, | ||
num_bits=4, | ||
): | ||
num_elems_per_byte = 8 // num_bits | ||
storage_dtype = "int8" | ||
A_shape = (M, K) | ||
B_shape = (N, K // num_elems_per_byte) | ||
A_shared_shape = (block_M, block_K) | ||
B_shared_shape = (block_N, block_K // num_elems_per_byte) | ||
B_dequantize_shared_shape = (block_N, block_K) | ||
|
||
import tvm.tl.language as T | ||
|
||
@T.prim_func | ||
def main( | ||
A: T.Buffer(A_shape, dtypeAB), | ||
B: T.Buffer(B_shape, storage_dtype), | ||
C: T.Buffer((M, N), dtypeC), | ||
): | ||
with T.Kernel( | ||
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads | ||
) as (bx, by): | ||
A_shared = T.alloc_shared(A_shared_shape, dtypeAB) | ||
B_shared = T.alloc_shared(B_shared_shape, storage_dtype) | ||
B_local = T.alloc_fragment([8], storage_dtype, "local") | ||
B_dequantize_local = T.alloc_fragment([16], dtypeAB, "local") | ||
B_dequantize_shared = T.alloc_shared( | ||
B_dequantize_shared_shape, dtypeAB | ||
) | ||
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
T.clear(C_local) | ||
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): | ||
T.copy(A[by * block_M, k * block_K], A_shared) | ||
|
||
for i in T.serial( | ||
block_N * block_K // num_elems_per_byte // (threads * 16) | ||
): | ||
for t in T.thread_binding(0, threads, thread="threadIdx.x"): | ||
for v in T.vectorized(0, 16): | ||
vi = (i * threads * 16 + t * 16 + v) // ( | ||
block_K // num_elems_per_byte | ||
) | ||
vj = (i * threads * 16 + t * 16 + v) % ( | ||
block_K // num_elems_per_byte | ||
) | ||
B_shared[vi, vj] = B[ | ||
bx * block_N + vi, | ||
k * block_K // num_elems_per_byte + vj, | ||
] | ||
|
||
for i in T.serial( | ||
block_N * block_K // num_elems_per_byte // (threads * 4) | ||
): | ||
for t in T.thread_binding(0, threads, thread="threadIdx.x"): | ||
for v in T.vectorized(0, 4): | ||
vi = (i * threads * 4 + t * 4 + v) // ( | ||
block_K // num_elems_per_byte | ||
) | ||
vj = (i * threads * 4 + t * 4 + v) % ( | ||
block_K // num_elems_per_byte | ||
) | ||
B_local[v] = B_shared[vi, vj] | ||
for v in T.serial(0, 8): | ||
B_dequantize_local[ | ||
v | ||
] = _tir_packed_to_unsigned_convert("int", 8)( | ||
num_bits, | ||
B_local[v // 2], | ||
v % 2, | ||
dtype=dtypeAB, | ||
) | ||
for v in T.vectorized(0, 8): | ||
vi = (i * threads * 8 + t * 8 + v) // (block_K) | ||
vj = (i * threads * 8 + t * 8 + v) % (block_K) | ||
B_dequantize_shared[vi, vj] = B_dequantize_local[v] | ||
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) | ||
T.copy(C_local, C[by * block_M, bx * block_N]) | ||
|
||
return main | ||
|
||
|
||
def run_gemm( | ||
M, | ||
N, | ||
K, | ||
dtypeAB, | ||
dtypeC, | ||
dtypeAccum, | ||
block_M, | ||
block_N, | ||
block_K, | ||
num_stages=3, | ||
num_threads=128, | ||
): | ||
program = matmul( | ||
M, | ||
N, | ||
K, | ||
block_M, | ||
block_N, | ||
block_K, | ||
dtypeAB, | ||
dtypeC, | ||
dtypeAccum, | ||
num_stages, | ||
num_threads, | ||
) | ||
print(program) | ||
|
||
mod, params = tl.lower(program) | ||
mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) | ||
|
||
out = mod.run_once() | ||
|
||
print(f"output is {out}") | ||
|
||
with open("debug/kernel.cu", "w") as f: | ||
f.write(mod.mod.imported_modules[0].get_source()) | ||
|
||
def ref_program(A, qB): | ||
import torch | ||
|
||
B = ( | ||
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half) | ||
.to(torch.half) | ||
.to(A.device) | ||
) | ||
for i in range(B.shape[0]): | ||
for j in range(B.shape[1]): | ||
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to( | ||
torch.half | ||
) | ||
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) | ||
C = C.to(torch.__getattribute__(dtypeC)) | ||
return C | ||
|
||
mod.assert_allclose(ref_program) | ||
|
||
|
||
def test_run_dequantize_gemm(): | ||
run_gemm(16, 16, 16, "int8", "int32", "int32", 16, 16, 16, num_threads=128) | ||
|
||
|
||
if __name__ == "__main__": | ||
bitblas.testing.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
import argparse | ||
import torch | ||
from tvm import tl | ||
import tvm.tl.language as T | ||
from tvm.tl.autotuner import * | ||
from functools import partial | ||
import itertools | ||
|
||
|
||
def get_configs(): | ||
block_M = [32, 64, 128] | ||
block_N = [32, 64, 128] | ||
num_stages = [1, 2] | ||
thread_num = [128, 256] | ||
_configs = list(itertools.product(block_M, block_N, num_stages, thread_num)) | ||
|
||
configs = [ | ||
{ | ||
"block_M": c[0], | ||
"block_N": c[1], | ||
"num_stages": c[2], | ||
"thread_num": c[3], | ||
} | ||
for c in _configs | ||
] | ||
return configs | ||
|
||
|
||
def ref_program(Q, K, V, casual): | ||
from flash_attn.flash_attn_interface import flash_attn_func | ||
|
||
return flash_attn_func(Q, K, V, causal=casual) | ||
|
||
|
||
def flashattn(batch, heads, seq_len, dim, is_casual): | ||
|
||
@autotune( | ||
configs=get_configs(), | ||
keys=["block_M", "block_N", "num_stages", "thread_num"], | ||
warmup=10, | ||
rep=5, | ||
) | ||
@jit( | ||
out_idx=[3], | ||
supply_type=tl.TensorSupplyType.Normal, | ||
ref_prog=partial(ref_program, casual=is_casual), | ||
rtol=0.01, | ||
atol=0.01, | ||
) | ||
def kernel(block_M=None, block_N=None, num_stages=None, thread_num=None): | ||
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) | ||
shape = [batch, seq_len, heads, dim] | ||
dtype = "float16" | ||
accum_dtype = "float" | ||
|
||
@T.prim_func | ||
def main( | ||
Q: T.Buffer(shape, dtype), # type: ignore | ||
K: T.Buffer(shape, dtype), # type: ignore | ||
V: T.Buffer(shape, dtype), # type: ignore | ||
Output: T.Buffer(shape, dtype), # type: ignore | ||
): | ||
with T.Kernel( | ||
T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num | ||
) as (bx, by, bz): | ||
Q_shared = T.alloc_shared([block_M, dim], dtype) | ||
Q_local = T.alloc_fragment([block_M, dim], dtype) | ||
K_shared = T.alloc_shared([block_N, dim], dtype) | ||
V_shared = T.alloc_shared([block_N, dim], dtype) | ||
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) | ||
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) | ||
acc_o = T.alloc_fragment([block_M, dim], accum_dtype) | ||
scores_max = T.alloc_fragment([block_M], accum_dtype) | ||
scores_max_prev = T.alloc_fragment([block_M], accum_dtype) | ||
scores_scale = T.alloc_fragment([block_M], accum_dtype) | ||
scores_sum = T.alloc_fragment([block_M], accum_dtype) | ||
logsum = T.alloc_fragment([block_M], accum_dtype) | ||
|
||
T.annotate_layout( | ||
{Q_shared: tl.layout.make_swizzled_layout(Q_shared)} | ||
) | ||
T.copy( | ||
Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared | ||
) | ||
T.fill(acc_o, 0) | ||
T.fill(logsum, 0) | ||
T.fill(scores_max, -T.infinity(accum_dtype)) | ||
T.copy(Q_shared, Q_local) | ||
for i, j in T.Parallel(block_M, dim): | ||
Q_local[i, j] *= scale | ||
loop_range = ( | ||
T.ceildiv((bx + 1) * block_M, block_N) | ||
if is_casual | ||
else T.ceildiv(seq_len, block_N) | ||
) | ||
for k in T.Pipelined(loop_range, num_stages=num_stages): | ||
T.copy( | ||
K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared | ||
) | ||
if is_casual: | ||
for i, j in T.Parallel(block_M, block_N): | ||
acc_s[i, j] = T.if_then_else( | ||
bx * block_M + i >= k * block_N + j, | ||
0, | ||
-T.infinity(acc_s.dtype), | ||
) | ||
else: | ||
T.clear(acc_s) | ||
T.gemm( | ||
Q_local, | ||
K_shared, | ||
acc_s, | ||
transpose_B=True, | ||
policy=T.GemmWarpPolicy.FullRow, | ||
) | ||
T.copy( | ||
V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared | ||
) | ||
T.copy(scores_max, scores_max_prev) | ||
T.reduce_max(acc_s, scores_max, dim=1, clear=False) | ||
for i in T.Parallel(block_M): | ||
scores_scale[i] = T.exp2( | ||
scores_max_prev[i] - scores_max[i] | ||
) | ||
for i, j in T.Parallel(block_M, dim): | ||
acc_o[i, j] *= scores_scale[i] | ||
for i, j in T.Parallel(block_M, block_N): | ||
acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i]) | ||
T.copy(acc_s, acc_s_cast) | ||
T.gemm( | ||
acc_s_cast, | ||
V_shared, | ||
acc_o, | ||
policy=T.GemmWarpPolicy.FullRow, | ||
) | ||
T.reduce_sum(acc_s, scores_sum, dim=1) | ||
for i in T.Parallel(block_M): | ||
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] | ||
for i, j in T.Parallel(block_M, dim): | ||
acc_o[i, j] /= logsum[i] | ||
T.copy( | ||
acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :] | ||
) | ||
|
||
return main | ||
|
||
return kernel() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--batch", type=int, default=64, help="Batch size") | ||
parser.add_argument("--h", type=int, default=12, help="Number of heads") | ||
parser.add_argument("--n_ctx", type=int, default=2048, help="Context size") | ||
parser.add_argument( | ||
"--d_head", type=int, default=256, help="Head dimension" | ||
) | ||
parser.add_argument("--casual", type=bool, default=True, help="Casual flag") | ||
args = parser.parse_args() | ||
BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head | ||
casual = args.casual | ||
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD | ||
total_flops = 2 * flops_per_matmul | ||
if casual: | ||
total_flops *= 0.5 | ||
|
||
best_latency, best_config, ref_latency = flashattn( | ||
BATCH, H, N_CTX, D_HEAD, casual | ||
) | ||
print(f"Best latency: {best_latency}") | ||
print(f"Best TFlops: {total_flops / best_latency * 1e-9}") | ||
print(f"Best config: {best_config}") | ||
print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}") |