Skip to content

Commit 03c4453

Browse files
committed
[release/2.6][SWDEV-523736] Fix some unittests for Navi4x
* Most part of testcases work properly on Navi48(gfx1201) with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1, in this commit enable it for this arch. No support of AOTriton currently for Navci44(gfx1200), so these testcases just skipped. * test_qconv2d_int8_mixed_bf16 skipped because it was originally skipped in pytorch#112550 but later lost. * test_sac_ilp_case1 skipped as per SWDEV-509011 * test_distributed_checkpoint_state_dict_type[0-1]_cuda fixed bug with arguments.
1 parent eb37e58 commit 03c4453

File tree

11 files changed

+51
-15
lines changed

11 files changed

+51
-15
lines changed

test/distributed/_tools/test_sac_ilp.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@
1919
sac_milp,
2020
)
2121
from torch.testing._internal.common_cuda import TEST_CUDA
22-
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
22+
from torch.testing._internal.common_utils import (
23+
run_tests,
24+
skipIfTorchDynamo,
25+
TestCase,
26+
skipIfRocmArch,
27+
NAVI4_ARCH,
28+
)
29+
2330
from torch.testing._internal.distributed._tensor.common_dtensor import (
2431
ModelArgs,
2532
Transformer,
@@ -131,6 +138,7 @@ def _collect_module_info_with_fake_tensor_mode(self) -> ModuleInfo:
131138

132139
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
133140
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
141+
@skipIfRocmArch(NAVI4_ARCH)
134142
def test_sac_ilp_case1(self):
135143
"""
136144
This is a case where the memory budget is either binding or too tight,

test/distributed/fsdp/test_distributed_checkpoint.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,22 @@
3030
)
3131
sys.exit(0)
3232

33-
34-
_DISTRIBUTED_STATE_DICT_IMPLS = {
35-
StateDictType.LOCAL_STATE_DICT,
36-
StateDictType.SHARDED_STATE_DICT,
37-
}
38-
39-
4033
class TestDistributedCheckpoint(FSDPTest):
4134
@property
4235
def world_size(self):
4336
return 2
4437

4538
@skip_if_lt_x_gpu(2)
4639
@with_temp_dir
47-
@parametrize("state_dict_type", _DISTRIBUTED_STATE_DICT_IMPLS)
48-
def test_distributed_checkpoint(self, state_dict_type) -> None:
40+
def test_distributed_checkpoint_state_dict_type0(self) -> None:
41+
self._test_distributed_checkpoint(StateDictType.LOCAL_STATE_DICT)
42+
43+
@skip_if_lt_x_gpu(2)
44+
@with_temp_dir
45+
def test_distributed_checkpoint_state_dict_type1(self) -> None:
46+
self._test_distributed_checkpoint(StateDictType.SHARDED_STATE_DICT)
47+
48+
def _test_distributed_checkpoint(self, state_dict_type) -> None:
4949
with enable_wrap(wrapper_cls=FSDP):
5050
torch.manual_seed(100)
5151
model = wrap(SkipModel(double_nest=True))

test/dynamo/test_activation_checkpointing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch._higher_order_ops.wrap import tag_activation_checkpoint
2121
from torch.testing._internal.common_cuda import (
2222
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
23+
PLATFORM_SUPPORTS_FLASH_ATTENTION,
2324
SM90OrLater,
2425
)
2526
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm
@@ -1279,6 +1280,7 @@ def fn(x, ys):
12791280
self.assertEqual(ref, res)
12801281

12811282
@requires_cuda
1283+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
12821284
def test_pattern_matcher(self):
12831285
# Check that the sdpa op is recomputed in the backward graph
12841286
# tests percolate_tags

test/dynamo/test_repros.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
parametrize,
4848
skipIfWindows,
4949
TEST_WITH_ROCM,
50+
skipIfRocmArch,
51+
NAVI44_ARCH,
5052
)
5153
from torch.testing._internal.two_tensor import TwoTensor
5254

@@ -6408,6 +6410,7 @@ def fn(x):
64086410
self.assertEqual(fn(inp), opt_fn(inp))
64096411

64106412
@requires_cuda
6413+
@skipIfRocmArch(NAVI44_ARCH)
64116414
def test_sdpa_dynamic_shapes(self):
64126415
def f(x, s0, s1, s2):
64136416
q = x.view(2, s0, s2, s0)

test/higher_order_ops/test_invoke_subgraph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
TEST_WITH_CROSSREF,
1818
TestCase,
1919
)
20+
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
2021
from torch.testing._internal.inductor_utils import HAS_CUDA
2122

2223

@@ -167,6 +168,7 @@ def fn(x):
167168
self.assertEqual(x.grad, x_clone.grad)
168169

169170
@requires_cuda
171+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
170172
def test_sdpa(self):
171173
@mark_compile_region
172174
def gn(q, k, v):

test/inductor/test_flex_attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
)
3535
from torch.testing import FileCheck
3636
from torch.testing._internal import common_utils
37-
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, TEST_MULTIGPU
37+
from torch.testing._internal.common_cuda import (
38+
PLATFORM_SUPPORTS_BF16,
39+
PLATFORM_SUPPORTS_FLASH_ATTENTION,
40+
TEST_MULTIGPU,
41+
)
3842
from torch.testing._internal.common_device_type import (
3943
flex_attention_supported_platform as supported_platform,
4044
)
@@ -2610,6 +2614,7 @@ def test_kernel_options_argument_is_respected(self):
26102614
FileCheck().check("BLOCK_M : tl.constexpr = 16").run(code[0])
26112615

26122616
@supported_platform
2617+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
26132618
def test_comparison_vs_sdpa(self):
26142619
def causal(score, b, h, q_idx, kv_idx):
26152620
return torch.where(q_idx >= kv_idx, score, -float("inf"))

test/inductor/test_flex_decoding.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import functools
55
from collections import namedtuple
66
from typing import Callable, Optional, Tuple, Union
7-
from unittest import expectedFailure, skipUnless
7+
from unittest import expectedFailure, skipUnless, skipIf
88
from unittest.mock import patch
99

1010
import torch
@@ -21,7 +21,10 @@
2121
)
2222
from torch.testing import FileCheck
2323
from torch.testing._internal import common_utils
24-
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16
24+
from torch.testing._internal.common_cuda import (
25+
PLATFORM_SUPPORTS_BF16,
26+
PLATFORM_SUPPORTS_FLASH_ATTENTION,
27+
)
2528
from torch.testing._internal.common_utils import skipIfRocm
2629
from torch.utils._triton import has_triton
2730

@@ -1342,6 +1345,7 @@ def test_windowed_no_mask_vs_sdpa(self):
13421345
self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
13431346

13441347
@supported_platform
1348+
@skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
13451349
def test_windowed_full_mask_vs_sdpa(self):
13461350
def mask_mod(b, h, q, kv):
13471351
return q + 1000 >= kv
@@ -1361,6 +1365,7 @@ def mask_mod(b, h, q, kv):
13611365
self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
13621366

13631367
@supported_platform
1368+
@skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
13641369
def test_windowed_partial_block_vs_sdpa(self):
13651370
def mask_mod(b, h, q, kv):
13661371
return q + 1000 >= kv
@@ -1376,6 +1381,7 @@ def mask_mod(b, h, q, kv):
13761381
self.run_test_with_call(attention, sdpa_attention, Q_H=16, KV_H=16, Q_S=8)
13771382

13781383
@supported_platform
1384+
@skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
13791385
def test_windowed_no_mask_vs_sdpa_paged_attention(self):
13801386
score_mod = _generate_windowed(1000)
13811387

@@ -1386,6 +1392,7 @@ def test_windowed_no_mask_vs_sdpa_paged_attention(self):
13861392
)
13871393

13881394
@supported_platform
1395+
@skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
13891396
def test_windowed_full_mask_vs_sdpa_paged_attention(self):
13901397
def mask_mod(b, h, q, kv):
13911398
return q + 1000 >= kv
@@ -1397,6 +1404,7 @@ def mask_mod(b, h, q, kv):
13971404
)
13981405

13991406
@supported_platform
1407+
@skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
14001408
def test_windowed_partial_block_vs_sdpa_paged_attention(self):
14011409
def mask_mod(b, h, q, kv):
14021410
return q + 1000 >= kv

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,7 @@ def test_qconv2d_xpu(self):
10351035
@skipIfNoDynamoSupport
10361036
@skipIfNoONEDNNBF16
10371037
@skipIfNoONEDNN
1038+
@skipIfRocm
10381039
def test_qconv2d_int8_mixed_bf16(self):
10391040
r"""
10401041
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.

test/inductor/test_torchinductor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10611,6 +10611,7 @@ def fn(q, k, v):
1061110611
)
1061210612

