Skip to content

Commit 9f5f0ea

Browse files
authored
[TL] test flashattention script (#196)
* [TL] test flashattention script * [TL] test flashattention script, remove commented code * [TL] test flashattention script, remove commented code * [TL] test flashattention script, remove commented code * [TL] test flashattention script, remove commented code, modify format
1 parent 70e5214 commit 9f5f0ea

File tree

2 files changed

+136
-27
lines changed

2 files changed

+136
-27
lines changed

format.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ else
143143
# Check spelling only of the files that changed in last commit.
144144
spell_check_changed
145145
fi
146-
echo 'BitBLAS codespell: Done'
146+
echo 'bitblas codespell: Done'
147147

148148
echo 'bitblas ruff: Check Start'
149149
# Lint specified files

testing/python/tilelang/test_tilelang_flash_atten.py

Lines changed: 135 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
import argparse
21
from tvm import tl
32
import tvm.tl.language as T
43
from tvm.tl.autotuner import *
54
from functools import partial
65
import itertools
6+
import torch
7+
import bitblas
8+
import logging
9+
from bitblas import set_log_level
10+
11+
set_log_level(logging.DEBUG)
712

813

914
def get_configs():
@@ -22,13 +27,28 @@ def get_configs():
2227
return configs
2328

2429

25-
def ref_program(Q, K, V, casual):
30+
def ref_program(Q, K, V, causal):
2631
from flash_attn.flash_attn_interface import flash_attn_func
2732

28-
return flash_attn_func(Q, K, V, causal=casual)
33+
return flash_attn_func(Q, K, V, causal=causal)
34+
35+
36+
def ref_flashattn_result(batch, heads, seq_len, dim, is_casual, dtype="float16"):
37+
q_shape = (batch, seq_len, heads, dim)
38+
k_shape = (batch, seq_len, heads, dim)
39+
v_shape = (batch, seq_len, heads, dim)
40+
typemap = {"float16": torch.float16}
41+
Q = torch.rand(batch * seq_len * heads * dim).uniform_(-1, 1).reshape(q_shape).type(
42+
typemap[dtype]).cuda()
43+
K = torch.rand(batch * seq_len * heads * dim).uniform_(-1, 1).reshape(k_shape).type(
44+
typemap[dtype]).cuda()
45+
V = torch.rand(batch * seq_len * heads * dim).uniform_(-1, 1).reshape(v_shape).type(
46+
typemap[dtype]).cuda()
47+
res = ref_program(Q, K, V, is_casual)
48+
return res
2949

3050

31-
def flashattn(batch, heads, seq_len, dim, is_casual):
51+
def flashattn_autotune(batch, heads, seq_len, dim, is_causal):
3252

3353
@autotune(
3454
configs=get_configs(),
@@ -39,7 +59,7 @@ def flashattn(batch, heads, seq_len, dim, is_casual):
3959
@jit(
4060
out_idx=[3],
4161
supply_type=tl.TensorSupplyType.Normal,
42-
ref_prog=partial(ref_program, casual=is_casual),
62+
ref_prog=partial(ref_program, causal=is_causal),
4363
rtol=0.01,
4464
atol=0.01,
4565
)
@@ -81,10 +101,10 @@ def main(
81101
Q_local[i, j] *= scale
82102
loop_range = (
83103
T.ceildiv(
84-
(bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N))
104+
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
85105
for k in T.Pipelined(loop_range, num_stages=num_stages):
86106
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
87-
if is_casual:
107+
if is_causal:
88108
for i, j in T.Parallel(block_M, block_N):
89109
acc_s[i, j] = T.if_then_else(
90110
bx * block_M + i >= k * block_N + j,
@@ -128,23 +148,112 @@ def main(
128148
return kernel()
129149

130150

151+
@bitblas.testing.requires_cuda_compute_version(8, 9)
152+
def test_flashattn_autotune():
153+
flashattn_autotune(1, 4, 256, 256, True)
154+
flashattn_autotune(1, 8, 256, 256, True)
155+
flashattn_autotune(4, 4, 256, 256, True)
156+
flashattn_autotune(4, 8, 256, 256, True)
157+
158+
159+
def flashattn(batch, heads, seq_len, dim, is_causal):
160+
161+
def kernel(block_M=64, block_N=64, num_stages=1, thread_num=128):
162+
scale = (1.0 / dim)**0.5 * 1.44269504
163+
shape = [batch, seq_len, heads, dim]
164+
dtype = "float16"
165+
accum_dtype = "float"
166+
167+
@T.prim_func
168+
def main(
169+
Q: T.Buffer(shape, dtype),
170+
K: T.Buffer(shape, dtype),
171+
V: T.Buffer(shape, dtype),
172+
Output: T.Buffer(shape, dtype),
173+
):
174+
print(type(seq_len), seq_len)
175+
print(type(block_M), block_M)
176+
with T.Kernel(
177+
T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):
178+
Q_shared = T.alloc_shared([block_M, dim], dtype)
179+
Q_local = T.alloc_fragment([block_M, dim], dtype)
180+
K_shared = T.alloc_shared([block_N, dim], dtype)
181+
V_shared = T.alloc_shared([block_N, dim], dtype)
182+
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
183+
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
184+
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
185+
scores_max = T.alloc_fragment([block_M], accum_dtype)
186+
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
187+
scores_scale = T.alloc_fragment([block_M], accum_dtype)
188+
scores_sum = T.alloc_fragment([block_M], accum_dtype)
189+
logsum = T.alloc_fragment([block_M], accum_dtype)
190+
191+
T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})
192+
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
193+
T.fill(acc_o, 0)
194+
T.fill(logsum, 0)
195+
T.fill(scores_max, -T.infinity(accum_dtype))
196+
T.copy(Q_shared, Q_local)
197+
for i, j in T.Parallel(block_M, dim):
198+
Q_local[i, j] *= scale
199+
loop_range = (
200+
T.ceildiv(
201+
(bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N))
202+
for k in T.Pipelined(loop_range, num_stages=num_stages):
203+
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
204+
if is_causal:
205+
for i, j in T.Parallel(block_M, block_N):
206+
acc_s[i, j] = T.if_then_else(
207+
bx * block_M + i >= k * block_N + j,
208+
0,
209+
-T.infinity(acc_s.dtype),
210+
)
211+
else:
212+
T.clear(acc_s)
213+
T.gemm(
214+
Q_local,
215+
K_shared,
216+
acc_s,
217+
transpose_B=True,
218+
policy=T.GemmWarpPolicy.FullRow,
219+
)
220+
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
221+
T.copy(scores_max, scores_max_prev)
222+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
223+
for i, j in T.Parallel(block_M, block_N):
224+
acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])
225+
for i in T.Parallel(block_M):
226+
scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])
227+
for i, j in T.Parallel(block_M, dim):
228+
acc_o[i, j] *= scores_scale[i]
229+
T.copy(acc_s, acc_s_cast)
230+
T.gemm(
231+
acc_s_cast,
232+
V_shared,
233+
acc_o,
234+
policy=T.GemmWarpPolicy.FullRow,
235+
)
236+
T.reduce_sum(acc_s, scores_sum, dim=1)
237+
for i in T.Parallel(block_M):
238+
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
239+
for i, j in T.Parallel(block_M, dim):
240+
acc_o[i, j] /= logsum[i]
241+
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
242+
243+
return main
244+
245+
mod, params = tl.lower(kernel())
246+
mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal)
247+
mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01)
248+
249+
250+
@bitblas.testing.requires_cuda_compute_version(8, 9)
251+
def test_flashattn():
252+
flashattn(1, 4, 256, 256, True)
253+
flashattn(1, 8, 256, 256, True)
254+
flashattn(4, 4, 256, 256, True)
255+
flashattn(4, 8, 256, 256, True)
256+
257+
131258
if __name__ == "__main__":
132-
parser = argparse.ArgumentParser()
133-
parser.add_argument("--batch", type=int, default=64, help="Batch size")
134-
parser.add_argument("--h", type=int, default=12, help="Number of heads")
135-
parser.add_argument("--n_ctx", type=int, default=2048, help="Context size")
136-
parser.add_argument("--d_head", type=int, default=256, help="Head dimension")
137-
parser.add_argument("--casual", type=bool, default=True, help="Casual flag")
138-
args = parser.parse_args()
139-
BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head
140-
casual = args.casual
141-
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
142-
total_flops = 2 * flops_per_matmul
143-
if casual:
144-
total_flops *= 0.5
145-
146-
best_latency, best_config, ref_latency = flashattn(BATCH, H, N_CTX, D_HEAD, casual)
147-
print(f"Best latency: {best_latency}")
148-
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
149-
print(f"Best config: {best_config}")
150-
print(f"Ref TFlops: {total_flops / ref_latency * 1e-9}")
259+
bitblas.testing.main()

0 commit comments

Comments
 (0)