Skip to content

Commit 9416519

Browse files
authored
Apply formatting (NVIDIA#929)
* Apply formatting Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Apply formatting Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
1 parent d99142a commit 9416519

File tree

256 files changed

+35450
-35134
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

256 files changed

+35450
-35134
lines changed

benchmarks/attention/benchmark_attention.py

Lines changed: 132 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
_is_flash_attention_supported,
1515
_is_fused_attention_supported,
1616
_is_unfused_attention_supported,
17-
_run_dot_product_attention
17+
_run_dot_product_attention,
1818
)
1919

2020
pd.set_option("display.precision", 4)
@@ -28,7 +28,7 @@
2828
# workspace optimization path for cuDNN attention
2929
workspace_opt = True
3030
# QKV memory layout
31-
qkv_layout = 'bshd_bshd_bshd'
31+
qkv_layout = "bshd_bshd_bshd"
3232
# sliding window attention
3333
swa = False
3434
# padding between sequences for qkv_format=thd
@@ -38,16 +38,17 @@
3838

3939
model_configs = {
4040
# test: b, h, hg, d, sq, skv, p, mask, bias
41-
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
42-
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
43-
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
44-
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
41+
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
42+
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
43+
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
44+
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
4545
}
4646

47+
4748
def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported):
4849
config = model_configs[model]
4950
if dtype == torch.bfloat16:
50-
tols = dict(atol=2.5e-2, rtol=2.5e-2)
51+
tols = dict(atol=2.5e-2, rtol=2.5e-2)
5152
else:
5253
tols = dict(atol=5e-3, rtol=5e-3)
5354

@@ -57,17 +58,31 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
5758
for i in range(warmup_iters):
5859
if fused_attn_supported:
5960
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
60-
dtype, config, "FusedAttention",
61-
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
61+
dtype,
62+
config,
63+
"FusedAttention",
64+
ckpt_attn,
65+
qkv_layout,
66+
workspace_opt,
67+
swa,
68+
pad_between_seqs,
69+
is_training,
6270
)
6371
if flash_attn_supported:
6472
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
65-
dtype, config, "FlashAttention",
66-
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
73+
dtype,
74+
config,
75+
"FlashAttention",
76+
ckpt_attn,
77+
qkv_layout,
78+
workspace_opt,
79+
swa,
80+
pad_between_seqs,
81+
is_training,
6782
)
6883
if fused_attn_supported and flash_attn_supported:
6984
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
70-
for i,_ in enumerate(flash_attn_bwd):
85+
for i, _ in enumerate(flash_attn_bwd):
7186
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
7287

7388
torch.cuda.cudart().cudaProfilerStart()
@@ -76,8 +91,15 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
7691
if fused_attn_supported:
7792
for i in range(num_iters):
7893
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
79-
dtype, config, "FusedAttention",
80-
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
94+
dtype,
95+
config,
96+
"FusedAttention",
97+
ckpt_attn,
98+
qkv_layout,
99+
workspace_opt,
100+
swa,
101+
pad_between_seqs,
102+
is_training,
81103
)
82104
torch.cuda.synchronize()
83105
fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0
@@ -87,81 +109,113 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
87109
if flash_attn_supported:
88110
for i in range(num_iters):
89111
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
90-
dtype, config, "FlashAttention",
91-
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
112+
dtype,
113+
config,
114+
"FlashAttention",
115+
ckpt_attn,
116+
qkv_layout,
117+
workspace_opt,
118+
swa,
119+
pad_between_seqs,
120+
is_training,
92121
)
93122
torch.cuda.synchronize()
94123
flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0
95124

96-
df = pd.read_csv('times.csv')
97-
df = pd.concat([
98-
df,
99-
pd.DataFrame(
100-
[[fused_attn_time*1e3/num_iters, 0, 0, 0,
101-
flash_attn_time*1e3/num_iters, 0, 0, 0, 0]], columns=df.columns)],
102-
ignore_index=True
103-
)
104-
df.to_csv('times.csv',index=False)
125+
df = pd.read_csv("times.csv")
126+
df = pd.concat(
127+
[
128+
df,
129+
pd.DataFrame(
130+
[
131+
[
132+
fused_attn_time * 1e3 / num_iters,
133+
0,
134+
0,
135+
0,
136+
flash_attn_time * 1e3 / num_iters,
137+
0,
138+
0,
139+
0,
140+
0,
141+
]
142+
],
143+
columns=df.columns,
144+
),
145+
],
146+
ignore_index=True,
147+
)
148+
df.to_csv("times.csv", index=False)
105149
torch.cuda.cudart().cudaProfilerStop()
106150

