Skip to content

Commit dc07657

Browse files
iupaikov-amdAmdSampsa
authored andcommitted
[rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions (#2100)
It's an internal change, removes unnecessary arch mentions and generalizes conditions. Needs to be cherry-picked into rocm6.5_internal_testing and release/2.7 .and now also into release/2.6
1 parent 684f6f2 commit dc07657

File tree

2 files changed

+12
-20
lines changed

2 files changed

+12
-20
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,15 @@
2828
from torch.export import Dim, export, export_for_training
2929
from torch.testing import FileCheck
3030
from torch.testing._internal import common_utils
31-
from torch.testing._internal.common_cuda import SM80OrLater, SM90OrLater
32-
from torch.testing._internal.common_device_type import skipCUDAIf
31+
from torch.testing._internal.common_cuda import (
32+
SM80OrLater,
33+
SM90OrLater,
34+
PLATFORM_SUPPORTS_FLASH_ATTENTION
35+
)
36+
from torch.testing._internal.common_device_type import (
37+
_has_sufficient_memory,
38+
skipCUDAIf,
39+
)
3340
from torch.testing._internal.common_quantization import (
3441
skip_if_no_torchvision,
3542
skipIfNoFBGEMM,
@@ -41,10 +48,8 @@
4148
IS_MACOS,
4249
IS_WINDOWS,
4350
skipIfRocm,
44-
skipIfRocmArch,
4551
skipIfXpu,
4652
TEST_WITH_ROCM,
47-
NAVI32_ARCH,
4853
)
4954
from torch.testing._internal.inductor_utils import GPU_TYPE
5055
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
@@ -924,10 +929,9 @@ def forward(self, q, k, v):
924929
)
925930
self.check_model(Model(), example_inputs)
926931

927-
# Eager mode produces incorrect tensor values for navi32 during this test
928-
@skipIfRocmArch(NAVI32_ARCH)
929932
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
930933
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
934+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
931935
def test_sdpa_2(self):
932936
class Model(torch.nn.Module):
933937
def __init__(self) -> None:
@@ -1034,9 +1038,8 @@ def forward(self, x, y):
10341038
)
10351039
self.check_model(Repro(), example_inputs)
10361040

1037-
@skipIfRocmArch(NAVI32_ARCH)
1038-
# SDPA is not supported on navi32 arch
10391041
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
1042+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
10401043
def test_fallback_kernel_with_symexpr_output(self):
10411044
if self.device != GPU_TYPE:
10421045
raise unittest.SkipTest("requires GPU")
@@ -3031,8 +3034,7 @@ def grid(meta):
30313034
dynamic_shapes=dynamic_shapes,
30323035
)
30333036

3034-
@skipIfRocmArch(NAVI32_ARCH)
3035-
# SDPA is not supported on navi32 arch
3037+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
30363038
def test_scaled_dot_product_efficient_attention(self):
30373039
if self.device != GPU_TYPE:
30383040
raise unittest.SkipTest("requires GPU")

torch/testing/_internal/common_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,16 +1370,6 @@ def printErrors(self) -> None:
13701370
IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
13711371
IS_S390X = platform.machine() == "s390x"
13721372

1373-
NAVI32_ARCH = "gfx1101"
1374-
1375-
def is_navi_arch():
1376-
if torch.cuda.is_available():
1377-
prop = torch.cuda.get_device_properties(0)
1378-
gfx_arch = prop.gcnArchName.split(":")[0]
1379-
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1380-
return True
1381-
return False
1382-
13831373
def is_avx512_vnni_supported():
13841374
if sys.platform != 'linux':
13851375
return False

0 commit comments

Comments
 (0)