From 5f76c574bad8ace0ede4c278b40de4397c147077 Mon Sep 17 00:00:00 2001 From: iupaikov-amd Date: Mon, 12 May 2025 21:19:07 +0200 Subject: [PATCH 1/4] Cherry-picked commit with merge conflict --- test/inductor/test_aot_inductor.py | 95 +++++++++++++++++++++++++ torch/testing/_internal/common_utils.py | 4 ++ 2 files changed, 99 insertions(+) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 60cffb05d63e9f..6e88b816c7c191 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -24,7 +24,19 @@ from torch.export import Dim, export from torch.testing import FileCheck from torch.testing._internal import common_utils +<<<<<<< HEAD from torch.testing._internal.common_cuda import SM80OrLater, SM90OrLater +======= +from torch.testing._internal.common_cuda import ( + SM80OrLater, + SM90OrLater, + PLATFORM_SUPPORTS_FLASH_ATTENTION +) +from torch.testing._internal.common_device_type import ( + _has_sufficient_memory, + skipCUDAIf, +) +>>>>>>> 4e4e3395e6 ([rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions (#2100)) from torch.testing._internal.common_quantization import ( skip_if_no_torchvision, skipIfNoFBGEMM, @@ -38,6 +50,10 @@ IS_SANDCASTLE, IS_WINDOWS, skipIfRocm, +<<<<<<< HEAD +======= + skipIfXpu, +>>>>>>> 4e4e3395e6 ([rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions (#2100)) TEST_WITH_ROCM, ) from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda @@ -968,6 +984,7 @@ def forward(self, q, k, v): @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") @unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+") + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") def test_sdpa_2(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -1055,6 +1072,80 @@ def forward(self, x, y): ) self.check_model(Repro(), example_inputs) +<<<<<<< HEAD +======= + @config.patch({"triton.autotune_at_compile_time": None}) + def test_stride_with_unbacked_expr(self): + class Repro(torch.nn.Module): + def forward(self, x, y): + u0 = x.item() + torch._check(u0 >= 1) + s0 = y.size(0) + expr = u0 * s0 + sevens = torch.empty_strided( + size=(10, expr, 32), stride=(expr * 32, 32, 1), device=x.device + ).fill_(7) + return sevens * 3 + + example_inputs = ( + torch.scalar_tensor(2, dtype=torch.int, device=self.device), + torch.ones(8, device=self.device), + ) + self.check_model(Repro(), example_inputs) + + @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") + def test_fallback_kernel_with_symexpr_output(self): + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + + class Module(torch.nn.Module): + def forward(self, q, k, v): + q = q.reshape( + q.shape[0], + 2, + q.shape[2] * q.shape[3], + q.shape[1] // 2, + ) + k = k.reshape( + k.shape[0], + 2, + k.shape[2] * k.shape[3], + k.shape[1] // 2, + ) + v = v.reshape( + v.shape[0], + 2, + v.shape[2] * v.shape[3], + v.shape[1] // 2, + ) + + res = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, + k, + v, + ) + return res[0] + + m = Module().to(device=self.device) + tensor_shape = (4, 32, 4, 4) + inputs = ( + torch.randn(tensor_shape, dtype=torch.float16, device=self.device), + torch.randn(tensor_shape, dtype=torch.float16, device=self.device), + torch.randn(tensor_shape, dtype=torch.float16, device=self.device), + ) + + dynamic_shapes = { + "q": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC}, + "k": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC}, + "v": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC}, + } + ep = torch.export.export(m, inputs, dynamic_shapes=dynamic_shapes, strict=False) + path = torch._inductor.aot_compile(ep.module(), inputs) + aot_model = torch._export.aot_load(path, device=self.device) + torch.testing.assert_close(m(*inputs), aot_model(*inputs)) + +>>>>>>> 4e4e3395e6 ([rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions (#2100)) def test_large_grid(self): if self.device != "cuda": raise unittest.SkipTest("requires CUDA") @@ -2838,7 +2929,11 @@ def grid(meta): dynamic_shapes=dynamic_shapes, ) +<<<<<<< HEAD @skipIfRocm # USE_MEM_EFF_ATTENTION was not enabled for build. +======= + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") +>>>>>>> 4e4e3395e6 ([rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions (#2100)) def test_scaled_dot_product_efficient_attention(self): if self.device != "cuda": raise unittest.SkipTest("requires CUDA") diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 8ef0747f2742a5..e74f24615e0910 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1288,6 +1288,10 @@ def printErrors(self) -> None: IS_PPC = platform.machine() == "ppc64le" IS_X86 = platform.machine() in ('x86_64', 'i386') IS_ARM64 = platform.machine() in ('arm64', 'aarch64') +<<<<<<< HEAD +======= +IS_S390X = platform.machine() == "s390x" +>>>>>>> 4e4e3395e6 ([rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions (#2100)) def is_avx512_vnni_supported(): if sys.platform != 'linux': From 21dfa30fcb9708cf2a81a748d3c11031e8756147 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Tue, 10 Jun 2025 13:59:09 +0100 Subject: [PATCH 2/4] Resolve conflicts --- test/inductor/test_aot_inductor.py | 89 +----------------------------- 1 file changed, 1 insertion(+), 88 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 6e88b816c7c191..e9cd2435f18361 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -24,19 +24,11 @@ from torch.export import Dim, export from torch.testing import FileCheck from torch.testing._internal import common_utils -<<<<<<< HEAD -from torch.testing._internal.common_cuda import SM80OrLater, SM90OrLater -======= from torch.testing._internal.common_cuda import ( SM80OrLater, SM90OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION ) -from torch.testing._internal.common_device_type import ( - _has_sufficient_memory, - skipCUDAIf, -) ->>>>>>> 4e4e3395e6 ([rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions (#2100)) from torch.testing._internal.common_quantization import ( skip_if_no_torchvision, skipIfNoFBGEMM, @@ -50,10 +42,6 @@ IS_SANDCASTLE, IS_WINDOWS, skipIfRocm, -<<<<<<< HEAD -======= - skipIfXpu, ->>>>>>> 4e4e3395e6 ([rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions (#2100)) TEST_WITH_ROCM, ) from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda @@ -1072,80 +1060,7 @@ def forward(self, x, y): ) self.check_model(Repro(), example_inputs) -<<<<<<< HEAD -======= - @config.patch({"triton.autotune_at_compile_time": None}) - def test_stride_with_unbacked_expr(self): - class Repro(torch.nn.Module): - def forward(self, x, y): - u0 = x.item() - torch._check(u0 >= 1) - s0 = y.size(0) - expr = u0 * s0 - sevens = torch.empty_strided( - size=(10, expr, 32), stride=(expr * 32, 32, 1), device=x.device - ).fill_(7) - return sevens * 3 - - example_inputs = ( - torch.scalar_tensor(2, dtype=torch.int, device=self.device), - torch.ones(8, device=self.device), - ) - self.check_model(Repro(), example_inputs) - - @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") - @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") - def test_fallback_kernel_with_symexpr_output(self): - if self.device != GPU_TYPE: - raise unittest.SkipTest("requires GPU") - - class Module(torch.nn.Module): - def forward(self, q, k, v): - q = q.reshape( - q.shape[0], - 2, - q.shape[2] * q.shape[3], - q.shape[1] // 2, - ) - k = k.reshape( - k.shape[0], - 2, - k.shape[2] * k.shape[3], - k.shape[1] // 2, - ) - v = v.reshape( - v.shape[0], - 2, - v.shape[2] * v.shape[3], - v.shape[1] // 2, - ) - - res = torch.ops.aten._scaled_dot_product_flash_attention.default( - q, - k, - v, - ) - return res[0] - - m = Module().to(device=self.device) - tensor_shape = (4, 32, 4, 4) - inputs = ( - torch.randn(tensor_shape, dtype=torch.float16, device=self.device), - torch.randn(tensor_shape, dtype=torch.float16, device=self.device), - torch.randn(tensor_shape, dtype=torch.float16, device=self.device), - ) - - dynamic_shapes = { - "q": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC}, - "k": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC}, - "v": {2: Dim.DYNAMIC, 3: Dim.DYNAMIC}, - } - ep = torch.export.export(m, inputs, dynamic_shapes=dynamic_shapes, strict=False) - path = torch._inductor.aot_compile(ep.module(), inputs) - aot_model = torch._export.aot_load(path, device=self.device) - torch.testing.assert_close(m(*inputs), aot_model(*inputs)) ->>>>>>> 4e4e3395e6 ([rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions (#2100)) def test_large_grid(self): if self.device != "cuda": raise unittest.SkipTest("requires CUDA") @@ -2929,11 +2844,9 @@ def grid(meta): dynamic_shapes=dynamic_shapes, ) -<<<<<<< HEAD + @skipIfRocm # USE_MEM_EFF_ATTENTION was not enabled for build. -======= @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") ->>>>>>> 4e4e3395e6 ([rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions (#2100)) def test_scaled_dot_product_efficient_attention(self): if self.device != "cuda": raise unittest.SkipTest("requires CUDA") From e8fac90ee992ca9ee51f46a9618d1f0573cf6457 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Tue, 10 Jun 2025 14:00:34 +0100 Subject: [PATCH 3/4] Conflicts fix again --- test/inductor/test_aot_inductor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index e9cd2435f18361..5bace9a34263fe 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1060,7 +1060,6 @@ def forward(self, x, y): ) self.check_model(Repro(), example_inputs) - def test_large_grid(self): if self.device != "cuda": raise unittest.SkipTest("requires CUDA") @@ -2844,7 +2843,6 @@ def grid(meta): dynamic_shapes=dynamic_shapes, ) - @skipIfRocm # USE_MEM_EFF_ATTENTION was not enabled for build. @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") def test_scaled_dot_product_efficient_attention(self): From c95a29781e23d436cee829e1628c6e1a52e5097e Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Tue, 10 Jun 2025 14:01:23 +0100 Subject: [PATCH 4/4] Conflicts --- torch/testing/_internal/common_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index e74f24615e0910..8ef0747f2742a5 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1288,10 +1288,6 @@ def printErrors(self) -> None: IS_PPC = platform.machine() == "ppc64le" IS_X86 = platform.machine() in ('x86_64', 'i386') IS_ARM64 = platform.machine() in ('arm64', 'aarch64') -<<<<<<< HEAD -======= -IS_S390X = platform.machine() == "s390x" ->>>>>>> 4e4e3395e6 ([rocm6.4_internal_testing] Replaced ROCm specific skips to generalized conditions (#2100)) def is_avx512_vnni_supported(): if sys.platform != 'linux':