151+
107152
def parse_results(per_cudnn, per_flash, model):
108-
filename = f'prof_{model}_cuda_gpu_trace.csv'
109-
df = pd.read_csv(os.path.join('./',filename))
110-
df_times = pd.read_csv('times.csv')
111-
row = len(df_times.index)-1
112-
153+
filename = f"prof_{model}_cuda_gpu_trace.csv"
154+
df = pd.read_csv(os.path.join("./", filename))
155+
df_times = pd.read_csv("times.csv")
156+
row = len(df_times.index) - 1
157+
113158
if per_cudnn > 0:
114-
t_cudnn_all = df[df['Name'].str.contains('cudnn')]['Duration (ns)'].to_numpy()
159+
t_cudnn_all = df[df["Name"].str.contains("cudnn")]["Duration (ns)"].to_numpy()
115160
t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn)
116161
t_cudnn_avg = np.average(t_cudnn_all, axis=0)
117-
df_times.loc[row, 'FusedAttention Kernels (fwd)'] = t_cudnn_avg[0]/1e6
118-
df_times.loc[row, 'FusedAttention Kernels (bwd)'] = t_cudnn_avg[1:4].sum()/1e6
119-
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)'] = t_cudnn_avg.sum()/1e6
162+
df_times.loc[row, "FusedAttention Kernels (fwd)"] = t_cudnn_avg[0] / 1e6
163+
df_times.loc[row, "FusedAttention Kernels (bwd)"] = t_cudnn_avg[1:4].sum() / 1e6
164+
df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6
120165

121166
if per_flash > 0:
122-
t_flash_all = df[df['Name'].str.contains('void flash')]['Duration (ns)'].to_numpy()
167+
t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy()
123168
t_flash_all = t_flash_all.reshape(-1, per_flash)
124169
t_flash_avg = np.average(t_flash_all, axis=0)
125-
df_times.loc[row, 'FlashAttention Kernels (fwd)'] = t_flash_avg[0]/1e6
126-
df_times.loc[row, 'FlashAttention Kernels (bwd)'] = t_flash_avg[1:4].sum()/1e6
127-
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] = t_flash_avg.sum()/1e6
170+
df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
171+
df_times.loc[row, "FlashAttention Kernels (bwd)"] = t_flash_avg[1:4].sum() / 1e6
172+
df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] = t_flash_avg.sum() / 1e6
128173

129174
if per_cudnn > 0 and per_flash > 0:
130-
df_times.loc[row, 'Fused vs Flash Kernels Speedup (fwd+bwd)'] = \
131-
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] / \
132-
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)']
133-
df_times.to_csv('times.csv',index=False)
175+
df_times.loc[row, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = (
176+
df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"]
177+
/ df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"]
178+
)
179+
df_times.to_csv("times.csv", index=False)
180+
134181

135182
def main():
136183
times = pd.DataFrame(
137-
columns=[
138-
'FusedAttention Module',
139-
'FusedAttention Kernels (fwd)',
140-
'FusedAttention Kernels (bwd)',
141-
'FusedAttention Kernels (fwd+bwd)',
142-
'FlashAttention Module',
143-
'FlashAttention Kernels (fwd)',
144-
'FlashAttention Kernels (bwd)',
145-
'FlashAttention Kernels (fwd+bwd)',
146-
'Fused vs Flash Kernels Speedup (fwd+bwd)',
147-
])
148-
times.to_csv('times.csv',index=False)
184+
columns=[
185+
"FusedAttention Module",
186+
"FusedAttention Kernels (fwd)",
187+
"FusedAttention Kernels (bwd)",
188+
"FusedAttention Kernels (fwd+bwd)",
189+
"FlashAttention Module",
190+
"FlashAttention Kernels (fwd)",
191+
"FlashAttention Kernels (bwd)",
192+
"FlashAttention Kernels (fwd+bwd)",
193+
"Fused vs Flash Kernels Speedup (fwd+bwd)",
194+
]
195+
)
196+
times.to_csv("times.csv", index=False)
149197

