diff --git a/test/test_cuda.py b/test/test_cuda.py index 47610a979d81c..dca6b618291a9 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1192,7 +1192,7 @@ def _test_stream_event_nogil(self, sync_func, p2c, c2p): c2p.put(sync_func(self, TestCuda.FIFTY_MIL_CYCLES)) # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190 - @skipIfRocm + @skipIfRocm("as per pytorch/issues/53190") @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_stream_event_nogil(self): for sync_func in [TestCuda._stream_synchronize, @@ -1230,7 +1230,7 @@ def test_stream_event_nogil(self): self.assertGreater(parent_time + child_time, total_time * 1.4) # This test is flaky for ROCm, see issue #62602 - @skipIfRocm + @skipIfRocm("flakey on rocm") @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_events_wait(self): d0 = torch.device('cuda:0') @@ -1340,7 +1340,7 @@ def test_events_multi_gpu_elapsed_time(self): self.assertGreater(e0.elapsed_time(e2), 0) # XXX: this test only fails with hip-clang. revisit this once the dust has settled there. - @skipIfRocm + @skipIfRocm("fails on hip-clang") def test_record_stream(self): cycles_per_ms = get_cycles_per_ms() diff --git a/test/test_nn.py b/test/test_nn.py index 90bafbb4e59d7..de65a1eceebd7 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3555,7 +3555,7 @@ def test_cudnn_rnn_dropout_states_device(self): output = rnn(input, hx) @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') - @skipIfRocm + @skipIfRocm("Skipped on ROCm as explicit cudnn test") def test_cudnn_weight_format(self): rnns = [ nn.LSTM(10, 20, batch_first=True), diff --git a/test/test_optim.py b/test/test_optim.py index aeb23eaa5f515..cdf31de5e4e69 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1509,7 +1509,7 @@ def test_asgd(self): with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"): optim.ASGD(None, lr=1e-2, weight_decay=-0.5, foreach=foreach) - @skipIfRocm + @skipIfRocm("Skipped on ROCm due to some reason...") def test_rprop(self): is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability( 0 diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index eb0270058ea1c..dbe5f3e434e84 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -2232,7 +2232,7 @@ def run_test(n, k, upper, unitriangular, transpose, zero): itertools.product([True, False], repeat=4)): run_test(n, k, upper, unitriangular, transpose, zero) - @skipCUDAIfRocm + @skipCUDAIfRocm("sampled_addmm issue") @skipCUDAIf( not _check_cusparse_sddmm_available(), "cuSparse Generic API SDDMM is not available" @@ -2287,7 +2287,7 @@ def run_test(c, a, b, op_a, op_b, *, alpha=None, beta=None): for op_a, op_b in itertools.product([True, False], repeat=2): run_test(c, a, b, op_a, op_b) - @skipCUDAIfRocm + @skipCUDAIfRocm("sddmm known issue") @skipCUDAIf( not _check_cusparse_sddmm_available(), "cuSparse Generic API SDDMM is not available" @@ -2318,7 +2318,7 @@ def test_sampled_addmm_autograd(self, device, dtype): self.assertEqual(a.grad, a1.grad) self.assertEqual(b.grad, b1.grad) - @skipCUDAIfRocm + @skipCUDAIfRocm("sddmm known issue") @onlyCUDA @skipCUDAIf(True, "Causes CUDA memory exception, see https://github.com/pytorch/pytorch/issues/72177") @skipCUDAIf( @@ -2526,7 +2526,7 @@ def fn(input): dense_output.backward(dense_covector) self.assertEqual(sparse_input.grad, dense_input.grad) - @skipCUDAIfRocm + @skipCUDAIfRocm("addmm known issues") @skipCUDAIf( not _check_cusparse_sddmm_available(), "cuSparse Generic API SDDMM is not available" diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 0585ee0820e7e..edc9174fdf08c 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -1252,8 +1252,19 @@ def skipCUDAIfNoMagmaAndNoCusolver(fn): return skipCUDAIfNoMagma(fn) # Skips a test on CUDA when using ROCm. -def skipCUDAIfRocm(fn): - return skipCUDAIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")(fn) +def skipCUDAIfRocm(msg="test doesn't currently work on the ROCm stack"): + + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if TEST_WITH_ROCM: + if self.device_type == 'cuda': + raise unittest.SkipTest(f"skipCUDAIfRocm: {msg}") + + return fn(self, *args, **kwargs) + + return wrap_fn + return dec_fn # Skips a test on CUDA when not using ROCm. def skipCUDAIfNotRocm(fn): diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index f6161990ce13c..364817b0b7206 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1058,14 +1058,15 @@ def has_corresponding_torch_dtype(np_dtype): torch.complex32: np.complex64 }) -def skipIfRocm(fn): - @wraps(fn) - def wrapper(*args, **kwargs): - if TEST_WITH_ROCM: - raise unittest.SkipTest("test doesn't currently work on the ROCm stack") - else: - fn(*args, **kwargs) - return wrapper +def skipIfRocm(msg="test doesn't currently work on the ROCm stack"): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if TEST_WITH_ROCM: + raise unittest.SkipTest(f"skipIfRocm: {msg}") + return fn(self, *args, **kwargs) + return wrap_fn + return dec_fn def skipIfMps(fn): @wraps(fn)