Skip to content

Commit

Permalink
[TL] test flashattention script (#196)
Browse files Browse the repository at this point in the history
* [TL] test flashattention script

* [TL] test flashattention script, remove commented code

* [TL] test flashattention script, remove commented code

* [TL] test flashattention script, remove commented code

* [TL] test flashattention script, remove commented code, modify format
  • Loading branch information
tzj-fxz authored Sep 27, 2024
1 parent 70e5214 commit 9f5f0ea
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 27 deletions.
2 changes: 1 addition & 1 deletion format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ else
# Check spelling only of the files that changed in last commit.
spell_check_changed
fi
echo 'BitBLAS codespell: Done'
echo 'bitblas codespell: Done'

echo 'bitblas ruff: Check Start'
# Lint specified files
Expand Down
161 changes: 135 additions & 26 deletions testing/python/tilelang/test_tilelang_flash_atten.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import argparse
from tvm import tl
import tvm.tl.language as T
from tvm.tl.autotuner import *
from functools import partial
import itertools
import torch
import bitblas
import logging
from bitblas import set_log_level

set_log_level(logging.DEBUG)


def get_configs():
Expand All @@ -22,13 +27,28 @@ def get_configs():
return configs


def ref_program(Q, K, V, casual):
def ref_program(Q, K, V, causal):
from flash_attn.flash_attn_interface import flash_attn_func

return flash_attn_func(Q, K, V, causal=casual)
return flash_attn_func(Q, K, V, causal=causal)


def ref_flashattn_result(batch, heads, seq_len, dim, is_casual, dtype="float16"):
q_shape = (batch, seq_len, heads, dim)
k_shape = (batch, seq_len, heads, dim)
v_shape = (batch, seq_len, heads, dim)
typemap = {"float16": torch.float16}
Q = torch.rand(batch * seq_len * heads * dim).uniform_(-1, 1).reshape(q_shape).type(
typemap[dtype]).cuda()
K = torch.rand(batch * seq_len * heads * dim).uniform_(-1, 1).reshape(k_shape).type(
typemap[dtype]).cuda()
V = torch.rand(batch * seq_len * heads * dim).uniform_(-1, 1).reshape(v_shape).type(
typemap[dtype]).cuda()
res = ref_program(Q, K, V, is_casual)
return res


def flashattn(batch, heads, seq_len, dim, is_casual):
def flashattn_autotune(batch, heads, seq_len, dim, is_causal):

@autotune(
configs=get_configs(),
Expand All @@ -39,7 +59,7 @@ def flashattn(batch, heads, seq_len, dim, is_casual):
@jit(
out_idx=[3],
supply_type=tl.TensorSupplyType.Normal,
ref_prog=partial(ref_program, casual=is_casual),
ref_prog=partial(ref_program, causal=is_causal),
rtol=0.01,
atol=0.01,
)
Expand Down Expand Up @@ -81,10 +101,10 @@ def main(
Q_local[i, j] *= scale
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N))
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_casual:
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i >= k * block_N + j,
Expand Down Expand Up @@ -128,23 +148,112 @@ def main(
return kernel()


@bitblas.testing.requires_cuda_compute_version(8, 9)
def test_flashattn_autotune():
flashattn_autotune(1, 4, 256, 256, True)
flashattn_autotune(1, 8, 256, 256, True)
flashattn_autotune(4, 4, 256, 256, True)
flashattn_autotune(4, 8, 256, 256, True)


def flashattn(batch, heads, seq_len, dim, is_causal):

def kernel(block_M=64, block_N=64, num_stages=1, thread_num=128):
scale = (1.0 / dim)**0.5 * 1.44269504
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"

@T.prim_func
def main(
Q: T.Buffer(shape, dtype),
K: T.Buffer(shape, dtype),
V: T.Buffer(shape, dtype),
Output: T.Buffer(shape, dtype),
):
print(type(seq_len), seq_len)
print(type(block_M), block_M)
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
Q_local = T.alloc_fragment([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)

T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.copy(Q_shared, Q_local)
for i, j in T.Parallel(block_M, dim):
Q_local[i, j] *= scale
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i >= k * block_N + j,
0,
-T.infinity(acc_s.dtype),
)
else:
T.clear(acc_s)
T.gemm(
Q_local,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(acc_s, acc_s_cast)
T.gemm(
acc_s_cast,
V_shared,
acc_o,
policy=T.GemmWarpPolicy.FullRow,
)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])

return main

mod, params = tl.lower(kernel())
mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal)
mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01)


@bitblas.testing.requires_cuda_compute_version(8, 9)
def test_flashattn():
flashattn(1, 4, 256, 256, True)
flashattn(1, 8, 256, 256, True)
flashattn(4, 4, 256, 256, True)
flashattn(4, 8, 256, 256, True)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=64, help="Batch size")
parser.add_argument("--h", type=int, default=12, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=2048, help="Context size")
parser.add_argument("--d_head", type=int, default=256, help="Head dimension")
parser.add_argument("--casual", type=bool, default=True, help="Casual flag")
args = parser.parse_args()
BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head
casual = args.casual
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if casual:
total_flops *= 0.5

best_latency, best_config, ref_latency = flashattn(BATCH, H, N_CTX, D_HEAD, casual)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}")
bitblas.testing.main()

0 comments on commit 9f5f0ea

Please sign in to comment.