Skip to content

Commit

Permalink
added skipping tests for all frameworks except jax
Browse files Browse the repository at this point in the history
  • Loading branch information
hazemessamm committed Oct 7, 2024
1 parent 43938ca commit 738375f
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,9 +2208,12 @@ def test_psnr(self):
bias=(None, True),
scale=(None, 1.0),
mask_and_is_causal=((None, False), (True, False), (None, True)),
flash_attention=(True, False),
)
)
def test_dot_product_attention(self, bias, scale, mask_and_is_causal):
def test_dot_product_attention(
self, bias, scale, mask_and_is_causal, flash_attention
):
mask, is_causal = mask_and_is_causal
query_shape = (2, 3, 4, 5)
key_shape = (2, 6, 4, 5)
Expand All @@ -2232,6 +2235,10 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal):
mask_shape
)

if flash_attention and backend.backend() in ["torch", "tensorflow", "numpy"]:
self.skipTest("Not supported in TF and NumPy and supported for "
"PyTorch with specific requirements.")

expected = _dot_product_attention(
query,
key,
Expand All @@ -2249,6 +2256,7 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal):
mask=mask,
scale=scale,
is_causal=is_causal,
flash_attention=flash_attention,
)
self.assertAllClose(outputs, expected)

Expand Down

0 comments on commit 738375f

Please sign in to comment.