Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge improvements of 0.7.1b release into main #46

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bindings/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ if(AOTRITON_BUILD_FOR_TUNING)
else(AOTRITON_BUILD_FOR_TUNING)
target_compile_definitions(pyaotriton PRIVATE -DAOTRITON_BUILD_FOR_TUNING=0)
endif(AOTRITON_BUILD_FOR_TUNING)

set_target_properties(pyaotriton PROPERTIES INSTALL_RPATH "$ORIGIN")
include(GNUInstallDirs)
install(TARGETS pyaotriton LIBRARY DESTINATION lib)
1 change: 1 addition & 0 deletions test/mptune/flash/db_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def constrcut_inputs(self, request):
head_dim_rounded = max(16, head_dim_rounded)
inputs = {
'Q_dtype': str(dtype),
'BATCH' : BATCH,
'N_HEADS': N_HEADS,
'D_HEAD': D_HEAD,
'max_seqlen_q': seqlen_q,
Expand Down
30 changes: 24 additions & 6 deletions test/performance_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# Copyright © 2023-2024 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

import os
import pytest
import torch

import triton
from collections import defaultdict
from attn_torch_function import attention, AttentionExtraArgs

try:
Expand All @@ -19,23 +21,36 @@
except BaseException:
FLASH_VER = None
HAS_FLASH = FLASH_VER is not None
USE_TFLOPS = bool(int(os.getenv('USE_TFLOPS', default='1')))
print(f'{USE_TFLOPS=}')

BATCH, N_HEADS, N_CTX, D_HEAD = 4, 64, 4096, 64
d_heads = os.getenv('D_HEADS', default='64,128')
d_heads = list(map(lambda x: int(x), d_heads.split(',')))

n_ctx = os.getenv('N_CTX', default=list(range(10, 14)))
if isinstance(n_ctx, str):
n_ctx = map(lambda x: int(x), n_ctx.split(','))
X_VALS = list(map(lambda x: 2 ** x, n_ctx))
print(f'{X_VALS=}')

BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# BATCH, N_HEADS, N_CTX, D_HEAD = 512, 32, 512, 64
# vary seq length for fixed head and batch=4
configs = []
for mode in ['bwd']:
# for causal in [False, True]:
for causal in [False]:
for D_HEAD in [64, 128]:
for D_HEAD in d_heads:
configs.append(triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=list(X_VALS),
# x_vals=[2**i for i in range(10, 15)],
x_vals=[2**13],
# x_vals=[2**13],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
line_names=['Triton(TFLOPS)' if USE_TFLOPS else 'Triton(ms)'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
ylabel='TFLOPS' if USE_TFLOPS else 'ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
args={
'H': N_HEADS,
Expand Down Expand Up @@ -97,7 +112,10 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
total_flops *= 0.5
if mode == 'bwd':
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
return total_flops / ms * 1e-9
if USE_TFLOPS:
return total_flops / ms * 1e-9
else:
return ms


# only works on post-Ampere GPUs right now
Expand Down
29 changes: 23 additions & 6 deletions test/performance_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch

import os
import triton
from attn_torch_function import attention, AttentionExtraArgs

Expand All @@ -19,23 +20,35 @@
except BaseException:
FLASH_VER = None
HAS_FLASH = FLASH_VER is not None
USE_TFLOPS = bool(int(os.getenv('USE_TFLOPS', default='1')))
print(f'{USE_TFLOPS=}')

BATCH, N_HEADS, N_CTX, D_HEAD = 8, 64, 4096, 64
d_heads = os.getenv('D_HEADS', default='64,128')
d_heads = list(map(lambda x: int(x), d_heads.split(',')))

n_ctx = os.getenv('N_CTX', default=list(range(10, 14)))
if isinstance(n_ctx, str):
n_ctx = map(lambda x: int(x), n_ctx.split(','))
X_VALS = list(map(lambda x: 2 ** x, n_ctx))
print(f'{X_VALS=}')

BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd']:
# for causal in [False, True]:
for causal in [False]:
for D_HEAD in [64, 128]:
for D_HEAD in d_heads:
configs.append(triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=list(X_VALS),
# x_vals=[2**i for i in range(10, 15)],
x_vals=[2**13],
# x_vals=[2**13],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
line_names=['Triton(TFLOPS)' if USE_TFLOPS else 'Triton(ms)'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
ylabel='TFLOPS' if USE_TFLOPS else 'ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}',
args={
'H': N_HEADS,
Expand All @@ -47,6 +60,7 @@
})
)


@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"):
print(f"{N_CTX=}")
Expand Down Expand Up @@ -97,7 +111,10 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
total_flops *= 0.5
if mode == 'bwd':
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
return total_flops / ms * 1e-9
if USE_TFLOPS:
return total_flops / ms * 1e-9
else:
return ms


# only works on post-Ampere GPUs right now
Expand Down
2 changes: 1 addition & 1 deletion test/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale
is_allclose, adiff, grads_allclose, grads_adiff, tfts = ctx.validate_with_reference(tri_out, ctx.dout_tensors, return_target_fudge_factors=True)
ctx.display_validation_results(tri_out, is_allclose, adiff, grads_allclose, grads_adiff)

assert is_allclose, f'Forward pass {is_allclose=}'
assert is_allclose, f'Forward pass {is_allclose=} {tfts=}'
dq_allclose, dk_allclose, dv_allclose, db_allclose = grads_allclose
tri_dq, tri_dk, tri_dv, tri_db = ctx.dout_tensors
ref_dq, ref_dk, ref_dv, ref_db = ctx.dref_tensors
Expand Down
8 changes: 6 additions & 2 deletions test/tune_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def next_index(self, kig: KernelIndexProress) -> bool:
class FlashTunerSource(MonadService):
def gen(self):
a = self._args
yield from itertools.product(a.batch, a.n_heads, a.d_head, a.seqlen_q, a.seqlen_k, a.causal, a.sm_scale, a.dropout_p, a.return_encoded_softmax, a.dtype, a.bias_type)
for tup in itertools.product(a.batch, a.n_heads, a.d_head, a.seqlen_q, a.seqlen_k, a.causal, a.sm_scale, a.dropout_p, a.return_encoded_softmax, a.dtype, a.bias_type):
BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, return_encoded_softmax, dtype, bias_type = tup
if seqlen_q > 4096 and seqlen_k > 4096:
BATCH = 2
yield (BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, return_encoded_softmax, dtype, bias_type)

def init(self, _):
pass
Expand Down Expand Up @@ -227,7 +231,7 @@ def make_ui(manager : TunerManager):

def parse():
p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('--batch', type=int, nargs=1, default=[1], help='(Not a functional) Batch size.')
p.add_argument('--batch', type=int, nargs=1, default=[8], help='(Not a functional) Batch size.')
p.add_argument('--n_heads', type=int, nargs=1, default=[12], help='(Not a functional) Number of heads')
p.add_argument('--sm_scale', type=float, nargs=1, default=[1.2], help='(Not a functional) Softmax Scale')
p.add_argument('--return_encoded_softmax', type=bool, default=[False],
Expand Down
24 changes: 13 additions & 11 deletions tritonsrc/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def __init__(self, BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, dtype,

# Maximal value from tune_flash.py and table_tool.py --fudge_factor_tolerance 5.0
# Note: Navi 3x is experimental and YMMV
self.OUT_FUDGE_FACTOR = 6.0 if dtype != torch.float32 else 10.0
self.OUT_FUDGE_FACTOR = 3.0
if dtype == torch.float32:
self.OUT_FUDGE_FACTOR = 12.0

'''
Create Tensors that will be kept b/w forward and backward pass
Expand Down Expand Up @@ -245,18 +247,17 @@ def _compute_fudge_factors(self, p : SdpaParams):

# Maximal value from tune_flash.py and table_tool.py --fudge_factor_tolerance 5.0
# Note: Navi 3x is experimental and YMMV
query_fudge_factor = 180.0
key_fudge_factor = 16.0
value_fudge_factor = 32.0
query_fudge_factor = 32.0
key_fudge_factor = 48.0
value_fudge_factor = 16.0
bias_fudge_factor = 16.0
print(f'{torch.cuda.get_device_properties(0).gcnArchName=}')
# print(f'{torch.cuda.get_device_properties(0).gcnArchName=}')
if torch.version.hip:
if 'gfx90a' in torch.cuda.get_device_properties(0).gcnArchName:
key_fudge_factor = max(8.0, (seqlen_k + seqlen_q) / 16.0) # TODO: Check why
bias_fudge_factor = 32.0
query_fudge_factor = 80.0
if dtype == torch.float32:
key_fudge_factor = 180.0
bias_fudge_factor = 32.0
bias_fudge_factor = 24.0
return (query_fudge_factor, key_fudge_factor, value_fudge_factor, bias_fudge_factor)

@staticmethod
Expand Down Expand Up @@ -321,9 +322,10 @@ def lmax(x) -> float:
atol = default_atol[torch.float32]
threshold = max(atol, ref_error * fudge_factor)
valid = test_error <= threshold
tft = test_error / ref_error if ref_error > atol else 1.0
tft = test_error / ref_error if ref_error * fudge_factor > atol else 1.0
if not valid:
print(f'For {tname}, Consider bump fudge_factor to {tft} = {test_error=} / {ref_error=}. So that {test_error=} < max({atol=}, {ref_error=} * {tft=})')
pass
# print(f'For {tname}, Consider bump fudge_factor to {tft} = {test_error=} / {ref_error=}. So that {test_error=} < max({atol=}, {ref_error=} * {tft=})')
if return_target_fudge_factors:
return valid, max_adiff, tft
else:
Expand Down Expand Up @@ -351,7 +353,7 @@ def validate_with_reference(self, out, grads,
return out_allclose, out_adiff, [], []
grads_allclose = []
grads_adiff = []
print(f'using {self.fudge_factors=}')
# print(f'using {self.fudge_factors=}')
for grad, ref, lp_ref, fudge_factor, tname in zip(grads, self.dref_tensors, self.lp_dref_tensors, self.fudge_factors, self.TENSOR_NAMES):
allclose, adiff, tft = self._validate(grad,
ref,
Expand Down
1 change: 0 additions & 1 deletion tritonsrc/attn_torch_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
bwd_preprocess as bare_bwd_preprocess,
bwd_kernel_dk_dv as bare_bwd_kernel_dk_dv,
bwd_kernel_dq as bare_bwd_kernel_dq,
attn_bwd as bare_attn_bwd,
)
from tuned_bwd import (
tuned_bwd_kernel_dk_dv,
Expand Down
84 changes: 0 additions & 84 deletions tritonsrc/bwd_inner_dkdv.py

This file was deleted.

64 changes: 0 additions & 64 deletions tritonsrc/bwd_inner_dq.py

This file was deleted.

Loading