Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Job exit when all nodecheck failed #1323

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dlrover/python/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class JobExitReason(object):
RDZV_TIMEOUT_ERROR = "RdzvTimeout"
PENDING_TIMEOUT = "PendingTimeout"
UNCOMPLETED_TIMEOUT = "UncompletedTimeout"
RDZV_ALL_FAILED = "RdzvAllFailed"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change name to 'NodeCheckFailed'. This is actually a training failure, not a rdzv failure.



class CustomMetricKeys:
Expand Down
35 changes: 35 additions & 0 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,42 @@ def is_all_reduce_type_job(self):
== DistributionStrategy.ALLREDUCE
)

def is_all_workers_node_check_failed(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better move this function into WorkerManager. And work for 'AllReduceType' training only.

return all(
[
node.is_node_check_failed()
for _, node in self._job_nodes[NodeType.WORKER].items()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for node in self._job_nodes[NodeType.WORKER].values()

]
)

def should_early_stop(self):
# node-check all failed
if (
self.is_all_reduce_type_job()
and self.is_all_workers_node_check_failed()
):
msg = (
"Stop the training early because all worker nodes has "
"failed the node check in rendezvous."
)

self._process_error(
None,
0,
msg,
level=TrainingExceptionLevel.RDZV_ERROR,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

process_error(message=msg, level=...)


self._report_event(
ErrorMonitorConstants.TYPE_INFO,
"job",
ErrorMonitorConstants.ACTION_EARLY_STOP,
"All node check failed",
{"nodes": json.dumps(self._worker_manager.cur_nodes)},
)

return True, JobExitReason.RDZV_ALL_FAILED, msg

# ps pending judgement: any ps pod pending timeout
timeout_ps_nodes = (
self._ps_manager.get_pending_timeout_oom_recovered_node()
Expand Down
17 changes: 17 additions & 0 deletions dlrover/python/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,10 @@ def test_check_worker_status(self):
manager._job_nodes[NodeType.WORKER][0].status = NodeStatus.FINISHED
self.assertTrue(manager.all_critical_node_completed())

for worker in manager._job_nodes[NodeType.WORKER].values():
worker.reported_status = 2
self.assertTrue(manager.is_all_workers_node_check_failed())

def test_tf_ps_node_handling(self):
params = MockK8sPSJobArgs()
params.initilize()
Expand Down Expand Up @@ -730,6 +734,19 @@ def test_early_stop_part3(self):
result, reason, msg = manager.should_early_stop()
self.assertFalse(result)

def test_early_stop_part4(self):
params = MockK8sAllreduceJobArgs()
params.initilize()
manager = create_job_manager(params, SpeedMonitor())
manager._init_nodes()

manager.is_all_workers_node_check_failed = mock.MagicMock(
return_value=True
)
result, reason, msg = manager.should_early_stop()
self.assertTrue(result)
self.assertEqual(reason, JobExitReason.RDZV_ALL_FAILED)

def test_when_node_not_init(self):
params = MockK8sPSJobArgs()
params.initilize()
Expand Down
Loading