diff --git a/dlrover/python/common/constants.py b/dlrover/python/common/constants.py index e1c81e5ba..dd9b752ba 100644 --- a/dlrover/python/common/constants.py +++ b/dlrover/python/common/constants.py @@ -118,6 +118,7 @@ class JobExitReason(object): RDZV_TIMEOUT_ERROR = "RdzvTimeout" PENDING_TIMEOUT = "PendingTimeout" UNCOMPLETED_TIMEOUT = "UncompletedTimeout" + RDZV_ALL_FAILED = "RdzvAllFailed" class CustomMetricKeys: diff --git a/dlrover/python/master/node/dist_job_manager.py b/dlrover/python/master/node/dist_job_manager.py index 848a9706e..f7994cc89 100644 --- a/dlrover/python/master/node/dist_job_manager.py +++ b/dlrover/python/master/node/dist_job_manager.py @@ -243,7 +243,42 @@ def is_all_reduce_type_job(self): == DistributionStrategy.ALLREDUCE ) + def is_all_workers_node_check_failed(self): + return all( + [ + node.is_node_check_failed() + for _, node in self._job_nodes[NodeType.WORKER].items() + ] + ) + def should_early_stop(self): + # node-check all failed + if ( + self.is_all_reduce_type_job() + and self.is_all_workers_node_check_failed() + ): + msg = ( + "Stop the training early because all worker nodes has " + "failed the node check in rendezvous." + ) + + self._process_error( + None, + 0, + msg, + level=TrainingExceptionLevel.RDZV_ERROR, + ) + + self._report_event( + ErrorMonitorConstants.TYPE_INFO, + "job", + ErrorMonitorConstants.ACTION_EARLY_STOP, + "All node check failed", + {"nodes": json.dumps(self._worker_manager.cur_nodes)}, + ) + + return True, JobExitReason.RDZV_ALL_FAILED, msg + # ps pending judgement: any ps pod pending timeout timeout_ps_nodes = ( self._ps_manager.get_pending_timeout_oom_recovered_node() diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py index 50fed89c7..e9ea486f0 100644 --- a/dlrover/python/tests/test_job_manager.py +++ b/dlrover/python/tests/test_job_manager.py @@ -575,6 +575,10 @@ def test_check_worker_status(self): manager._job_nodes[NodeType.WORKER][0].status = NodeStatus.FINISHED self.assertTrue(manager.all_critical_node_completed()) + for worker in manager._job_nodes[NodeType.WORKER].values(): + worker.reported_status = 2 + self.assertTrue(manager.is_all_workers_node_check_failed()) + def test_tf_ps_node_handling(self): params = MockK8sPSJobArgs() params.initilize() @@ -730,6 +734,19 @@ def test_early_stop_part3(self): result, reason, msg = manager.should_early_stop() self.assertFalse(result) + def test_early_stop_part4(self): + params = MockK8sAllreduceJobArgs() + params.initilize() + manager = create_job_manager(params, SpeedMonitor()) + manager._init_nodes() + + manager.is_all_workers_node_check_failed = mock.MagicMock( + return_value=True + ) + result, reason, msg = manager.should_early_stop() + self.assertTrue(result) + self.assertEqual(reason, JobExitReason.RDZV_ALL_FAILED) + def test_when_node_not_init(self): params = MockK8sPSJobArgs() params.initilize()