14
14
_is_flash_attention_supported ,
15
15
_is_fused_attention_supported ,
16
16
_is_unfused_attention_supported ,
17
- _run_dot_product_attention
17
+ _run_dot_product_attention ,
18
18
)
19
19
20
20
pd .set_option ("display.precision" , 4 )
28
28
# workspace optimization path for cuDNN attention
29
29
workspace_opt = True
30
30
# QKV memory layout
31
- qkv_layout = ' bshd_bshd_bshd'
31
+ qkv_layout = " bshd_bshd_bshd"
32
32
# sliding window attention
33
33
swa = False
34
34
# padding between sequences for qkv_format=thd
38
38
39
39
model_configs = {
40
40
# test: b, h, hg, d, sq, skv, p, mask, bias
41
- "test_0" : ModelConfig (2 , 16 , 16 , 64 , 512 , 512 , 0.0 , "no_mask" , "no_bias" ), # short seq
42
- "test_1" : ModelConfig (2 , 16 , 16 , 128 , 2048 , 2048 , 0.0 , "causal" , "no_bias" ), # longer seq, mask
43
- "test_2" : ModelConfig (2 , 16 , 16 , 128 , 2048 , 2048 , 0.0 , "causal" , "post_scale_bias" ), # bias
44
- "test_3" : ModelConfig (2 , 32 , 4 , 128 , 8192 , 8192 , 0.0 , "causal" , "no_bias" ), # GQA
41
+ "test_0" : ModelConfig (2 , 16 , 16 , 64 , 512 , 512 , 0.0 , "no_mask" , "no_bias" ), # short seq
42
+ "test_1" : ModelConfig (2 , 16 , 16 , 128 , 2048 , 2048 , 0.0 , "causal" , "no_bias" ), # longer seq, mask
43
+ "test_2" : ModelConfig (2 , 16 , 16 , 128 , 2048 , 2048 , 0.0 , "causal" , "post_scale_bias" ), # bias
44
+ "test_3" : ModelConfig (2 , 32 , 4 , 128 , 8192 , 8192 , 0.0 , "causal" , "no_bias" ), # GQA
45
45
}
46
46
47
+
47
48
def benchmark_dot_product_attention (model , fused_attn_supported , flash_attn_supported ):
48
49
config = model_configs [model ]
49
50
if dtype == torch .bfloat16 :
50
- tols = dict (atol = 2.5e-2 , rtol = 2.5e-2 )
51
+ tols = dict (atol = 2.5e-2 , rtol = 2.5e-2 )
51
52
else :
52
53
tols = dict (atol = 5e-3 , rtol = 5e-3 )
53
54
@@ -57,17 +58,31 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
57
58
for i in range (warmup_iters ):
58
59
if fused_attn_supported :
59
60
fused_attn_fwd , fused_attn_bwd = _run_dot_product_attention (
60
- dtype , config , "FusedAttention" ,
61
- ckpt_attn , qkv_layout , workspace_opt , swa , pad_between_seqs , is_training ,
61
+ dtype ,
62
+ config ,
63
+ "FusedAttention" ,
64
+ ckpt_attn ,
65
+ qkv_layout ,
66
+ workspace_opt ,
67
+ swa ,
68
+ pad_between_seqs ,
69
+ is_training ,
62
70
)
63
71
if flash_attn_supported :
64
72
flash_attn_fwd , flash_attn_bwd = _run_dot_product_attention (
65
- dtype , config , "FlashAttention" ,
66
- ckpt_attn , qkv_layout , workspace_opt , swa , pad_between_seqs , is_training ,
73
+ dtype ,
74
+ config ,
75
+ "FlashAttention" ,
76
+ ckpt_attn ,
77
+ qkv_layout ,
78
+ workspace_opt ,
79
+ swa ,
80
+ pad_between_seqs ,
81
+ is_training ,
67
82
)
68
83
if fused_attn_supported and flash_attn_supported :
69
84
torch .testing .assert_close (fused_attn_fwd , flash_attn_fwd , ** tols )
70
- for i ,_ in enumerate (flash_attn_bwd ):
85
+ for i , _ in enumerate (flash_attn_bwd ):
71
86
torch .testing .assert_close (fused_attn_bwd [i ], flash_attn_bwd [i ], ** tols )
72
87
73
88
torch .cuda .cudart ().cudaProfilerStart ()
@@ -76,8 +91,15 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
76
91
if fused_attn_supported :
77
92
for i in range (num_iters ):
78
93
fused_attn_fwd , fused_attn_bwd = _run_dot_product_attention (
79
- dtype , config , "FusedAttention" ,
80
- ckpt_attn , qkv_layout , workspace_opt , swa , pad_between_seqs , is_training ,
94
+ dtype ,
95
+ config ,
96
+ "FusedAttention" ,
97
+ ckpt_attn ,
98
+ qkv_layout ,
99
+ workspace_opt ,
100
+ swa ,
101
+ pad_between_seqs ,
102
+ is_training ,
81
103
)
82
104
torch .cuda .synchronize ()
83
105
fused_attn_time = time .time () - fused_attn_start if fused_attn_supported else 0
@@ -87,81 +109,113 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
87
109
if flash_attn_supported :
88
110
for i in range (num_iters ):
89
111
flash_attn_fwd , flash_attn_bwd = _run_dot_product_attention (
90
- dtype , config , "FlashAttention" ,
91
- ckpt_attn , qkv_layout , workspace_opt , swa , pad_between_seqs , is_training ,
112
+ dtype ,
113
+ config ,
114
+ "FlashAttention" ,
115
+ ckpt_attn ,
116
+ qkv_layout ,
117
+ workspace_opt ,
118
+ swa ,
119
+ pad_between_seqs ,
120
+ is_training ,
92
121
)
93
122
torch .cuda .synchronize ()
94
123
flash_attn_time = time .time () - flash_attn_start if flash_attn_supported else 0
95
124
96
- df = pd .read_csv ('times.csv' )
97
- df = pd .concat ([
98
- df ,
99
- pd .DataFrame (
100
- [[fused_attn_time * 1e3 / num_iters , 0 , 0 , 0 ,
101
- flash_attn_time * 1e3 / num_iters , 0 , 0 , 0 , 0 ]], columns = df .columns )],
102
- ignore_index = True
103
- )
104
- df .to_csv ('times.csv' ,index = False )
125
+ df = pd .read_csv ("times.csv" )
126
+ df = pd .concat (
127
+ [
128
+ df ,
129
+ pd .DataFrame (
130
+ [
131
+ [
132
+ fused_attn_time * 1e3 / num_iters ,
133
+ 0 ,
134
+ 0 ,
135
+ 0 ,
136
+ flash_attn_time * 1e3 / num_iters ,
137
+ 0 ,
138
+ 0 ,
139
+ 0 ,
140
+ 0 ,
141
+ ]
142
+ ],
143
+ columns = df .columns ,
144
+ ),
145
+ ],
146
+ ignore_index = True ,
147
+ )
148
+ df .to_csv ("times.csv" , index = False )
105
149
torch .cuda .cudart ().cudaProfilerStop ()
106
150
151
+
107
152
def parse_results (per_cudnn , per_flash , model ):
108
- filename = f' prof_{ model } _cuda_gpu_trace.csv'
109
- df = pd .read_csv (os .path .join ('./' , filename ))
110
- df_times = pd .read_csv (' times.csv' )
111
- row = len (df_times .index )- 1
112
-
153
+ filename = f" prof_{ model } _cuda_gpu_trace.csv"
154
+ df = pd .read_csv (os .path .join ("./" , filename ))
155
+ df_times = pd .read_csv (" times.csv" )
156
+ row = len (df_times .index ) - 1
157
+
113
158
if per_cudnn > 0 :
114
- t_cudnn_all = df [df [' Name' ].str .contains (' cudnn' )][' Duration (ns)' ].to_numpy ()
159
+ t_cudnn_all = df [df [" Name" ].str .contains (" cudnn" )][" Duration (ns)" ].to_numpy ()
115
160
t_cudnn_all = t_cudnn_all .reshape (- 1 , per_cudnn )
116
161
t_cudnn_avg = np .average (t_cudnn_all , axis = 0 )
117
- df_times .loc [row , ' FusedAttention Kernels (fwd)' ] = t_cudnn_avg [0 ]/ 1e6
118
- df_times .loc [row , ' FusedAttention Kernels (bwd)' ] = t_cudnn_avg [1 :4 ].sum ()/ 1e6
119
- df_times .loc [row , ' FusedAttention Kernels (fwd+bwd)' ] = t_cudnn_avg .sum ()/ 1e6
162
+ df_times .loc [row , " FusedAttention Kernels (fwd)" ] = t_cudnn_avg [0 ] / 1e6
163
+ df_times .loc [row , " FusedAttention Kernels (bwd)" ] = t_cudnn_avg [1 :4 ].sum () / 1e6
164
+ df_times .loc [row , " FusedAttention Kernels (fwd+bwd)" ] = t_cudnn_avg .sum () / 1e6
120
165
121
166
if per_flash > 0 :
122
- t_flash_all = df [df [' Name' ].str .contains (' void flash' )][' Duration (ns)' ].to_numpy ()
167
+ t_flash_all = df [df [" Name" ].str .contains (" void flash" )][" Duration (ns)" ].to_numpy ()
123
168
t_flash_all = t_flash_all .reshape (- 1 , per_flash )
124
169
t_flash_avg = np .average (t_flash_all , axis = 0 )
125
- df_times .loc [row , ' FlashAttention Kernels (fwd)' ] = t_flash_avg [0 ]/ 1e6
126
- df_times .loc [row , ' FlashAttention Kernels (bwd)' ] = t_flash_avg [1 :4 ].sum ()/ 1e6
127
- df_times .loc [row , ' FlashAttention Kernels (fwd+bwd)' ] = t_flash_avg .sum ()/ 1e6
170
+ df_times .loc [row , " FlashAttention Kernels (fwd)" ] = t_flash_avg [0 ] / 1e6
171
+ df_times .loc [row , " FlashAttention Kernels (bwd)" ] = t_flash_avg [1 :4 ].sum () / 1e6
172
+ df_times .loc [row , " FlashAttention Kernels (fwd+bwd)" ] = t_flash_avg .sum () / 1e6
128
173
129
174
if per_cudnn > 0 and per_flash > 0 :
130
- df_times .loc [row , 'Fused vs Flash Kernels Speedup (fwd+bwd)' ] = \
131
- df_times .loc [row , 'FlashAttention Kernels (fwd+bwd)' ] / \
132
- df_times .loc [row , 'FusedAttention Kernels (fwd+bwd)' ]
133
- df_times .to_csv ('times.csv' ,index = False )
175
+ df_times .loc [row , "Fused vs Flash Kernels Speedup (fwd+bwd)" ] = (
176
+ df_times .loc [row , "FlashAttention Kernels (fwd+bwd)" ]
177
+ / df_times .loc [row , "FusedAttention Kernels (fwd+bwd)" ]
178
+ )
179
+ df_times .to_csv ("times.csv" , index = False )
180
+
134
181
135
182
def main ():
136
183
times = pd .DataFrame (
137
- columns = [
138
- 'FusedAttention Module' ,
139
- 'FusedAttention Kernels (fwd)' ,
140
- 'FusedAttention Kernels (bwd)' ,
141
- 'FusedAttention Kernels (fwd+bwd)' ,
142
- 'FlashAttention Module' ,
143
- 'FlashAttention Kernels (fwd)' ,
144
- 'FlashAttention Kernels (bwd)' ,
145
- 'FlashAttention Kernels (fwd+bwd)' ,
146
- 'Fused vs Flash Kernels Speedup (fwd+bwd)' ,
147
- ])
148
- times .to_csv ('times.csv' ,index = False )
184
+ columns = [
185
+ "FusedAttention Module" ,
186
+ "FusedAttention Kernels (fwd)" ,
187
+ "FusedAttention Kernels (bwd)" ,
188
+ "FusedAttention Kernels (fwd+bwd)" ,
189
+ "FlashAttention Module" ,
190
+ "FlashAttention Kernels (fwd)" ,
191
+ "FlashAttention Kernels (bwd)" ,
192
+ "FlashAttention Kernels (fwd+bwd)" ,
193
+ "Fused vs Flash Kernels Speedup (fwd+bwd)" ,
194
+ ]
195
+ )
196
+ times .to_csv ("times.csv" , index = False )
149
197
150
198
device_id = torch .cuda .current_device ()
151
199
device_properties = torch .cuda .get_device_properties (device_id )
152
- print (f"Device { device_id } : "
200
+ print (
201
+ f"Device { device_id } : "
153
202
f"{ device_properties .name } GPU, "
154
203
f"sm{ device_properties .major } { device_properties .minor } compute capability, "
155
- f"{ device_properties .total_memory / 1024 ** 3 :.1f} GB memory" )
204
+ f"{ device_properties .total_memory / 1024 ** 3 :.1f} GB memory"
205
+ )
156
206
for model in model_configs .keys ():
157
207
config = model_configs [model ]
158
208
fused_attn_supported , fused_attn_backend = _is_fused_attention_supported (
159
- config , dtype , qkv_layout = qkv_layout ,
209
+ config ,
210
+ dtype ,
211
+ qkv_layout = qkv_layout ,
160
212
)
161
213
fused_attn_supported = fused_attn_supported and not swa
162
214
flash_attn_supported = _is_flash_attention_supported (config )
163
- print (f'Running { model } with { "cuDNN attention" if fused_attn_supported else "" } '
164
- f'{ " and flash-attention" if flash_attn_supported else "" } ...' )
215
+ print (
216
+ f'Running { model } with { "cuDNN attention" if fused_attn_supported else "" } '
217
+ f'{ " and flash-attention" if flash_attn_supported else "" } ...'
218
+ )
165
219
166
220
prof_cmd = [
167
221
"nsys" ,
@@ -175,8 +229,8 @@ def main():
175
229
f""" "import benchmark_attention;""" ,
176
230
f"""benchmark_attention.benchmark_dot_product_attention("""
177
231
f"""'{ model } ', { fused_attn_supported } , { flash_attn_supported } )" """ ,
178
- ]
179
- prof_cmd = ' ' .join (prof_cmd )
232
+ ]
233
+ prof_cmd = " " .join (prof_cmd )
180
234
subprocess .call (prof_cmd , stdout = subprocess .DEVNULL , stderr = subprocess .DEVNULL , shell = True )
181
235
stats_cmd = [
182
236
"nsys" ,
@@ -190,36 +244,41 @@ def main():
190
244
"--force-export=true" ,
191
245
f"--output=prof_{ model } " ,
192
246
f"prof_{ model } .nsys-rep" ,
193
- ]
247
+ ]
194
248
if fused_attn_supported :
195
249
num_kernels_cudnn = 4
196
- if config .attn_bias_type == ' post_scale_bias' :
197
- num_kernels_cudnn = num_kernels_cudnn + 1
250
+ if config .attn_bias_type == " post_scale_bias" :
251
+ num_kernels_cudnn = num_kernels_cudnn + 1
198
252
if config .num_heads != config .num_gqa_groups :
199
- num_kernels_cudnn = num_kernels_cudnn + 2
253
+ num_kernels_cudnn = num_kernels_cudnn + 2
200
254
else :
201
255
num_kernels_cudnn = 0
202
256
num_kernels_flash = 4 if flash_attn_supported else 0
203
- stats_cmd = ' ' .join (stats_cmd )
257
+ stats_cmd = " " .join (stats_cmd )
204
258
subprocess .call (stats_cmd , stdout = subprocess .DEVNULL , stderr = subprocess .DEVNULL , shell = True )
205
259
parse_cmd = [
206
260
"python" ,
207
261
"-c" ,
208
262
f""" "import benchmark_attention;""" ,
209
263
f"""benchmark_attention.parse_results("""
210
264
f"""{ num_kernels_cudnn } , { num_kernels_flash } , '{ model } ')" """ ,
211
- ]
212
- parse_cmd = ' ' .join (parse_cmd )
265
+ ]
266
+ parse_cmd = " " .join (parse_cmd )
213
267
subprocess .call (parse_cmd , stdout = subprocess .DEVNULL , stderr = subprocess .DEVNULL , shell = True )
214
268
215
- df_times = pd .read_csv (' times.csv' )
269
+ df_times = pd .read_csv (" times.csv" )
216
270
df_times .index = list (model_configs .keys ())
217
- a = df_times [['FusedAttention Kernels (fwd+bwd)' ,
218
- 'FlashAttention Kernels (fwd+bwd)' ,
219
- 'Fused vs Flash Kernels Speedup (fwd+bwd)' ]]
220
- a .columns = ['cuDNN fwd+bwd (ms)' , 'flash-attn fwd+bwd (ms)' , 'cuDNN vs flash speedup' ]
271
+ a = df_times [
272
+ [
273
+ "FusedAttention Kernels (fwd+bwd)" ,
274
+ "FlashAttention Kernels (fwd+bwd)" ,
275
+ "Fused vs Flash Kernels Speedup (fwd+bwd)" ,
276
+ ]
277
+ ]
278
+ a .columns = ["cuDNN fwd+bwd (ms)" , "flash-attn fwd+bwd (ms)" , "cuDNN vs flash speedup" ]
221
279
print ()
222
280
print (a )
223
281
282
+
224
283
if __name__ == "__main__" :
225
284
main ()
0 commit comments