150198
device_id = torch.cuda.current_device()
151199
device_properties = torch.cuda.get_device_properties(device_id)
152-
print(f"Device {device_id}: "
200+
print(
201+
f"Device {device_id}: "
153202
f"{device_properties.name} GPU, "
154203
f"sm{device_properties.major}{device_properties.minor} compute capability, "
155-
f"{device_properties.total_memory/1024**3:.1f}GB memory")
204+
f"{device_properties.total_memory/1024**3:.1f}GB memory"
205+
)
156206
for model in model_configs.keys():
157207
config = model_configs[model]
158208
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
159-
config, dtype, qkv_layout=qkv_layout,
209+
config,
210+
dtype,
211+
qkv_layout=qkv_layout,
160212
)
161213
fused_attn_supported = fused_attn_supported and not swa
162214
flash_attn_supported = _is_flash_attention_supported(config)
163-
print(f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
164-
f'{" and flash-attention" if flash_attn_supported else ""}...')
215+
print(
216+
f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
217+
f'{" and flash-attention" if flash_attn_supported else ""}...'
218+
)
165219

166220
prof_cmd = [
167221
"nsys",
@@ -175,8 +229,8 @@ def main():
175229
f""" "import benchmark_attention;""",
176230
f"""benchmark_attention.benchmark_dot_product_attention("""
177231
f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """,
178-
]
179-
prof_cmd = ' '.join(prof_cmd)
232+
]
233+
prof_cmd = " ".join(prof_cmd)
180234
subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
181235
stats_cmd = [
182236
"nsys",
@@ -190,36 +244,41 @@ def main():
190244
"--force-export=true",
191245
f"--output=prof_{model}",
192246
f"prof_{model}.nsys-rep",
193-
]
247+
]
194248
if fused_attn_supported:
195249
num_kernels_cudnn = 4
196-
if config.attn_bias_type == 'post_scale_bias':
197-
num_kernels_cudnn = num_kernels_cudnn+1
250+
if config.attn_bias_type == "post_scale_bias":
251+
num_kernels_cudnn = num_kernels_cudnn + 1
198252
if config.num_heads != config.num_gqa_groups:
199-
num_kernels_cudnn = num_kernels_cudnn+2
253+
num_kernels_cudnn = num_kernels_cudnn + 2
200254
else:
201255
num_kernels_cudnn = 0
202256
num_kernels_flash = 4 if flash_attn_supported else 0
203-
stats_cmd = ' '.join(stats_cmd)
257+
stats_cmd = " ".join(stats_cmd)
204258
subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
205259
parse_cmd = [
206260
"python",
207261
"-c",
208262
f""" "import benchmark_attention;""",
209263
f"""benchmark_attention.parse_results("""
210264
f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """,
211-
]
212-
parse_cmd = ' '.join(parse_cmd)
265+
]
266+
parse_cmd = " ".join(parse_cmd)
213267
subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
214268

215-
df_times = pd.read_csv('times.csv')
269+
df_times = pd.read_csv("times.csv")
216270
df_times.index = list(model_configs.keys())
217-
a=df_times[['FusedAttention Kernels (fwd+bwd)',
218-
'FlashAttention Kernels (fwd+bwd)',
219-
'Fused vs Flash Kernels Speedup (fwd+bwd)']]
220-
a.columns = ['cuDNN fwd+bwd (ms)', 'flash-attn fwd+bwd (ms)', 'cuDNN vs flash speedup']
271+
a = df_times[
272+
[
273+
"FusedAttention Kernels (fwd+bwd)",
274+
"FlashAttention Kernels (fwd+bwd)",
275+
"Fused vs Flash Kernels Speedup (fwd+bwd)",
276+
]
277+
]
278+
a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"]
221279
print()
222280
print(a)
223281

282+
224283
if __name__ == "__main__":
225284
main()

0 commit comments

Comments
 (0)