Skip to content

Commit

Permalink
merge PR 1317
Browse files Browse the repository at this point in the history
  • Loading branch information
BalaBalaYi committed Nov 6, 2024
1 parent 224a68d commit 345583b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
18 changes: 17 additions & 1 deletion dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,14 +577,30 @@ def _monitor_scale_plan_crd(self):
def _process_list_nodes(self, nodes: List[Node]):
"""Callback with node list by the list api of k8s."""

logger.debug(f"Got list nodes: {nodes}")
exist_nodes: Dict[str, List[int]] = {}
job_nodes = self._job_context.job_nodes()
for node_type in job_nodes.keys():
exist_nodes[node_type] = []

if nodes:
for node in nodes:
exist_nodes[node.type].append(node.id)
node_type = node.type
node_id = node.id
exist_nodes[node_type].append(node_id)

# for nodes not in current 'job_nodes' obj, re add it
if (
node_id not in job_nodes[node_type]
and node.status != NodeStatus.DELETED
):
logger.info(
f"Node {node_type} {node.id} with status {node.status}"
" is re added without the event"
)
new_node = copy.deepcopy(node)
self._job_context.update_job_node(new_node)

if node.status == NodeStatus.DELETED:
event_type = NodeEventType.DELETED
else:
Expand Down
11 changes: 11 additions & 0 deletions dlrover/python/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def test_process_list_nodes(self):
manager = create_job_manager(params, SpeedMonitor())
manager._init_nodes()
job_nodes = self.job_context.job_nodes()
self.assertFalse(4 in job_nodes[NodeType.WORKER])
for node in job_nodes[NodeType.PS].values():
node.status = NodeStatus.PENDING
self.job_context.update_job_node(node)
Expand All @@ -439,11 +440,21 @@ def test_process_list_nodes(self):
max_relaunch_count=1,
)
nodes.append(node)
nodes.append(
Node(
node_type=NodeType.WORKER,
node_id=4,
status=NodeStatus.RUNNING,
config_resource=NodeResource(1, 4096),
max_relaunch_count=1,
)
)
manager._process_list_nodes(nodes)

job_nodes = self.job_context.job_nodes()
ps_ids = list(job_nodes[NodeType.PS].keys())
self.assertListEqual(ps_ids, [0, 1, 2])
self.assertTrue(4 in self.job_context.job_nodes()[NodeType.WORKER])

@patch.object(DistributedJobManager, "_process_event")
def test_process_list_nodes_for_empty_case(self, mock_method):
Expand Down

0 comments on commit 345583b

Please sign in to comment.