diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 449e89f93e29..6ad01cbadc83 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -17,16 +17,16 @@ import triton import triton.language as tl -torch_dtype:tl.constexpr = torch.float16 -TORCH_HAS_FP8 = False -TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2') -TORCH_HAS_FP8E5FNUZ = hasattr(torch, 'float8_e5m2fnuz') -if TORCH_HAS_FP8E5: - torch_dtype:tl.constexpr = torch.float8_e5m2 - TORCH_HAS_FP8 = True -if TORCH_HAS_FP8E5FNUZ: - torch_dtype:tl.constexpr = torch.float8_e5m2fnuz - TORCH_HAS_FP8 = True +# Pick the fp8 data type + +# AMD E4M3B8 +# Note: When picking this f8 data type, scaling is required when using f8 +# for the second gemm +#TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') + +# AMD E5M2B16 +TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') + @triton.jit def _attn_fwd_inner(acc, l_i, m_i, q, @@ -555,7 +555,7 @@ def forward(ctx, q, k, v, causal, sm_scale): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} - o = torch.empty_like(q) + o = torch.empty_like(q, dtype=v.dtype) if torch.version.hip is None: BLOCK_M = 128 BLOCK_N = 64 if Lk <= 64 else 32 @@ -642,26 +642,33 @@ def backward(ctx, do): attention = _attention.apply - -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', - [(4, 48, 1024, 64), - (4, 48, 2048, 64), - (4, 48, 4096, 64), - (4, 48, 1024, 128), - (4, 48, 2048, 128), - (4, 48, 4096, 128), - #(4, 48, 8192, 64), - #(4, 48, 16384, 64) - ]) +name_to_torch_types = { + 'fp16': torch.float16, +} + +if TORCH_HAS_FP8E5B16: + name_to_torch_types['fp8'] = torch.float8_e5m2fnuz + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, dtype', +[ (*shape, dtype) + for shape in [(4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + (4, 48, 1024, 128), + (4, 48, 2048, 128), + (4, 48, 4096, 128)] + for dtype in ['fp16', 'fp8']]) @pytest.mark.parametrize('causal', [False, True]) -def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): +def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype): + if dtype == 'fp8' and not TORCH_HAS_FP8E5B16: + pytest.skip("fp8 not supported") torch.manual_seed(20) - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - if TORCH_HAS_FP8: - q = q.to(torch_dtype) - k = k.to(torch_dtype) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + + q = q.to(name_to_torch_types[dtype]) + k = k.to(name_to_torch_types[dtype]) sm_scale = 0.5 dout = torch.randn_like(q, dtype=torch.float16) # reference implementation @@ -674,7 +681,9 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): # triton implementation tri_out = attention(q, k, v, causal, sm_scale) # compare - torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) + atol = 1.4e-1 if dtype == 'fp8' else 1e-2 + rtol = 1e-2 if dtype == 'fp8' else 0 + torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=rtol) @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', @@ -775,9 +784,6 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - if mode == "fwd" and TORCH_HAS_FP8: - q = q.to(torch_dtype) - k = k.to(torch_dtype) sm_scale = D_HEAD ** -0.5 fn = lambda: attention(q, k, v, causal, sm_scale) if mode == 'bwd':