diff --git a/dlrover/python/master/shard/task_manager.py b/dlrover/python/master/shard/task_manager.py index cc6ff2c3c..6fc2ad3e9 100644 --- a/dlrover/python/master/shard/task_manager.py +++ b/dlrover/python/master/shard/task_manager.py @@ -227,9 +227,7 @@ def _check_and_reassign_timeout_tasks(self): doing_task.node_id, task_id, ) - self.recover_tasks( - doing_task.node_type, doing_task.node_id - ) + dataset.report_task_status(task_id, success=False) self._invoke_task_timeout_callback(doing_task.node_id) break time.sleep(30) diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py index 654604edb..bccadac99 100644 --- a/dlrover/python/tests/test_job_manager.py +++ b/dlrover/python/tests/test_job_manager.py @@ -15,6 +15,7 @@ import unittest from unittest import mock +from dlrover.proto import elastic_training_pb2 from dlrover.python.common.constants import ( DistributionStrategy, JobExitReason, @@ -48,6 +49,7 @@ MockK8sPSJobArgs, create_task_manager, mock_k8s_client, + new_dataset_splitter, ) _MOCK_JOB_UUID = "11111" @@ -242,8 +244,21 @@ def test_recover_tasks_for_failed_workers(self): ds_name_0 = "test-0" ds_name_1 = "test-1" task_manager = create_task_manager(ds_name_0) - task_manager1 = create_task_manager(ds_name_1) - task_manager._datasets.update(task_manager1._datasets) + splitter = new_dataset_splitter( + False, + 100, + 1000, + 1, + ds_name_1, + "table", + ) + task_manager.new_dataset( + batch_size=10, + dataset_size=1000, + dataset_name=ds_name_1, + dataset_splitter=splitter, + task_type=elastic_training_pb2.EVALUATION, + ) task_callback = TaskRescheduleCallback(task_manager) params = MockK8sPSJobArgs() diff --git a/dlrover/python/tests/test_worker_manager.py b/dlrover/python/tests/test_worker_manager.py index d688d94d3..a193d9b5d 100644 --- a/dlrover/python/tests/test_worker_manager.py +++ b/dlrover/python/tests/test_worker_manager.py @@ -20,8 +20,8 @@ NodeType, PlatformType, ) -from dlrover.python.common.node import NodeGroupResource, NodeResource -from dlrover.python.master.node.worker import WorkerManager +from dlrover.python.common.node import Node, NodeGroupResource, NodeResource +from dlrover.python.master.node.worker import ChiefManager, WorkerManager from dlrover.python.master.resource.job import JobResource from dlrover.python.scheduler.factory import new_elastic_job from dlrover.python.tests.test_utils import mock_k8s_client @@ -106,6 +106,23 @@ def test_relaunch_node(self): self.assertEqual(plan.launch_nodes[0].config_resource.cpu, 16) self.assertEqual(worker_manager._nodes[5].id, 5) + def test_relaunch_chief_node(self): + tf_master_node = Node( + NodeType.MASTER, + node_id=0, + config_resource=NodeResource(cpu=16, memory=10240), + ) + manager = ChiefManager( + {0: tf_master_node}, + self._job_resource, + 3, + self._elastic_job.get_node_service_addr, + self._elastic_job.get_node_name, + ) + plan = manager.relaunch_node(tf_master_node) + self.assertEqual(plan.launch_nodes[0].config_resource.cpu, 16) + self.assertEqual(manager._nodes[1].id, 1) + def test_cut_pending_node_cpu(self): worker_manager = WorkerManager( self._job_nodes[NodeType.WORKER],