Skip to content

[rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions #2100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
from torch.export import Dim, export, export_for_training
from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import SM80OrLater, SM90OrLater
from torch.testing._internal.common_cuda import (
SM80OrLater,
SM90OrLater,
PLATFORM_SUPPORTS_FLASH_ATTENTION
)
from torch.testing._internal.common_device_type import (
_has_sufficient_memory,
skipCUDAIf,
Expand All @@ -44,10 +48,8 @@
IS_MACOS,
IS_WINDOWS,
skipIfRocm,
skipIfRocmArch,
skipIfXpu,
TEST_WITH_ROCM,
NAVI32_ARCH,
)
from torch.testing._internal.inductor_utils import GPU_TYPE
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
Expand Down Expand Up @@ -927,10 +929,9 @@ def forward(self, q, k, v):
)
self.check_model(Model(), example_inputs)

# Eager mode produces incorrect tensor values for navi32 during this test
@skipIfRocmArch(NAVI32_ARCH)
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this test fail when using math backend on CUDA as well?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I observed that if we force math backend in MI300 and even at A100 for exactly this test, it fails.
see: https://github.com/ROCm/frameworks-internal/issues/11952
there you can also find an upstream pytorch issue I filed..

But the logic here is clear:

  • If FA (goes into AOTriton) is not supported, then code goes into math backend and fails and there's nothing we can do about it (until upstream fixes math backend) -> must skip the test
  • If FA is enabled (this depends on the architecture and/or if TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL is set, see:
    def evaluate_platform_supports_flash_attention():

def test_sdpa_2(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
Expand Down Expand Up @@ -1037,9 +1038,8 @@ def forward(self, x, y):
)
self.check_model(Repro(), example_inputs)

@skipIfRocmArch(NAVI32_ARCH)
# SDPA is not supported on navi32 arch
@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
def test_fallback_kernel_with_symexpr_output(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
Expand Down Expand Up @@ -3031,8 +3031,7 @@ def grid(meta):
dynamic_shapes=dynamic_shapes,
)

@skipIfRocmArch(NAVI32_ARCH)
# SDPA is not supported on navi32 arch
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
def test_scaled_dot_product_efficient_attention(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
Expand Down
10 changes: 0 additions & 10 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,16 +1361,6 @@ def printErrors(self) -> None:
IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
IS_S390X = platform.machine() == "s390x"

NAVI32_ARCH = "gfx1101"

def is_navi_arch():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you have checked nobody is using this function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it was introduced by me in previous cherry-pick

if torch.cuda.is_available():
prop = torch.cuda.get_device_properties(0)
gfx_arch = prop.gcnArchName.split(":")[0]
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
return True
return False

def is_avx512_vnni_supported():
if sys.platform != 'linux':
return False
Expand Down