Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TL] test flashattention script #196

Merged
merged 5 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ else
# Check spelling only of the files that changed in last commit.
spell_check_changed
fi
echo 'BitBLAS codespell: Done'
echo 'bitblas codespell: Done'

echo 'bitblas ruff: Check Start'
# Lint specified files
Expand Down
161 changes: 135 additions & 26 deletions testing/python/tilelang/test_tilelang_flash_atten.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import argparse
from tvm import tl
import tvm.tl.language as T
from tvm.tl.autotuner import *
from functools import partial
import itertools
import torch
import bitblas
import logging
from bitblas import set_log_level

set_log_level(logging.DEBUG)


def get_configs():
Expand All @@ -22,13 +27,28 @@ def get_configs():
return configs


def ref_program(Q, K, V, casual):
def ref_program(Q, K, V, causal):
from flash_attn.flash_attn_interface import flash_attn_func

return flash_attn_func(Q, K, V, causal=casual)
return flash_attn_func(Q, K, V, causal=causal)


def ref_flashattn_result(batch, heads, seq_len, dim, is_casual, dtype="float16"):
q_shape = (batch, seq_len, heads, dim)
k_shape = (batch, seq_len, heads, dim)
v_shape = (batch, seq_len, heads, dim)
typemap = {"float16": torch.float16}
Q = torch.rand(batch * seq_len * heads * dim).uniform_(-1, 1).reshape(q_shape).type(
typemap[dtype]).cuda()
K = torch.rand(batch * seq_len * heads * dim).uniform_(-1, 1).reshape(k_shape).type(
typemap[dtype]).cuda()
V = torch.rand(batch * seq_len * heads * dim).uniform_(-1, 1).reshape(v_shape).type(
typemap[dtype]).cuda()
res = ref_program(Q, K, V, is_casual)
return res


def flashattn(batch, heads, seq_len, dim, is_casual):
def flashattn_autotune(batch, heads, seq_len, dim, is_causal):

@autotune(
configs=get_configs(),
Expand All @@ -39,7 +59,7 @@ def flashattn(batch, heads, seq_len, dim, is_casual):
@jit(
out_idx=[3],
supply_type=tl.TensorSupplyType.Normal,
ref_prog=partial(ref_program, casual=is_casual),
ref_prog=partial(ref_program, causal=is_causal),
rtol=0.01,
atol=0.01,
)
Expand Down Expand Up @@ -81,10 +101,10 @@ def main(
Q_local[i, j] *= scale
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N))
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_casual:
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i >= k * block_N + j,
Expand Down Expand Up @@ -128,23 +148,112 @@ def main(
return kernel()


@bitblas.testing.requires_cuda_compute_version(8, 9)
def test_flashattn_autotune():
flashattn_autotune(1, 4, 256, 256, True)
flashattn_autotune(1, 8, 256, 256, True)
flashattn_autotune(4, 4, 256, 256, True)
flashattn_autotune(4, 8, 256, 256, True)


def flashattn(batch, heads, seq_len, dim, is_causal):

def kernel(block_M=64, block_N=64, num_stages=1, thread_num=128):
scale = (1.0 / dim)**0.5 * 1.44269504
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"

@T.prim_func
def main(
Q: T.Buffer(shape, dtype),
K: T.Buffer(shape, dtype),
V: T.Buffer(shape, dtype),
Output: T.Buffer(shape, dtype),
):
print(type(seq_len), seq_len)
print(type(block_M), block_M)
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
Q_local = T.alloc_fragment([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)

T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.copy(Q_shared, Q_local)
for i, j in T.Parallel(block_M, dim):
Q_local[i, j] *= scale
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(
bx * block_M + i >= k * block_N + j,
0,
-T.infinity(acc_s.dtype),
)
else:
T.clear(acc_s)
T.gemm(
Q_local,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
T.copy(acc_s, acc_s_cast)
T.gemm(
acc_s_cast,
V_shared,
acc_o,
policy=T.GemmWarpPolicy.FullRow,
)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])

return main

mod, params = tl.lower(kernel())
mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal)
mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01)


@bitblas.testing.requires_cuda_compute_version(8, 9)
def test_flashattn():
flashattn(1, 4, 256, 256, True)
flashattn(1, 8, 256, 256, True)
flashattn(4, 4, 256, 256, True)
flashattn(4, 8, 256, 256, True)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=64, help="Batch size")
parser.add_argument("--h", type=int, default=12, help="Number of heads")
parser.add_argument("--n_ctx", type=int, default=2048, help="Context size")
parser.add_argument("--d_head", type=int, default=256, help="Head dimension")
parser.add_argument("--casual", type=bool, default=True, help="Casual flag")
args = parser.parse_args()
BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head
casual = args.casual
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if casual:
total_flops *= 0.5

best_latency, best_config, ref_latency = flashattn(BATCH, H, N_CTX, D_HEAD, casual)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}")
bitblas.testing.main()
Loading