From 85a179c6cdc92bfe8e44265c70a03fb7830ed6bf Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Wed, 15 Mar 2023 10:28:57 +0000 Subject: [PATCH 1/7] Added msg arg for skipIfRocm This change will provide extra information on skips --- torch/testing/_internal/common_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index f6161990ce13c..5026aa61ef57c 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1058,11 +1058,11 @@ def has_corresponding_torch_dtype(np_dtype): torch.complex32: np.complex64 }) -def skipIfRocm(fn): +def skipIfRocm(fn, msg="test doesn't currently work on the ROCm stack"): @wraps(fn) def wrapper(*args, **kwargs): if TEST_WITH_ROCM: - raise unittest.SkipTest("test doesn't currently work on the ROCm stack") + raise unittest.SkipTest(msg) else: fn(*args, **kwargs) return wrapper From 7491580c604c4b68712f7aae1e514f3901764907 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Wed, 15 Mar 2023 14:40:27 +0000 Subject: [PATCH 2/7] Skip test example --- test/test_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_optim.py b/test/test_optim.py index aeb23eaa5f515..de737f1612152 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1509,7 +1509,7 @@ def test_asgd(self): with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"): optim.ASGD(None, lr=1e-2, weight_decay=-0.5, foreach=foreach) - @skipIfRocm + @skipIfRocm("Testing skip message :)" def test_rprop(self): is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability( 0 From b5c39dda2b2e16cff7616adf04e0be508c4bc209 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Wed, 22 Mar 2023 00:04:13 +0000 Subject: [PATCH 3/7] Fixes --- test/test_nn.py | 2 +- test/test_optim.py | 2 +- torch/testing/_internal/common_utils.py | 17 +++++++++-------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 90bafbb4e59d7..de65a1eceebd7 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3555,7 +3555,7 @@ def test_cudnn_rnn_dropout_states_device(self): output = rnn(input, hx) @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') - @skipIfRocm + @skipIfRocm("Skipped on ROCm as explicit cudnn test") def test_cudnn_weight_format(self): rnns = [ nn.LSTM(10, 20, batch_first=True), diff --git a/test/test_optim.py b/test/test_optim.py index de737f1612152..cdf31de5e4e69 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1509,7 +1509,7 @@ def test_asgd(self): with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"): optim.ASGD(None, lr=1e-2, weight_decay=-0.5, foreach=foreach) - @skipIfRocm("Testing skip message :)" + @skipIfRocm("Skipped on ROCm due to some reason...") def test_rprop(self): is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability( 0 diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 5026aa61ef57c..376a140018c0f 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1058,14 +1058,15 @@ def has_corresponding_torch_dtype(np_dtype): torch.complex32: np.complex64 }) -def skipIfRocm(fn, msg="test doesn't currently work on the ROCm stack"): - @wraps(fn) - def wrapper(*args, **kwargs): - if TEST_WITH_ROCM: - raise unittest.SkipTest(msg) - else: - fn(*args, **kwargs) - return wrapper +def skipIfRocm(msg="test doesn't currently work on the ROCm stack"): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if TEST_WITH_ROCM: + raise unittest.SkipTest(msg) + return fn(self, *args, **kwargs) + return wrap_fn + return dec_fn def skipIfMps(fn): @wraps(fn) From c8e60b5052962f45607fcb30be7f6ef6ce860683 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Wed, 29 Mar 2023 11:17:18 +0100 Subject: [PATCH 4/7] Update common_device_type.py --- torch/testing/_internal/common_device_type.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 0585ee0820e7e..edc9174fdf08c 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -1252,8 +1252,19 @@ def skipCUDAIfNoMagmaAndNoCusolver(fn): return skipCUDAIfNoMagma(fn) # Skips a test on CUDA when using ROCm. -def skipCUDAIfRocm(fn): - return skipCUDAIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")(fn) +def skipCUDAIfRocm(msg="test doesn't currently work on the ROCm stack"): + + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if TEST_WITH_ROCM: + if self.device_type == 'cuda': + raise unittest.SkipTest(f"skipCUDAIfRocm: {msg}") + + return fn(self, *args, **kwargs) + + return wrap_fn + return dec_fn # Skips a test on CUDA when not using ROCm. def skipCUDAIfNotRocm(fn): From 283b2a6d2e2a43b8975e1414a74d673bfa5a1ee4 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Wed, 29 Mar 2023 11:18:20 +0100 Subject: [PATCH 5/7] Update common_utils.py --- torch/testing/_internal/common_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 376a140018c0f..364817b0b7206 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1063,7 +1063,7 @@ def dec_fn(fn): @wraps(fn) def wrap_fn(self, *args, **kwargs): if TEST_WITH_ROCM: - raise unittest.SkipTest(msg) + raise unittest.SkipTest(f"skipIfRocm: {msg}") return fn(self, *args, **kwargs) return wrap_fn return dec_fn From 9ce082c34b4b41fd4920857f793b582f559f00bd Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Wed, 29 Mar 2023 11:20:46 +0100 Subject: [PATCH 6/7] Update test_sparse_csr.py --- test/test_sparse_csr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index eb0270058ea1c..dbe5f3e434e84 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -2232,7 +2232,7 @@ def run_test(n, k, upper, unitriangular, transpose, zero): itertools.product([True, False], repeat=4)): run_test(n, k, upper, unitriangular, transpose, zero) - @skipCUDAIfRocm + @skipCUDAIfRocm("sampled_addmm issue") @skipCUDAIf( not _check_cusparse_sddmm_available(), "cuSparse Generic API SDDMM is not available" @@ -2287,7 +2287,7 @@ def run_test(c, a, b, op_a, op_b, *, alpha=None, beta=None): for op_a, op_b in itertools.product([True, False], repeat=2): run_test(c, a, b, op_a, op_b) - @skipCUDAIfRocm + @skipCUDAIfRocm("sddmm known issue") @skipCUDAIf( not _check_cusparse_sddmm_available(), "cuSparse Generic API SDDMM is not available" @@ -2318,7 +2318,7 @@ def test_sampled_addmm_autograd(self, device, dtype): self.assertEqual(a.grad, a1.grad) self.assertEqual(b.grad, b1.grad) - @skipCUDAIfRocm + @skipCUDAIfRocm("sddmm known issue") @onlyCUDA @skipCUDAIf(True, "Causes CUDA memory exception, see https://github.com/pytorch/pytorch/issues/72177") @skipCUDAIf( @@ -2526,7 +2526,7 @@ def fn(input): dense_output.backward(dense_covector) self.assertEqual(sparse_input.grad, dense_input.grad) - @skipCUDAIfRocm + @skipCUDAIfRocm("addmm known issues") @skipCUDAIf( not _check_cusparse_sddmm_available(), "cuSparse Generic API SDDMM is not available" From 78d8901723df4574edb86f1c26f0380731ce764c Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Wed, 29 Mar 2023 11:22:16 +0100 Subject: [PATCH 7/7] Update test_cuda.py --- test/test_cuda.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_cuda.py b/test/test_cuda.py index 47610a979d81c..dca6b618291a9 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1192,7 +1192,7 @@ def _test_stream_event_nogil(self, sync_func, p2c, c2p): c2p.put(sync_func(self, TestCuda.FIFTY_MIL_CYCLES)) # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190 - @skipIfRocm + @skipIfRocm("as per pytorch/issues/53190") @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_stream_event_nogil(self): for sync_func in [TestCuda._stream_synchronize, @@ -1230,7 +1230,7 @@ def test_stream_event_nogil(self): self.assertGreater(parent_time + child_time, total_time * 1.4) # This test is flaky for ROCm, see issue #62602 - @skipIfRocm + @skipIfRocm("flakey on rocm") @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_events_wait(self): d0 = torch.device('cuda:0') @@ -1340,7 +1340,7 @@ def test_events_multi_gpu_elapsed_time(self): self.assertGreater(e0.elapsed_time(e2), 0) # XXX: this test only fails with hip-clang. revisit this once the dust has settled there. - @skipIfRocm + @skipIfRocm("fails on hip-clang") def test_record_stream(self): cycles_per_ms = get_cycles_per_ms()