Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FA tutorial #485

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 39 additions & 33 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
import triton
import triton.language as tl

torch_dtype:tl.constexpr = torch.float16
TORCH_HAS_FP8 = False
TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2')
TORCH_HAS_FP8E5FNUZ = hasattr(torch, 'float8_e5m2fnuz')
if TORCH_HAS_FP8E5:
torch_dtype:tl.constexpr = torch.float8_e5m2
TORCH_HAS_FP8 = True
if TORCH_HAS_FP8E5FNUZ:
torch_dtype:tl.constexpr = torch.float8_e5m2fnuz
TORCH_HAS_FP8 = True
# Pick the fp8 data type

# AMD E4M3B8
# Note: When picking this f8 data type, scaling is required when using f8
# for the second gemm
#TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')

# AMD E5M2B16
TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz')


@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q,
Expand Down Expand Up @@ -555,7 +555,7 @@ def forward(ctx, q, k, v, causal, sm_scale):
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
o = torch.empty_like(q, dtype=v.dtype)
if torch.version.hip is None:
BLOCK_M = 128
BLOCK_N = 64 if Lk <= 64 else 32
Expand Down Expand Up @@ -642,26 +642,33 @@ def backward(ctx, do):

attention = _attention.apply


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
[(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
(4, 48, 1024, 128),
(4, 48, 2048, 128),
(4, 48, 4096, 128),
#(4, 48, 8192, 64),
#(4, 48, 16384, 64)
])
name_to_torch_types = {
'fp16': torch.float16,
}

if TORCH_HAS_FP8E5B16:
name_to_torch_types['fp8'] = torch.float8_e5m2fnuz

@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, dtype',
[ (*shape, dtype)
for shape in [(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
(4, 48, 1024, 128),
(4, 48, 2048, 128),
(4, 48, 4096, 128)]
for dtype in ['fp16', 'fp8']])
@pytest.mark.parametrize('causal', [False, True])
def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype):
if dtype == 'fp8' and not TORCH_HAS_FP8E5B16:
pytest.skip("fp8 not supported")
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
if TORCH_HAS_FP8:
q = q.to(torch_dtype)
k = k.to(torch_dtype)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_()

q = q.to(name_to_torch_types[dtype])
k = k.to(name_to_torch_types[dtype])
sm_scale = 0.5
dout = torch.randn_like(q, dtype=torch.float16)
# reference implementation
Expand All @@ -674,7 +681,9 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
# triton implementation
tri_out = attention(q, k, v, causal, sm_scale)
# compare
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2)
atol = 1.4e-1 if dtype == 'fp8' else 1e-2
rtol = 1e-2 if dtype == 'fp8' else 0
torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=rtol)


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
Expand Down Expand Up @@ -775,9 +784,6 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if mode == "fwd" and TORCH_HAS_FP8:
q = q.to(torch_dtype)
k = k.to(torch_dtype)
sm_scale = D_HEAD ** -0.5
fn = lambda: attention(q, k, v, causal, sm_scale)
if mode == 'bwd':
Expand Down
Loading