1061310613
@expectedFailureXPU
10614+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
1061410615
def test_scaled_dot_product_efficient_attention(self):
1061510616
if self.device == "cpu":
1061610617
raise unittest.SkipTest(f"requires {GPU_TYPE}")

torch/testing/_internal/common_cuda.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,19 @@ def evaluate_gfx_arch_within(arch_list):
4343
effective_arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
4444
# gcnArchName can be complicated strings like gfx90a:sramecc+:xnack-
4545
# Hence the matching should be done reversely
46-
return any(arch in effective_arch for arch in arch_list)
46+
result = any(arch in effective_arch for arch in arch_list)
47+
48+
if result and gcn_arch_name == "gfx1201":
49+
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
50+
51+
return result
4752

4853
def CDNA2OrLater():
4954
return evaluate_gfx_arch_within(["gfx90a", "gfx942"])
5055

5156
def evaluate_platform_supports_flash_attention():
5257
if TEST_WITH_ROCM:
53-
arch_list = ["gfx90a", "gfx942", "gfx1100"]
58+
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201"]
5459
return evaluate_gfx_arch_within(arch_list)
5560
if TEST_CUDA:
5661
return not IS_WINDOWS and SM80OrLater

torch/testing/_internal/common_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101", "gfx1200", "gfx1201")
112112
NAVI3_ARCH = ("gfx1100", "gfx1101")
113113
NAVI4_ARCH = ("gfx1200", "gfx1201")
114+
NAVI44_ARCH = "gfx1200"
114115

115116
def is_navi3_arch():
116117
if torch.cuda.is_available():

0 commit comments

Comments
 (0)