1
1
import argparse
2
- import torch
3
2
from tvm import tl
4
3
import tvm .tl .language as T
5
4
from tvm .tl .autotuner import *
@@ -14,15 +13,12 @@ def get_configs():
14
13
thread_num = [128 , 256 ]
15
14
_configs = list (itertools .product (block_M , block_N , num_stages , thread_num ))
16
15
17
- configs = [
18
- {
19
- "block_M" : c [0 ],
20
- "block_N" : c [1 ],
21
- "num_stages" : c [2 ],
22
- "thread_num" : c [3 ],
23
- }
24
- for c in _configs
25
- ]
16
+ configs = [{
17
+ "block_M" : c [0 ],
18
+ "block_N" : c [1 ],
19
+ "num_stages" : c [2 ],
20
+ "thread_num" : c [3 ],
21
+ } for c in _configs ]
26
22
return configs
27
23
28
24
@@ -48,21 +44,20 @@ def flashattn(batch, heads, seq_len, dim, is_casual):
48
44
atol = 0.01 ,
49
45
)
50
46
def kernel (block_M = None , block_N = None , num_stages = None , thread_num = None ):
51
- scale = (1.0 / dim ) ** 0.5 * 1.44269504 # log2(e)
47
+ scale = (1.0 / dim )** 0.5 * 1.44269504 # log2(e)
52
48
shape = [batch , seq_len , heads , dim ]
53
49
dtype = "float16"
54
50
accum_dtype = "float"
55
51
56
52
@T .prim_func
57
53
def main (
58
- Q : T .Buffer (shape , dtype ), # type: ignore
59
- K : T .Buffer (shape , dtype ), # type: ignore
60
- V : T .Buffer (shape , dtype ), # type: ignore
61
- Output : T .Buffer (shape , dtype ), # type: ignore
54
+ Q : T .Buffer (shape , dtype ), # type: ignore
55
+ K : T .Buffer (shape , dtype ), # type: ignore
56
+ V : T .Buffer (shape , dtype ), # type: ignore
57
+ Output : T .Buffer (shape , dtype ), # type: ignore
62
58
):
63
59
with T .Kernel (
64
- T .ceildiv (seq_len , block_M ), heads , batch , threads = thread_num
65
- ) as (bx , by , bz ):
60
+ T .ceildiv (seq_len , block_M ), heads , batch , threads = thread_num ) as (bx , by , bz ):
66
61
Q_shared = T .alloc_shared ([block_M , dim ], dtype )
67
62
Q_local = T .alloc_fragment ([block_M , dim ], dtype )
68
63
K_shared = T .alloc_shared ([block_N , dim ], dtype )
@@ -76,27 +71,19 @@ def main(
76
71
scores_sum = T .alloc_fragment ([block_M ], accum_dtype )
77
72
logsum = T .alloc_fragment ([block_M ], accum_dtype )
78
73
79
- T .annotate_layout (
80
- {Q_shared : tl .layout .make_swizzled_layout (Q_shared )}
81
- )
82
- T .copy (
83
- Q [bz , bx * block_M : (bx + 1 ) * block_M , by , :], Q_shared
84
- )
74
+ T .annotate_layout ({Q_shared : tl .layout .make_swizzled_layout (Q_shared )})
75
+ T .copy (Q [bz , bx * block_M :(bx + 1 ) * block_M , by , :], Q_shared )
85
76
T .fill (acc_o , 0 )
86
77
T .fill (logsum , 0 )
87
78
T .fill (scores_max , - T .infinity (accum_dtype ))
88
79
T .copy (Q_shared , Q_local )
89
80
for i , j in T .Parallel (block_M , dim ):
90
81
Q_local [i , j ] *= scale
91
82
loop_range = (
92
- T .ceildiv ((bx + 1 ) * block_M , block_N )
93
- if is_casual
94
- else T .ceildiv (seq_len , block_N )
95
- )
83
+ T .ceildiv (
84
+ (bx + 1 ) * block_M , block_N ) if is_casual else T .ceildiv (seq_len , block_N ))
96
85
for k in T .Pipelined (loop_range , num_stages = num_stages ):
97
- T .copy (
98
- K [bz , k * block_N : (k + 1 ) * block_N , by , :], K_shared
99
- )
86
+ T .copy (K [bz , k * block_N :(k + 1 ) * block_N , by , :], K_shared )
100
87
if is_casual :
101
88
for i , j in T .Parallel (block_M , block_N ):
102
89
acc_s [i , j ] = T .if_then_else (
@@ -113,15 +100,11 @@ def main(
113
100
transpose_B = True ,
114
101
policy = T .GemmWarpPolicy .FullRow ,
115
102
)
116
- T .copy (
117
- V [bz , k * block_N : (k + 1 ) * block_N , by , :], V_shared
118
- )
103
+ T .copy (V [bz , k * block_N :(k + 1 ) * block_N , by , :], V_shared )
119
104
T .copy (scores_max , scores_max_prev )
120
105
T .reduce_max (acc_s , scores_max , dim = 1 , clear = False )
121
106
for i in T .Parallel (block_M ):
122
- scores_scale [i ] = T .exp2 (
123
- scores_max_prev [i ] - scores_max [i ]
124
- )
107
+ scores_scale [i ] = T .exp2 (scores_max_prev [i ] - scores_max [i ])
125
108
for i , j in T .Parallel (block_M , dim ):
126
109
acc_o [i , j ] *= scores_scale [i ]
127
110
for i , j in T .Parallel (block_M , block_N ):
@@ -138,9 +121,7 @@ def main(
138
121
logsum [i ] = logsum [i ] * scores_scale [i ] + scores_sum [i ]
139
122
for i , j in T .Parallel (block_M , dim ):
140
123
acc_o [i , j ] /= logsum [i ]
141
- T .copy (
142
- acc_o , Output [bz , bx * block_M : (bx + 1 ) * block_M , by , :]
143
- )
124
+ T .copy (acc_o , Output [bz , bx * block_M :(bx + 1 ) * block_M , by , :])
144
125
145
126
return main
146
127
@@ -152,9 +133,7 @@ def main(
152
133
parser .add_argument ("--batch" , type = int , default = 64 , help = "Batch size" )
153
134
parser .add_argument ("--h" , type = int , default = 12 , help = "Number of heads" )
154
135
parser .add_argument ("--n_ctx" , type = int , default = 2048 , help = "Context size" )
155
- parser .add_argument (
156
- "--d_head" , type = int , default = 256 , help = "Head dimension"
157
- )
136
+ parser .add_argument ("--d_head" , type = int , default = 256 , help = "Head dimension" )
158
137
parser .add_argument ("--casual" , type = bool , default = True , help = "Casual flag" )
159
138
args = parser .parse_args ()
160
139
BATCH , H , N_CTX , D_HEAD = args .batch , args .h , args .n_ctx , args .d_head
@@ -164,9 +143,7 @@ def main(
164
143
if casual :
165
144
total_flops *= 0.5
166
145
167
- best_latency , best_config , ref_latency = flashattn (
168
- BATCH , H , N_CTX , D_HEAD , casual
169
- )
146
+ best_latency , best_config , ref_latency = flashattn (BATCH , H , N_CTX , D_HEAD , casual )
170
147
print (f"Best latency: { best_latency } " )
171
148
print (f"Best TFlops: { total_flops / best_latency * 1e-9 } " )
172
149
print (f"Best config: { best_config } " )
0 commit comments