diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index fe8d34fc656..114f997c44b 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -2208,9 +2208,12 @@ def test_psnr(self): bias=(None, True), scale=(None, 1.0), mask_and_is_causal=((None, False), (True, False), (None, True)), + flash_attention=(True, False), ) ) - def test_dot_product_attention(self, bias, scale, mask_and_is_causal): + def test_dot_product_attention( + self, bias, scale, mask_and_is_causal, flash_attention + ): mask, is_causal = mask_and_is_causal query_shape = (2, 3, 4, 5) key_shape = (2, 6, 4, 5) @@ -2232,6 +2235,10 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal): mask_shape ) + if flash_attention and backend.backend() in ["torch", "tensorflow", "numpy"]: + self.skipTest("Not supported in TF and NumPy and supported for " + "PyTorch with specific requirements.") + expected = _dot_product_attention( query, key, @@ -2249,6 +2256,7 @@ def test_dot_product_attention(self, bias, scale, mask_and_is_causal): mask=mask, scale=scale, is_causal=is_causal, + flash_attention=flash_attention, ) self.assertAllClose(outputs, expected)