Skip to content

Commit

Permalink
Parameterize D_HEAD when benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglx13 committed Sep 28, 2023
1 parent cc03a05 commit a2e55a7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16):
FLASH_VER = None
HAS_FLASH = FLASH_VER is not None

BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
BATCH, N_HEADS, N_CTX = 4, 48, 4096
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
Expand All @@ -664,9 +664,9 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16):
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode, 'causal': causal}
) for mode in ['fwd', 'bwd'] for causal in [False, True]]
) for mode in ['fwd', 'bwd'] for D_HEAD in [64, 128] for causal in [False, True]]


@triton.testing.perf_report(configs)
Expand Down

0 comments on commit a2e55a7

Please sign in to comment.