From 0af8d0c5bbab8161be1573b84f2064b8bad52bc8 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Thu, 12 Dec 2024 13:38:43 +0000 Subject: [PATCH] [sharktank] Mark test expected to fail Marks `testExportNondecomposed` as expected to fail if running with torch>=2.4.0, see #684. --- sharktank/tests/layers/paged_llama_attention_block_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sharktank/tests/layers/paged_llama_attention_block_test.py b/sharktank/tests/layers/paged_llama_attention_block_test.py index 63251c5a9..d5cb6863d 100644 --- a/sharktank/tests/layers/paged_llama_attention_block_test.py +++ b/sharktank/tests/layers/paged_llama_attention_block_test.py @@ -4,6 +4,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import pytest + import logging logging.basicConfig(level=logging.DEBUG) @@ -118,6 +120,10 @@ def forward(self, h, seq_block_ids, cache_state): asm = str(output.mlir_module) self.assertNotIn("scaled_dot_product_attention", asm) + @pytest.mark.xfail( + torch.__version__ >= (2, 4), + reason="https://github.com/nod-ai/shark-ai/issues/684", + ) def testExportNondecomposed(self): dtype = torch.float32