|
28 | 28 | from torch.export import Dim, export, export_for_training
|
29 | 29 | from torch.testing import FileCheck
|
30 | 30 | from torch.testing._internal import common_utils
|
31 |
| -from torch.testing._internal.common_cuda import SM80OrLater, SM90OrLater |
32 |
| -from torch.testing._internal.common_device_type import skipCUDAIf |
| 31 | +from torch.testing._internal.common_cuda import ( |
| 32 | + SM80OrLater, |
| 33 | + SM90OrLater, |
| 34 | + PLATFORM_SUPPORTS_FLASH_ATTENTION |
| 35 | +) |
| 36 | +from torch.testing._internal.common_device_type import ( |
| 37 | + _has_sufficient_memory, |
| 38 | + skipCUDAIf, |
| 39 | +) |
33 | 40 | from torch.testing._internal.common_quantization import (
|
34 | 41 | skip_if_no_torchvision,
|
35 | 42 | skipIfNoFBGEMM,
|
|
41 | 48 | IS_MACOS,
|
42 | 49 | IS_WINDOWS,
|
43 | 50 | skipIfRocm,
|
44 |
| - skipIfRocmArch, |
45 | 51 | skipIfXpu,
|
46 | 52 | TEST_WITH_ROCM,
|
47 |
| - NAVI32_ARCH, |
48 | 53 | )
|
49 | 54 | from torch.testing._internal.inductor_utils import GPU_TYPE
|
50 | 55 | from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
@@ -924,10 +929,9 @@ def forward(self, q, k, v):
|
924 | 929 | )
|
925 | 930 | self.check_model(Model(), example_inputs)
|
926 | 931 |
|
927 |
| - # Eager mode produces incorrect tensor values for navi32 during this test |
928 |
| - @skipIfRocmArch(NAVI32_ARCH) |
929 | 932 | @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
|
930 | 933 | @unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
|
| 934 | + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") |
931 | 935 | def test_sdpa_2(self):
|
932 | 936 | class Model(torch.nn.Module):
|
933 | 937 | def __init__(self) -> None:
|
@@ -1034,9 +1038,8 @@ def forward(self, x, y):
|
1034 | 1038 | )
|
1035 | 1039 | self.check_model(Repro(), example_inputs)
|
1036 | 1040 |
|
1037 |
| - @skipIfRocmArch(NAVI32_ARCH) |
1038 |
| - # SDPA is not supported on navi32 arch |
1039 | 1041 | @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
|
| 1042 | + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") |
1040 | 1043 | def test_fallback_kernel_with_symexpr_output(self):
|
1041 | 1044 | if self.device != GPU_TYPE:
|
1042 | 1045 | raise unittest.SkipTest("requires GPU")
|
@@ -3031,8 +3034,7 @@ def grid(meta):
|
3031 | 3034 | dynamic_shapes=dynamic_shapes,
|
3032 | 3035 | )
|
3033 | 3036 |
|
3034 |
| - @skipIfRocmArch(NAVI32_ARCH) |
3035 |
| - # SDPA is not supported on navi32 arch |
| 3037 | + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") |
3036 | 3038 | def test_scaled_dot_product_efficient_attention(self):
|
3037 | 3039 | if self.device != GPU_TYPE:
|
3038 | 3040 | raise unittest.SkipTest("requires GPU")
|
|
0 commit comments