Skip to content

Commit

Permalink
Apply formatting (NVIDIA#929)
Browse files Browse the repository at this point in the history
* Apply formatting

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Apply formatting

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored Jun 14, 2024
1 parent d99142a commit 9416519
Show file tree
Hide file tree
Showing 256 changed files with 35,450 additions and 35,134 deletions.
205 changes: 132 additions & 73 deletions benchmarks/attention/benchmark_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_run_dot_product_attention
_run_dot_product_attention,
)

pd.set_option("display.precision", 4)
Expand All @@ -28,7 +28,7 @@
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
qkv_layout = 'bshd_bshd_bshd'
qkv_layout = "bshd_bshd_bshd"
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
Expand All @@ -38,16 +38,17 @@

model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}


def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model]
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=5e-3, rtol=5e-3)

Expand All @@ -57,17 +58,31 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
for i in range(warmup_iters):
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i,_ in enumerate(flash_attn_bwd):
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)

torch.cuda.cudart().cudaProfilerStart()
Expand All @@ -76,8 +91,15 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
if fused_attn_supported:
for i in range(num_iters):
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
torch.cuda.synchronize()
fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0
Expand All @@ -87,81 +109,113 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
if flash_attn_supported:
for i in range(num_iters):
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention",
ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
dtype,
config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
torch.cuda.synchronize()
flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0

df = pd.read_csv('times.csv')
df = pd.concat([
df,
pd.DataFrame(
[[fused_attn_time*1e3/num_iters, 0, 0, 0,
flash_attn_time*1e3/num_iters, 0, 0, 0, 0]], columns=df.columns)],
ignore_index=True
)
df.to_csv('times.csv',index=False)
df = pd.read_csv("times.csv")
df = pd.concat(
[
df,
pd.DataFrame(
[
[
fused_attn_time * 1e3 / num_iters,
0,
0,
0,
flash_attn_time * 1e3 / num_iters,
0,
0,
0,
0,
]
],
columns=df.columns,
),
],
ignore_index=True,
)
df.to_csv("times.csv", index=False)
torch.cuda.cudart().cudaProfilerStop()


def parse_results(per_cudnn, per_flash, model):
filename = f'prof_{model}_cuda_gpu_trace.csv'
df = pd.read_csv(os.path.join('./',filename))
df_times = pd.read_csv('times.csv')
row = len(df_times.index)-1
filename = f"prof_{model}_cuda_gpu_trace.csv"
df = pd.read_csv(os.path.join("./", filename))
df_times = pd.read_csv("times.csv")
row = len(df_times.index) - 1

if per_cudnn > 0:
t_cudnn_all = df[df['Name'].str.contains('cudnn')]['Duration (ns)'].to_numpy()
t_cudnn_all = df[df["Name"].str.contains("cudnn")]["Duration (ns)"].to_numpy()
t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn)
t_cudnn_avg = np.average(t_cudnn_all, axis=0)
df_times.loc[row, 'FusedAttention Kernels (fwd)'] = t_cudnn_avg[0]/1e6
df_times.loc[row, 'FusedAttention Kernels (bwd)'] = t_cudnn_avg[1:4].sum()/1e6
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)'] = t_cudnn_avg.sum()/1e6
df_times.loc[row, "FusedAttention Kernels (fwd)"] = t_cudnn_avg[0] / 1e6
df_times.loc[row, "FusedAttention Kernels (bwd)"] = t_cudnn_avg[1:4].sum() / 1e6
df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6

if per_flash > 0:
t_flash_all = df[df['Name'].str.contains('void flash')]['Duration (ns)'].to_numpy()
t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, 'FlashAttention Kernels (fwd)'] = t_flash_avg[0]/1e6
df_times.loc[row, 'FlashAttention Kernels (bwd)'] = t_flash_avg[1:4].sum()/1e6
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] = t_flash_avg.sum()/1e6
df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
df_times.loc[row, "FlashAttention Kernels (bwd)"] = t_flash_avg[1:4].sum() / 1e6
df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] = t_flash_avg.sum() / 1e6

