diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index a94d7f272df5..ae3ddadf3f9c 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -28,7 +28,11 @@ from torch.export import Dim, export, export_for_training from torch.testing import FileCheck from torch.testing._internal import common_utils -from torch.testing._internal.common_cuda import SM80OrLater, SM90OrLater +from torch.testing._internal.common_cuda import ( + SM80OrLater, + SM90OrLater, + PLATFORM_SUPPORTS_FLASH_ATTENTION +) from torch.testing._internal.common_device_type import ( _has_sufficient_memory, skipCUDAIf, @@ -44,10 +48,8 @@ IS_MACOS, IS_WINDOWS, skipIfRocm, - skipIfRocmArch, skipIfXpu, TEST_WITH_ROCM, - NAVI32_ARCH, ) from torch.testing._internal.inductor_utils import GPU_TYPE from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test @@ -927,10 +929,9 @@ def forward(self, q, k, v): ) self.check_model(Model(), example_inputs) - # Eager mode produces incorrect tensor values for navi32 during this test - @skipIfRocmArch(NAVI32_ARCH) @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") @unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+") + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") def test_sdpa_2(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -1037,9 +1038,8 @@ def forward(self, x, y): ) self.check_model(Repro(), example_inputs) - @skipIfRocmArch(NAVI32_ARCH) - # SDPA is not supported on navi32 arch @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") def test_fallback_kernel_with_symexpr_output(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") @@ -3031,8 +3031,7 @@ def grid(meta): dynamic_shapes=dynamic_shapes, ) - @skipIfRocmArch(NAVI32_ARCH) - # SDPA is not supported on navi32 arch + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") def test_scaled_dot_product_efficient_attention(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 20edb817a4b6..c248411cf8e1 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1361,16 +1361,6 @@ def printErrors(self) -> None: IS_ARM64 = platform.machine() in ('arm64', 'aarch64') IS_S390X = platform.machine() == "s390x" -NAVI32_ARCH = "gfx1101" - -def is_navi_arch(): - if torch.cuda.is_available(): - prop = torch.cuda.get_device_properties(0) - gfx_arch = prop.gcnArchName.split(":")[0] - if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]: - return True - return False - def is_avx512_vnni_supported(): if sys.platform != 'linux': return False