From 65166d86a3a38800872a2f6ffc4559c687355d27 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Fri, 22 Nov 2024 00:56:33 +0000 Subject: [PATCH] [MPS] Add regression test for sync deadlock (#141296) See https://github.com/pytorch/pytorch/pull/140725#issuecomment-2492434870 Running `torch.mps.synchronize()` after metal kernel resulted in infinite wait inside `[_MTLCommandBuffer waitUntilCompleted]` ``` (lldb) bt * thread #1, queue = 'com.apple.main-thread', stop reason = signal SIGSTOP * frame #0: 0x00000001aa919084 Metal`pthread_cond_wait + 12 frame #1: 0x00000001aa78b1b4 Metal`-[_MTLCommandBuffer waitUntilCompleted] + 84 frame #2: 0x00000001032bf358 libtorch_python.dylib`torch::mps::MPSModule_deviceSynchronize(_object*, _object*) + 40 frame #3: 0x0000000100e94c20 Python`cfunction_vectorcall_NOARGS + 100 frame #4: 0x0000000100e389b8 Python`PyObject_Vectorcall + 92 frame #5: 0x0000000100f61e38 Python`_PyEval_EvalFrameDefault + 19040 frame #6: 0x0000000100f5d180 Python`PyEval_EvalCode + 200 frame #7: 0x0000000100fcd1a4 Python`run_eval_code_obj + 104 frame #8: 0x0000000100fccbe4 Python`run_mod + 168 frame #9: 0x0000000100fcb518 Python`pyrun_file + 164 frame #10: 0x0000000100fca854 Python`_PyRun_SimpleFileObject + 256 frame #11: 0x0000000100fca4e8 Python`_PyRun_AnyFileObject + 80 frame #12: 0x0000000100ff2028 Python`pymain_run_file_obj + 164 frame #13: 0x0000000100ff1ce4 Python`pymain_run_file + 72 frame #14: 0x0000000100ff0f74 Python`Py_RunMain + 988 frame #15: 0x0000000100ff1564 Python`pymain_main + 304 frame #16: 0x0000000100ff1604 Python`Py_BytesMain + 40 frame #17: 0x000000019f630274 dyld`start + 2840 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/141296 Approved by: https://github.com/huydhn --- test/test_mps.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/test_mps.py b/test/test_mps.py index fe7a65d3696fc..a294b1b0d2bcc 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -8385,6 +8385,14 @@ def test_cumprod_dim_check(self): self.assertRaises(IndexError, lambda: x.cumprod(2)) self.assertRaises(IndexError, lambda: x.cumprod(-3)) + def test_do_sync_thrice_its_all_right(self): + # Regression test for https://github.com/pytorch/pytorch/commit/9bc9d4cdb4355a385a7d7959f07d04d1648d6904 + # That caused sync calls to deadlock + x = torch.nextafter(torch.ones(1024, device='mps'), torch.zeros(1024, device='mps')) + for _ in range(3): + torch.mps.synchronize() + self.assertLess(x.sum().item(), x.numel()) + class TestLogical(TestCaseMPS): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)