1
- import argparse
2
1
from tvm import tl
3
2
import tvm .tl .language as T
4
3
from tvm .tl .autotuner import *
5
4
from functools import partial
6
5
import itertools
6
+ import torch
7
+ import bitblas
8
+ import logging
9
+ from bitblas import set_log_level
10
+
11
+ set_log_level (logging .DEBUG )
7
12
8
13
9
14
def get_configs ():
@@ -22,13 +27,28 @@ def get_configs():
22
27
return configs
23
28
24
29
25
- def ref_program (Q , K , V , casual ):
30
+ def ref_program (Q , K , V , causal ):
26
31
from flash_attn .flash_attn_interface import flash_attn_func
27
32
28
- return flash_attn_func (Q , K , V , causal = casual )
33
+ return flash_attn_func (Q , K , V , causal = causal )
34
+
35
+
36
+ def ref_flashattn_result (batch , heads , seq_len , dim , is_casual , dtype = "float16" ):
37
+ q_shape = (batch , seq_len , heads , dim )
38
+ k_shape = (batch , seq_len , heads , dim )
39
+ v_shape = (batch , seq_len , heads , dim )
40
+ typemap = {"float16" : torch .float16 }
41
+ Q = torch .rand (batch * seq_len * heads * dim ).uniform_ (- 1 , 1 ).reshape (q_shape ).type (
42
+ typemap [dtype ]).cuda ()
43
+ K = torch .rand (batch * seq_len * heads * dim ).uniform_ (- 1 , 1 ).reshape (k_shape ).type (
44
+ typemap [dtype ]).cuda ()
45
+ V = torch .rand (batch * seq_len * heads * dim ).uniform_ (- 1 , 1 ).reshape (v_shape ).type (
46
+ typemap [dtype ]).cuda ()
47
+ res = ref_program (Q , K , V , is_casual )
48
+ return res
29
49
30
50
31
- def flashattn (batch , heads , seq_len , dim , is_casual ):
51
+ def flashattn_autotune (batch , heads , seq_len , dim , is_causal ):
32
52
33
53
@autotune (
34
54
configs = get_configs (),
@@ -39,7 +59,7 @@ def flashattn(batch, heads, seq_len, dim, is_casual):
39
59
@jit (
40
60
out_idx = [3 ],
41
61
supply_type = tl .TensorSupplyType .Normal ,
42
- ref_prog = partial (ref_program , casual = is_casual ),
62
+ ref_prog = partial (ref_program , causal = is_causal ),
43
63
rtol = 0.01 ,
44
64
atol = 0.01 ,
45
65
)
@@ -81,10 +101,10 @@ def main(
81
101
Q_local [i , j ] *= scale
82
102
loop_range = (
83
103
T .ceildiv (
84
- (bx + 1 ) * block_M , block_N ) if is_casual else T .ceildiv (seq_len , block_N ))
104
+ (bx + 1 ) * block_M , block_N ) if is_causal else T .ceildiv (seq_len , block_N ))
85
105
for k in T .Pipelined (loop_range , num_stages = num_stages ):
86
106
T .copy (K [bz , k * block_N :(k + 1 ) * block_N , by , :], K_shared )
87
- if is_casual :
107
+ if is_causal :
88
108
for i , j in T .Parallel (block_M , block_N ):
89
109
acc_s [i , j ] = T .if_then_else (
90
110
bx * block_M + i >= k * block_N + j ,
@@ -128,23 +148,112 @@ def main(
128
148
return kernel ()
129
149
130
150
151
+ @bitblas .testing .requires_cuda_compute_version (8 , 9 )
152
+ def test_flashattn_autotune ():
153
+ flashattn_autotune (1 , 4 , 256 , 256 , True )
154
+ flashattn_autotune (1 , 8 , 256 , 256 , True )
155
+ flashattn_autotune (4 , 4 , 256 , 256 , True )
156
+ flashattn_autotune (4 , 8 , 256 , 256 , True )
157
+
158
+
159
+ def flashattn (batch , heads , seq_len , dim , is_causal ):
160
+
161
+ def kernel (block_M = 64 , block_N = 64 , num_stages = 1 , thread_num = 128 ):
162
+ scale = (1.0 / dim )** 0.5 * 1.44269504
163
+ shape = [batch , seq_len , heads , dim ]
164
+ dtype = "float16"
165
+ accum_dtype = "float"
166
+
167
+ @T .prim_func
168
+ def main (
169
+ Q : T .Buffer (shape , dtype ),
170
+ K : T .Buffer (shape , dtype ),
171
+ V : T .Buffer (shape , dtype ),
172
+ Output : T .Buffer (shape , dtype ),
173
+ ):
174
+ print (type (seq_len ), seq_len )
175
+ print (type (block_M ), block_M )
176
+ with T .Kernel (
177
+ T .ceildiv (seq_len , block_M ), heads , batch , threads = thread_num ) as (bx , by , bz ):
178
+ Q_shared = T .alloc_shared ([block_M , dim ], dtype )
179
+ Q_local = T .alloc_fragment ([block_M , dim ], dtype )
180
+ K_shared = T .alloc_shared ([block_N , dim ], dtype )
181
+ V_shared = T .alloc_shared ([block_N , dim ], dtype )
182
+ acc_s = T .alloc_fragment ([block_M , block_N ], accum_dtype )
183
+ acc_s_cast = T .alloc_fragment ([block_M , block_N ], dtype )
184
+ acc_o = T .alloc_fragment ([block_M , dim ], accum_dtype )
185
+ scores_max = T .alloc_fragment ([block_M ], accum_dtype )
186
+ scores_max_prev = T .alloc_fragment ([block_M ], accum_dtype )
187
+ scores_scale = T .alloc_fragment ([block_M ], accum_dtype )
188
+ scores_sum = T .alloc_fragment ([block_M ], accum_dtype )
189
+ logsum = T .alloc_fragment ([block_M ], accum_dtype )
190
+
191
+ T .annotate_layout ({Q_shared : tl .layout .make_swizzled_layout (Q_shared )})
192
+ T .copy (Q [bz , bx * block_M :(bx + 1 ) * block_M , by , :], Q_shared )
193
+ T .fill (acc_o , 0 )
194
+ T .fill (logsum , 0 )
195
+ T .fill (scores_max , - T .infinity (accum_dtype ))
196
+ T .copy (Q_shared , Q_local )
197
+ for i , j in T .Parallel (block_M , dim ):
198
+ Q_local [i , j ] *= scale
199
+ loop_range = (
200
+ T .ceildiv (
201
+ (bx + 1 ) * block_M , block_N ) if is_causal else T .ceildiv (seq_len , block_N ))
202
+ for k in T .Pipelined (loop_range , num_stages = num_stages ):
203
+ T .copy (K [bz , k * block_N :(k + 1 ) * block_N , by , :], K_shared )
204
+ if is_causal :
205
+ for i , j in T .Parallel (block_M , block_N ):
206
+ acc_s [i , j ] = T .if_then_else (
207
+ bx * block_M + i >= k * block_N + j ,
208
+ 0 ,
209
+ - T .infinity (acc_s .dtype ),
210
+ )
211
+ else :
212
+ T .clear (acc_s )
213
+ T .gemm (
214
+ Q_local ,
215
+ K_shared ,
216
+ acc_s ,
217
+ transpose_B = True ,
218
+ policy = T .GemmWarpPolicy .FullRow ,
219
+ )
220
+ T .copy (V [bz , k * block_N :(k + 1 ) * block_N , by , :], V_shared )
221
+ T .copy (scores_max , scores_max_prev )
222
+ T .reduce_max (acc_s , scores_max , dim = 1 , clear = False )
223
+ for i , j in T .Parallel (block_M , block_N ):
224
+ acc_s [i , j ] = T .exp2 (acc_s [i , j ] - scores_max [i ])
225
+ for i in T .Parallel (block_M ):
226
+ scores_scale [i ] = T .exp2 (scores_max_prev [i ] - scores_max [i ])
227
+ for i , j in T .Parallel (block_M , dim ):
228
+ acc_o [i , j ] *= scores_scale [i ]
229
+ T .copy (acc_s , acc_s_cast )
230
+ T .gemm (
231
+ acc_s_cast ,
232
+ V_shared ,
233
+ acc_o ,
234
+ policy = T .GemmWarpPolicy .FullRow ,
235
+ )
236
+ T .reduce_sum (acc_s , scores_sum , dim = 1 )
237
+ for i in T .Parallel (block_M ):
238
+ logsum [i ] = logsum [i ] * scores_scale [i ] + scores_sum [i ]
239
+ for i , j in T .Parallel (block_M , dim ):
240
+ acc_o [i , j ] /= logsum [i ]
241
+ T .copy (acc_o , Output [bz , bx * block_M :(bx + 1 ) * block_M , by , :])
242
+
243
+ return main
244
+
245
+ mod , params = tl .lower (kernel ())
246
+ mod = tl .Profiler (mod , params , [3 ], tl .TensorSupplyType .Normal )
247
+ mod .assert_allclose (partial (ref_program , causal = is_causal ), rtol = 0.01 , atol = 0.01 )
248
+
249
+
250
+ @bitblas .testing .requires_cuda_compute_version (8 , 9 )
251
+ def test_flashattn ():
252
+ flashattn (1 , 4 , 256 , 256 , True )
253
+ flashattn (1 , 8 , 256 , 256 , True )
254
+ flashattn (4 , 4 , 256 , 256 , True )
255
+ flashattn (4 , 8 , 256 , 256 , True )
256
+
257
+
131
258
if __name__ == "__main__" :
132
- parser = argparse .ArgumentParser ()
133
- parser .add_argument ("--batch" , type = int , default = 64 , help = "Batch size" )
134
- parser .add_argument ("--h" , type = int , default = 12 , help = "Number of heads" )
135
- parser .add_argument ("--n_ctx" , type = int , default = 2048 , help = "Context size" )
136
- parser .add_argument ("--d_head" , type = int , default = 256 , help = "Head dimension" )
137
- parser .add_argument ("--casual" , type = bool , default = True , help = "Casual flag" )
138
- args = parser .parse_args ()
139
- BATCH , H , N_CTX , D_HEAD = args .batch , args .h , args .n_ctx , args .d_head
140
- casual = args .casual
141
- flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
142
- total_flops = 2 * flops_per_matmul
143
- if casual :
144
- total_flops *= 0.5
145
-
146
- best_latency , best_config , ref_latency = flashattn (BATCH , H , N_CTX , D_HEAD , casual )
147
- print (f"Best latency: { best_latency } " )
148
- print (f"Best TFlops: { total_flops / best_latency * 1e-9 } " )
149
- print (f"Best config: { best_config } " )
150
- print (f"Ref TFlops: { total_flops / ref_latency * 1e-9 } " )
259
+ bitblas .testing .main ()
0 commit comments