if per_cudnn > 0 and per_flash > 0:
df_times.loc[row, 'Fused vs Flash Kernels Speedup (fwd+bwd)'] = \
df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] / \
df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)']
df_times.to_csv('times.csv',index=False)
df_times.loc[row, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = (
df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"]
/ df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"]
)
df_times.to_csv("times.csv", index=False)


def main():
times = pd.DataFrame(
columns=[
'FusedAttention Module',
'FusedAttention Kernels (fwd)',
'FusedAttention Kernels (bwd)',
'FusedAttention Kernels (fwd+bwd)',
'FlashAttention Module',
'FlashAttention Kernels (fwd)',
'FlashAttention Kernels (bwd)',
'FlashAttention Kernels (fwd+bwd)',
'Fused vs Flash Kernels Speedup (fwd+bwd)',
])
times.to_csv('times.csv',index=False)
columns=[
"FusedAttention Module",
"FusedAttention Kernels (fwd)",
"FusedAttention Kernels (bwd)",
"FusedAttention Kernels (fwd+bwd)",
"FlashAttention Module",
"FlashAttention Kernels (fwd)",
"FlashAttention Kernels (bwd)",
"FlashAttention Kernels (fwd+bwd)",
"Fused vs Flash Kernels Speedup (fwd+bwd)",
]
)
times.to_csv("times.csv", index=False)

device_id = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(device_id)
print(f"Device {device_id}: "
print(
f"Device {device_id}: "
f"{device_properties.name} GPU, "
f"sm{device_properties.major}{device_properties.minor} compute capability, "
f"{device_properties.total_memory/1024**3:.1f}GB memory")
f"{device_properties.total_memory/1024**3:.1f}GB memory"
)
for model in model_configs.keys():
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
config,
dtype,
qkv_layout=qkv_layout,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
print(f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...')
print(
f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...'
)

prof_cmd = [
"nsys",
Expand All @@ -175,8 +229,8 @@ def main():
f""" "import benchmark_attention;""",
f"""benchmark_attention.benchmark_dot_product_attention("""
f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """,
]
prof_cmd = ' '.join(prof_cmd)
]
prof_cmd = " ".join(prof_cmd)
subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
stats_cmd = [
"nsys",
Expand All @@ -190,36 +244,41 @@ def main():
"--force-export=true",
f"--output=prof_{model}",
f"prof_{model}.nsys-rep",
]
]
if fused_attn_supported:
num_kernels_cudnn = 4
if config.attn_bias_type == 'post_scale_bias':
num_kernels_cudnn = num_kernels_cudnn+1
if config.attn_bias_type == "post_scale_bias":
num_kernels_cudnn = num_kernels_cudnn + 1
if config.num_heads != config.num_gqa_groups:
num_kernels_cudnn = num_kernels_cudnn+2
num_kernels_cudnn = num_kernels_cudnn + 2
else:
num_kernels_cudnn = 0
num_kernels_flash = 4 if flash_attn_supported else 0
stats_cmd = ' '.join(stats_cmd)
stats_cmd = " ".join(stats_cmd)
subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
parse_cmd = [
"python",
"-c",
f""" "import benchmark_attention;""",
f"""benchmark_attention.parse_results("""
f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """,
]
parse_cmd = ' '.join(parse_cmd)
]
parse_cmd = " ".join(parse_cmd)
subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)

df_times = pd.read_csv('times.csv')
df_times = pd.read_csv("times.csv")
df_times.index = list(model_configs.keys())
a=df_times[['FusedAttention Kernels (fwd+bwd)',
'FlashAttention Kernels (fwd+bwd)',
'Fused vs Flash Kernels Speedup (fwd+bwd)']]
a.columns = ['cuDNN fwd+bwd (ms)', 'flash-attn fwd+bwd (ms)', 'cuDNN vs flash speedup']
a = df_times[
[
"FusedAttention Kernels (fwd+bwd)",
"FlashAttention Kernels (fwd+bwd)",
"Fused vs Flash Kernels Speedup (fwd+bwd)",
]
]
a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"]
print()
print(a)


if __name__ == "__main__":
main()
Loading

0 comments on commit 9416519

Please sign in to comment.