Skip to content

Commit

Permalink
dtype params
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Jul 16, 2024
1 parent 3e1d75c commit 46481c2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/perf-kernels/flash-attention.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1166,15 +1166,16 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout,
])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('use_bias', [True])
def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16):
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype):
torch.manual_seed(20)
sm_scale = D_HEAD**-0.5
input_metadata = MetaData(sm_scale=sm_scale)
q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout='bhsd')
if causal:
input_metadata.need_causal()
if use_bias:
bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda")
bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda")
input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K)
else:
bias = None
Expand All @@ -1197,7 +1198,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor
# this by converting the NaNs to 0s, which is what they should be out of the softmax.
nan_mask = torch.isnan(p)
p[nan_mask == 1] = 0
ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v)
ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v)
# compare
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)

Expand Down
1 change: 1 addition & 0 deletions scripts
Submodule scripts added at 963bb7

0 comments on commit 46481c2

Please sign in to comment.