diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py old mode 100644 new mode 100755 index d36caaf61952..ece254119d80 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -1166,7 +1166,8 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, ]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('use_bias', [True]) -def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16): +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): torch.manual_seed(20) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) @@ -1174,7 +1175,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor if causal: input_metadata.need_causal() if use_bias: - bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda") + bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda") input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) else: bias = None @@ -1197,7 +1198,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor # this by converting the NaNs to 0s, which is what they should be out of the softmax. nan_mask = torch.isnan(p) p[nan_mask == 1] = 0 - ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v) # compare torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) diff --git a/scripts b/scripts new file mode 160000 index 000000000000..963bb7524376 --- /dev/null +++ b/scripts @@ -0,0 +1 @@ +Subproject commit 963bb752437693d29d7f62f4ef8acbed8167f71a