Skip to content

Commit

Permalink
Fix: recover timeout tasks. (#522)
Browse files Browse the repository at this point in the history
* Fix recovering tasks

* Fix test cases
  • Loading branch information
workingloong authored Jul 24, 2023
1 parent 3928253 commit 8f888c7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
19 changes: 11 additions & 8 deletions dlrover/python/master/shard/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
)
from dlrover.python.master.shard.batch_dataset_manager import (
BatchDatasetManager,
DoingTask,
)
from dlrover.python.master.shard.dataset_splitter import DatasetSplitter

Expand Down Expand Up @@ -63,10 +62,8 @@ def new_dataset(
task_type=elastic_training_pb2.NONE,
):
logger.info(
"New %s dataset with batch size = %s, dataset size = %s",
dataset_name,
batch_size,
dataset_size,
f"New {task_type} dataset {dataset_name} with, "
f"batch size = {batch_size} dataset size = {dataset_size}"
)

with self._lock:
Expand Down Expand Up @@ -161,7 +158,7 @@ def finished(self):
def recover_tasks(self, node_type, node_id):
"""Recover doing tasks for a dead worker if needed"""
for name, dataset in self._datasets.items():
doing_tasks: Dict[int, DoingTask] = dataset.get_doing_tasks()
doing_tasks = dataset.doing
if not doing_tasks:
return
ids = [
Expand All @@ -173,12 +170,15 @@ def recover_tasks(self, node_type, node_id):
if not ids:
return
request = elastic_training_pb2.ReportTaskResultRequest()
recover_tasks = []
for id in ids:
request.task_id = id
request.dataset_name = name
recover_tasks.append(id)
self.report_dataset_task(request, False)
logger.info(
"Recover tasks of dataset %s assigned to %s-%d",
"Recover tasks %s of dataset %s assigned to %s-%d",
recover_tasks,
name,
node_type,
node_id,
Expand All @@ -204,7 +204,7 @@ def _invoke_task_timeout_callback(self, worker_id):

def _check_and_reassign_timeout_tasks(self):
"""Check whether there are timeout tasks periodically."""
logger.info("Start the thread to monitor timeout tasks")
logger.info("Start the thread to monitor timeout tasks.")
while True:
for _, dataset in self._datasets.items():
doing_tasks = dataset.doing.copy()
Expand All @@ -227,6 +227,9 @@ def _check_and_reassign_timeout_tasks(self):
doing_task.node_id,
task_id,
)
self.recover_tasks(
doing_task.node_type, doing_task.node_id
)
self._invoke_task_timeout_callback(doing_task.node_id)
break
time.sleep(30)
Expand Down
17 changes: 12 additions & 5 deletions dlrover/python/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,25 +239,32 @@ def test_create_allreduce_job_manager(self):
self.assertEqual(len(manager._job_nodes[NodeType.WORKER]), 3)

def test_recover_tasks_for_failed_workers(self):
dataset_name = "test"
task_manager = create_task_manager()
ds_name_0 = "test-0"
ds_name_1 = "test-1"
task_manager = create_task_manager(ds_name_0)
task_manager1 = create_task_manager(ds_name_1)
task_manager._datasets.update(task_manager1._datasets)

task_callback = TaskRescheduleCallback(task_manager)
params = MockK8sPSJobArgs()
params.initilize()
manager = create_job_manager(params, SpeedMonitor())
manager._init_nodes()
manager.add_node_event_callback(task_callback)

dataset = task_manager.get_dataset(dataset_name)
task_manager.get_dataset_task(NodeType.WORKER, 0, dataset_name)
dataset_0 = task_manager.get_dataset(ds_name_0)
dataset_1 = task_manager.get_dataset(ds_name_1)
task_manager.get_dataset_task(NodeType.WORKER, 0, ds_name_0)
task_manager.get_dataset_task(NodeType.WORKER, 0, ds_name_1)
node = Node(
node_type=NodeType.WORKER,
node_id=0,
status=NodeStatus.RUNNING,
config_resource=NodeResource(1, 4096),
)
manager._process_node_events(NODE_STATE_FLOWS[9], node)
self.assertEqual(len(dataset.doing), 0)
self.assertEqual(len(dataset_0.doing), 0)
self.assertEqual(len(dataset_1.doing), 0)

def test_create_initial_nodes(self):
params = MockK8sPSJobArgs()
Expand Down
3 changes: 1 addition & 2 deletions dlrover/python/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,8 @@ def mock_list_namespaced_pod(label_selector):
)


def create_task_manager():
def create_task_manager(dataset_name="test"):
task_manager = TaskManager(False, SpeedMonitor())
dataset_name = "test"
splitter = new_dataset_splitter(
False,
100,
Expand Down

0 comments on commit 8f888c7

Please sign in to comment.