From a38cf8242bc7060b6be8763a1a0680842280119d Mon Sep 17 00:00:00 2001 From: Bo Sang <31138775+samplise@users.noreply.github.com> Date: Wed, 6 Nov 2024 18:52:04 -0800 Subject: [PATCH] Job context implementation (#1319) * split job context impl * optimization and ut fix * fix * fix * fix ut * fix ut * fix ut * fix ut * merge PR 1317 * fix ut * fix ut * sync * Refactor node event report and report succeeded. (#1315) * refactor node event report * fix ut * merge check report * lint * refactor node succeeded field into a more common status field * fix ut * optimize ut * no updating if already succeeded * fix ut (cherry picked from commit d9dc1c1be74ff1ebaaea5582dfd60a9ba724a330) * fix * fix ut * lint * lint * fix * optimization * more ut * optimize action queue * optimize action response * add ut * optimized * optimized --------- Co-authored-by: BO SANG Co-authored-by: chentianyi.cty Co-authored-by: Tianyi Chen --- dlrover/python/common/grpc.py | 11 + dlrover/python/common/node.py | 5 + dlrover/python/diagnosis/common/constants.py | 16 +- .../diagnosis/common/diagnosis_action.py | 215 ++++++++++++++++++ .../diagnosis/inferencechain/coordinator.py | 6 +- .../worker.py => elastic_agent/context.py} | 4 +- .../diagnosis/diagnosis_agent.py | 32 +-- dlrover/python/elastic_agent/master_client.py | 28 ++- .../python/elastic_agent/monitor/training.py | 4 +- .../python/elastic_agent/torch/training.py | 16 +- .../master/elastic_training/rdzv_manager.py | 4 +- .../python/master/node/dist_job_manager.py | 197 +++++++++------- dlrover/python/master/node/job_auto_scaler.py | 27 ++- dlrover/python/master/node/job_context.py | 114 ++++++++++ dlrover/python/master/node/job_manager.py | 15 +- .../python/master/node/local_job_manager.py | 40 +++- dlrover/python/master/node/ps.py | 48 ++-- dlrover/python/master/node/training_node.py | 65 ++++-- dlrover/python/master/node/worker.py | 68 +++--- dlrover/python/master/servicer.py | 25 +- dlrover/python/tests/test_common_util.py | 2 +- dlrover/python/tests/test_diagnosis.py | 75 +++++- dlrover/python/tests/test_diagnosis_agent.py | 28 +-- .../tests/test_elastic_training_agent.py | 1 + dlrover/python/tests/test_job_auto_scaler.py | 43 +++- dlrover/python/tests/test_job_manager.py | 147 ++++++++---- dlrover/python/tests/test_log.py | 1 + dlrover/python/tests/test_master.py | 16 +- dlrover/python/tests/test_node.py | 4 + dlrover/python/tests/test_pod_scaler.py | 3 + dlrover/python/tests/test_ps_manager.py | 70 ++++-- dlrover/python/tests/test_servicer.py | 60 +++-- dlrover/python/tests/test_sync_service.py | 17 +- dlrover/python/tests/test_time_util.py | 31 +++ dlrover/python/tests/test_worker_manager.py | 101 ++++---- .../diagnose_action.py => util/time_util.py} | 12 +- 36 files changed, 1161 insertions(+), 390 deletions(-) create mode 100644 dlrover/python/diagnosis/common/diagnosis_action.py rename dlrover/python/{common/worker.py => elastic_agent/context.py} (93%) create mode 100644 dlrover/python/master/node/job_context.py create mode 100644 dlrover/python/tests/test_time_util.py rename dlrover/python/{diagnosis/common/diagnose_action.py => util/time_util.py} (71%) diff --git a/dlrover/python/common/grpc.py b/dlrover/python/common/grpc.py index 17a7d0540..8d691f3a7 100644 --- a/dlrover/python/common/grpc.py +++ b/dlrover/python/common/grpc.py @@ -506,3 +506,14 @@ class ElasticRunConfigRequest(Message): @dataclass class ElasticRunConfig(Message): configs: Dict[str, str] = field(default_factory=dict) + + +@dataclass +class DiagnosisAction(Message): + action_cls: str = "" + action_content: str = "" + + +@dataclass +class HeartbeatResponse(Message): + action: DiagnosisAction = field(default_factory=DiagnosisAction) diff --git a/dlrover/python/common/node.py b/dlrover/python/common/node.py index 0fc3c7b4c..cf89e8aad 100644 --- a/dlrover/python/common/node.py +++ b/dlrover/python/common/node.py @@ -384,3 +384,8 @@ def to_dict(self): d.pop("config_resource", None) d.pop("used_resource", None) return d + + def update_from_node(self, node): + if self == node: + return + self.__dict__.update(node.__dict__) diff --git a/dlrover/python/diagnosis/common/constants.py b/dlrover/python/diagnosis/common/constants.py index 601bb730d..e190a8298 100644 --- a/dlrover/python/diagnosis/common/constants.py +++ b/dlrover/python/diagnosis/common/constants.py @@ -24,6 +24,10 @@ class InferenceConfigKey(object): class DiagnosisConstant(object): MASTER_DIAGNOSIS_OBSERVING_INTERVAL_SECS = 180 AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS = 60 + MASTER = -1 + ANY_INSTANCE = -2 + LOCAL_INSTANCE = -3 + ACTION_EXPIRED_TIME_PERIOD_DEFAULT = 60 * 5 class DiagnosisDataType(object): @@ -32,7 +36,15 @@ class DiagnosisDataType(object): XPU_TIMER_METRIC = "XPU_TIMER_METRIC" -class DiagnosisAction(object): - NO_ACTION = "no_action" +class DiagnosisActionType(object): + # common + NONE = "no_action" + ANY = "any_action" + + # master operation + MASTER_RELAUNCH_WORKER = "master_relaunch_worker" + EVENT = "event" + + # node operation RESTART_WORKER = "restart_worker" RELAUNCH_WORKER = "relaunch_worker" diff --git a/dlrover/python/diagnosis/common/diagnosis_action.py b/dlrover/python/diagnosis/common/diagnosis_action.py new file mode 100644 index 000000000..968263859 --- /dev/null +++ b/dlrover/python/diagnosis/common/diagnosis_action.py @@ -0,0 +1,215 @@ +# Copyright 2024 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import queue +import threading +from abc import ABCMeta +from datetime import datetime +from queue import Queue +from typing import Dict + +from dlrover.python.common.log import default_logger as logger +from dlrover.python.diagnosis.common.constants import ( + DiagnosisActionType, + DiagnosisConstant, +) +from dlrover.python.util.time_util import has_expired + + +class DiagnosisAction(metaclass=ABCMeta): + """ + Operation to be done after diagnostician. + + Args: + action_type (DiagnosisActionType): Type of action. + instance (Optional): Instance id(nod id). Defaults to -1(Master). + timestamp (Optional[datetime.datetime]): Timestamp of action. Unit: ms. + Defaults to current time. + expired_time_period (Optional): Milliseconds of expired time period. + Unit: ms. Defaults to 60 seconds. + """ + + def __init__( + self, + action_type=DiagnosisActionType.NONE, + instance=DiagnosisConstant.MASTER, + timestamp=0, + expired_time_period=60 * 1000, + ): + self._action_type = action_type + self._instance = instance + if timestamp == 0: + self._timestamp: float = datetime.now().timestamp() + else: + self._timestamp: float = timestamp + + if expired_time_period == 0: + self._expired_time_period = ( + DiagnosisConstant.ACTION_EXPIRED_TIME_PERIOD_DEFAULT + ) + else: + self._expired_time_period = expired_time_period + + @property + def action_type(self): + return self._action_type + + @property + def instance(self): + return self._instance + + @property + def timestamp(self): + return self._timestamp + + @property + def expired_timestamp(self): + return self._timestamp + self._expired_time_period + + def is_expired(self) -> bool: + return has_expired(self._timestamp, self._expired_time_period) + + def to_json(self): + data = {k.lstrip("_"): v for k, v in self.__dict__.items()} + return json.dumps(data) + + @classmethod + def from_json(cls, json_data): + return cls(**json.loads(json_data)) + + +class NoAction(DiagnosisAction): + def __init__(self): + super(NoAction, self).__init__( + action_type=DiagnosisActionType.NONE, instance=-999 + ) + + +class EventAction(DiagnosisAction): + """Output the specified event.""" + + def __init__( + self, + event_type: str = "", + event_instance: str = "", + event_action: str = "", + event_msg: str = "", + event_labels: Dict[str, str] = {}, + timestamp=0, + expired_time_period=0, + ): + super().__init__( + DiagnosisActionType.EVENT, + timestamp=timestamp, + expired_time_period=expired_time_period, + ) + self._event_type = event_type + self._event_instance = event_instance + self._event_action = event_action + self._event_msg = event_msg + self._event_labels = event_labels + + @property + def event_type(self): + return self._event_type + + @property + def event_instance(self): + return self._event_instance + + @property + def event_action(self): + return self._event_action + + @property + def event_msg(self): + return self._event_msg + + @property + def event_labels(self): + return self._event_labels + + +class NodeAction(DiagnosisAction): + def __init__( + self, + node_id, + node_status, + reason, + action_type=DiagnosisActionType.MASTER_RELAUNCH_WORKER, + timestamp=0, + expired_time_period=0, + ): + super().__init__( + action_type, + node_id, + timestamp, + expired_time_period, + ) + self._node_status = node_status + self._reason = reason + + @property + def node_id(self): + return self.instance + + @property + def node_status(self): + return self._node_status + + @property + def reason(self): + return self._reason + + +def is_same_action(action1: DiagnosisAction, action2: DiagnosisAction) -> bool: + return False + + +class DiagnosisActionQueue: + def __init__(self): + self._actions: Dict[int, Queue[DiagnosisAction]] = {} + self._lock = threading.Lock() + + def add_action(self, new_action: DiagnosisAction): + with self._lock: + instance = new_action.instance + if instance not in self._actions: + self._actions[instance] = Queue(maxsize=10) + ins_actions = self._actions[instance] + try: + ins_actions.put(new_action, timeout=3) + logger.info(f"New diagnosis action {new_action}") + except queue.Full: + logger.warning( + f"Diagnosis actions for {instance} is full, " + f"skip action: {new_action}." + ) + + def next_action( + self, + instance=DiagnosisConstant.LOCAL_INSTANCE, + ) -> DiagnosisAction: + with self._lock: + while True: + if ( + instance not in self._actions + or self._actions[instance].empty() + ): + return DiagnosisAction() + action = self._actions[instance].get() + if not action.is_expired(): + return action + else: + logger.info(f"Skip expired diagnosis action: {action}.") diff --git a/dlrover/python/diagnosis/inferencechain/coordinator.py b/dlrover/python/diagnosis/inferencechain/coordinator.py index 07cb70326..f092fa95a 100644 --- a/dlrover/python/diagnosis/inferencechain/coordinator.py +++ b/dlrover/python/diagnosis/inferencechain/coordinator.py @@ -13,9 +13,9 @@ from typing import List -from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction +from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction from dlrover.python.diagnosis.common.inference_chain import Inference -def coordinate_inferences(observations: List[Inference]) -> DiagnoseAction: - return DiagnoseAction() +def coordinate_inferences(observations: List[Inference]) -> DiagnosisAction: + return DiagnosisAction() diff --git a/dlrover/python/common/worker.py b/dlrover/python/elastic_agent/context.py similarity index 93% rename from dlrover/python/common/worker.py rename to dlrover/python/elastic_agent/context.py index c09c4fdaa..bbfd1f4db 100644 --- a/dlrover/python/common/worker.py +++ b/dlrover/python/elastic_agent/context.py @@ -11,10 +11,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from torch.distributed.elastic.agent.server.api import RunResult, WorkerSpec +from torch.distributed.elastic.agent.server import RunResult, WorkerSpec -class WorkerContext: +class AgentContext(object): def __init__( self, worker_spec: WorkerSpec, diff --git a/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py b/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py index 7b1619829..b03feda42 100644 --- a/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py +++ b/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py @@ -23,13 +23,12 @@ from dlrover.python.common.error import ProcessError from dlrover.python.common.log import default_logger as logger from dlrover.python.common.singleton import Singleton -from dlrover.python.common.worker import WorkerContext from dlrover.python.diagnosis.common.constants import ( - DiagnosisAction, + DiagnosisActionType, DiagnosisConstant, InferenceConfigKey, ) -from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction +from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric from dlrover.python.diagnosis.common.inference_chain import ( Inference, @@ -50,6 +49,7 @@ get_worker_diagnosis_operators, get_worker_observe_operators, ) +from dlrover.python.elastic_agent.context import AgentContext from dlrover.python.elastic_agent.master_client import MasterClient @@ -105,7 +105,7 @@ def _observe(self) -> List[Inference]: def _diagnose_observations( self, observations: List[Inference] - ) -> DiagnoseAction: + ) -> DiagnosisAction: conclusions: List[Inference] = [] for ob in observations: ic = InferenceChain([ob], self._diagnosis_operators) @@ -133,9 +133,9 @@ def _periodically_diagnosis(self): DiagnosisConstant.AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS ) - def diagnose_training_failure(self, worker_context: WorkerContext) -> str: + def diagnose_training_failure(self, agent_context: AgentContext) -> str: self._report_failure_to_master( - worker_context.run_result.failures, worker_context.restart_count + agent_context.run_result.failures, agent_context.restart_count ) # check if the node is failed inference = Inference( @@ -156,25 +156,25 @@ def diagnose_training_failure(self, worker_context: WorkerContext) -> str: ) failure_node = is_inference_included(infer_results, failure_inf) - if worker_context.remaining_failovers > 0 and not failure_node: + if agent_context.remaining_failovers > 0 and not failure_node: logger.info( - f"[{worker_context.worker_spec.role}] Worker group " - f"{worker_context.run_result.state.name}, " + f"[{agent_context.worker_spec.role}] Worker group " + f"{agent_context.run_result.state.name}, " f"is failure node: {failure_node}," - f"{worker_context.remaining_failovers}/" - f"{worker_context.worker_spec.max_restarts} " + f"{agent_context.remaining_failovers}/" + f"{agent_context.worker_spec.max_restarts} " f"attempts left; will restart worker group." ) - return DiagnosisAction.RESTART_WORKER + return DiagnosisActionType.RESTART_WORKER else: logger.info( - f"[{worker_context.worker_spec.role}] Worker group " - f"{worker_context.run_result.state.name}, " + f"[{agent_context.worker_spec.role}] Worker group " + f"{agent_context.run_result.state.name}, " f"is failure node: {failure_node}, " - f"no attempts({worker_context.worker_spec.max_restarts}) " + f"no attempts({agent_context.worker_spec.max_restarts}) " "left; will relaunch." ) - return DiagnosisAction.RELAUNCH_WORKER + return DiagnosisActionType.RELAUNCH_WORKER def _report_failure_to_master( self, failures: Dict[int, ProcessFailure], restart_count: int diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index c59c5593f..0cd03fe63 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -11,12 +11,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib import os import socket import threading import time from contextlib import closing -from typing import Dict +from typing import Dict, Optional from dlrover.proto import elastic_training_pb2, elastic_training_pb2_grpc from dlrover.python.common import env_utils, grpc @@ -27,6 +28,10 @@ ) from dlrover.python.common.log import default_logger as logger from dlrover.python.common.singleton import Singleton +from dlrover.python.diagnosis.common.diagnosis_action import ( + DiagnosisAction, + NoAction, +) from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData @@ -89,6 +94,10 @@ def __init__(self, master_addr, node_id, node_type, timeout=5): self._worker_local_process_id = int(os.getenv("LOCAL_RANK", 0)) self._ddp_server_port = self.find_free_port() + self._diagnosis_action_module = importlib.import_module( + "dlrover.python.diagnosis.common.diagnosis_action" + ) + def __del__(self): if self._channel: self._channel.close() @@ -235,9 +244,22 @@ def report_global_step( ) return self._report(message) - def report_heart_beat(self, timestamp): + def report_heart_beat(self, timestamp) -> DiagnosisAction: message = grpc.HeartBeat(timestamp=timestamp) - return self._report(message) + response: grpc.HeartbeatResponse = self._get(message) + action = NoAction() + action_cls: Optional[DiagnosisData] = getattr( + self._diagnosis_action_module, + response.action.action_cls, + ) + if action_cls is None: + logger.warning( + "Invalid diagnosis action " + f"action type: {response.action.action_cls}" + ) + else: + action = action_cls.from_json(response.action.action_content) + return action def get_cluster_version(self, version_type, task_type, task_id): request = grpc.ClusterVersionRequest( diff --git a/dlrover/python/elastic_agent/monitor/training.py b/dlrover/python/elastic_agent/monitor/training.py index d60396639..975aa2adb 100644 --- a/dlrover/python/elastic_agent/monitor/training.py +++ b/dlrover/python/elastic_agent/monitor/training.py @@ -126,7 +126,9 @@ def report_resource_with_step(self): def send_heartbeat(self): try: ts = int(time.time()) - self._master_client.report_heart_beat(ts) + action = self._master_client.report_heart_beat(ts) + if action: + pass except Exception: logger.warning("Fail to report a heartbeat.") diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 656e0166a..a45e8b928 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -87,11 +87,11 @@ find_free_port_in_set, ) from dlrover.python.common.log import default_logger as logger -from dlrover.python.common.worker import WorkerContext -from dlrover.python.diagnosis.common.constants import DiagnosisAction +from dlrover.python.diagnosis.common.constants import DiagnosisActionType from dlrover.python.elastic_agent.config.paral_config_tuner import ( ParalConfigTuner, ) +from dlrover.python.elastic_agent.context import AgentContext from dlrover.python.elastic_agent.diagnosis.diagnosis_agent import ( DiagnosisAgent, ) @@ -876,7 +876,7 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: logger.error(f"The worker fails with {run_result.failures}") self._save_ckpt_to_storage() - worker_context = WorkerContext( + context = AgentContext( worker_spec=self._worker_group.spec, remaining_failovers=self._remaining_failovers, restart_count=self._restart_count, @@ -884,14 +884,14 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: ) try: action = self._diagnose_agent.diagnose_training_failure( - worker_context + context ) except Exception as e: logger.warning(f"Failed to diagnose errors: {e}") if self._remaining_failovers > 0: - action = DiagnosisAction.RESTART_WORKER + action = DiagnosisActionType.RESTART_WORKER else: - action = DiagnosisAction.RELAUNCH_WORKER + action = DiagnosisActionType.RELAUNCH_WORKER self._process_diagnose_action(action) if self._worker_group.state == WorkerState.FAILED: return run_result @@ -904,10 +904,10 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: raise Exception(f"[{role}] worker group in {state.name} state") def _process_diagnose_action(self, action: str): - if action == DiagnosisAction.RESTART_WORKER: + if action == DiagnosisActionType.RESTART_WORKER: self._remaining_failovers -= 1 self._restart_workers(self._worker_group) - elif action == DiagnosisAction.RELAUNCH_WORKER: + elif action == DiagnosisActionType.RELAUNCH_WORKER: self._stop_workers(self._worker_group) self._worker_group.state = WorkerState.FAILED diff --git a/dlrover/python/master/elastic_training/rdzv_manager.py b/dlrover/python/master/elastic_training/rdzv_manager.py index db3ef9afd..d4c870a21 100644 --- a/dlrover/python/master/elastic_training/rdzv_manager.py +++ b/dlrover/python/master/elastic_training/rdzv_manager.py @@ -287,8 +287,8 @@ def join_rendezvous( psw=psw, ) logger.info( - f"Worker node with id: {meta.node_id} " - f"and rank: {meta.node_rank} " + f"Worker node with id: {meta.node_id}, " + f"rank: {meta.node_rank} and ip: {meta.node_ip} " f"joining rendezvous for round: {self._rdzv_round}." ) self._waiting_nodes[node_rank] = meta diff --git a/dlrover/python/master/node/dist_job_manager.py b/dlrover/python/master/node/dist_job_manager.py index 848a9706e..164cd4b04 100644 --- a/dlrover/python/master/node/dist_job_manager.py +++ b/dlrover/python/master/node/dist_job_manager.py @@ -36,6 +36,10 @@ from dlrover.python.common.grpc import ParallelConfig from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node, NodeGroupResource +from dlrover.python.diagnosis.common.diagnosis_action import ( + DiagnosisAction, + NoAction, +) from dlrover.python.master.monitor.error_monitor import K8sJobErrorMonitor from dlrover.python.master.node.event_callback import ( ClusterContext, @@ -210,8 +214,8 @@ def start(self): target=self._monitor_nodes, name="node_monitor", daemon=True ).start() threading.Thread( - target=self._monitor_node_heart_beat, - name="node_heart_beat_monitor", + target=self._diagnose_job, + name="diagnose_job", daemon=True, ).start() if os.getenv("KUBERNETES_SERVICE_HOST"): @@ -340,32 +344,25 @@ def _create_initial_scale_plan(self): def _init_training_node_manager(self): self._ps_manager = ParameterServerManager( - self._job_nodes.get(NodeType.PS, {}), self._job_resource, self._ps_relaunch_max_num, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - chief_nodes = self._job_nodes.get(NodeType.CHIEF, {}) - if not chief_nodes: - chief_nodes = self._job_nodes.get(NodeType.MASTER, {}) self._chief_manager = ChiefManager( - chief_nodes, self._job_resource, self._relaunch_on_worker_failure, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) self._worker_manager = WorkerManager( - self._job_nodes.get(NodeType.WORKER, {}), self._job_resource, self._relaunch_on_worker_failure, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) self._evaluator_manager = EvaluatorManager( - self._job_nodes.get(NodeType.EVALUATOR, {}), self._job_resource, self._relaunch_on_worker_failure, self._elastic_job.get_node_service_addr, @@ -376,7 +373,7 @@ def add_node_event_callback(self, node_event_callback): self._node_event_callbacks.append(node_event_callback) def _init_nodes(self): - self._job_nodes = self._job_resource.init_job_node_meta( + job_nodes = self._job_resource.init_job_node_meta( self._relaunch_on_worker_failure, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, @@ -390,28 +387,23 @@ def _init_nodes(self): self._pending_relaunch_count = 0 set_critical_node( - self._job_nodes, + job_nodes, self._ps_is_critical, self._ps_relaunch_max_num, self._critical_worker_index, ) - update_nodes_priority(self._job_nodes) - - self._ps_manager.update_nodes(self._job_nodes.get(NodeType.PS, {})) - chief_nodes = self._job_nodes.get(NodeType.CHIEF, {}) - if not chief_nodes: - chief_nodes = self._job_nodes.get(NodeType.MASTER, {}) - self._chief_manager.update_nodes(chief_nodes) - workers = self._job_nodes.get(NodeType.WORKER, {}) - self._worker_manager.update_nodes(workers) - evaluators = self._job_nodes.get(NodeType.EVALUATOR, {}) - self._evaluator_manager.update_nodes(evaluators) + update_nodes_priority(job_nodes) + self._job_context.update_job_nodes(job_nodes) + + self._ps_manager.update_nodes() + self._chief_manager.update_nodes() + self._worker_manager.update_nodes() + self._evaluator_manager.update_nodes() def _init_job_auto_scaler(self): self._job_autoscaler: JobAutoScaler = new_job_auto_scaler( self._job_args.distribution_strategy, self._job_resource, - self._job_nodes, self._job_optimizer, self._speed_monitor, self._ps_manager, @@ -465,11 +457,34 @@ def _monitor_node_heart_beat(self): logger.warning(detail_trace_back) time.sleep(15) + def _diagnose_job(self): + logger.info("Start diagnosing the job.") + while True: + if self._stopped: + logger.info("Stop diagnosing job.") + break + with self._lock: + try: + events = self._get_dead_node_event() + except Exception as e: + logger.warning(e) + events = [] + for event in events: + try: + self._process_event(event) + except Exception as e: + logger.warning(e) + detail_trace_back = traceback.format_exc() + logger.warning(detail_trace_back) + self._process_diagnosis_action(self._job_context.next_action()) + time.sleep(15) + def _get_dead_node_event(self, window_interval=900) -> List[NodeEvent]: now = time.time() dead_events: List[NodeEvent] = [] - logger.debug(f"Current job nodes are: {self._job_nodes}.") - for _, nodes in self._job_nodes.items(): + job_nodes = self.get_job_nodes() + logger.debug(f"Current job nodes are: {job_nodes}.") + for _, nodes in job_nodes.items(): for _, node in nodes.items(): if ( node.heartbeat_time > 0 @@ -519,7 +534,8 @@ def _get_dead_node_event(self, window_interval=900) -> List[NodeEvent]: def _get_nodes_time_info(self): result = {} - for _, nodes in self._job_nodes.items(): + job_nodes = self.get_job_nodes() + for _, nodes in job_nodes.items(): for _, node in nodes.items(): if node.heartbeat_time == 0: heartbeat_time = 0 @@ -566,7 +582,8 @@ def _process_list_nodes(self, nodes: List[Node]): logger.debug(f"Got list nodes: {nodes}") exist_nodes: Dict[str, List[int]] = {} - for node_type in self._job_nodes.keys(): + job_nodes = self.get_job_nodes() + for node_type in job_nodes.keys(): exist_nodes[node_type] = [] if nodes: @@ -577,7 +594,7 @@ def _process_list_nodes(self, nodes: List[Node]): # for nodes not in current 'job_nodes' obj, re add it if ( - node_id not in self._job_nodes[node_type] + node_id not in job_nodes[node_type] and node.status != NodeStatus.DELETED ): logger.info( @@ -585,7 +602,7 @@ def _process_list_nodes(self, nodes: List[Node]): " is re added without the event" ) new_node = copy.deepcopy(node) - self._job_nodes[node_type][node_id] = new_node + self._job_context.update_job_node(new_node) if node.status == NodeStatus.DELETED: event_type = NodeEventType.DELETED @@ -594,10 +611,11 @@ def _process_list_nodes(self, nodes: List[Node]): # Mock event to avoid missing events event = NodeEvent(event_type, node) self._process_event(event) + logger.debug(f"Got list nodes: {exist_nodes}") - for node_type in self._job_nodes.keys(): + for node_type in job_nodes.keys(): # Avoid dictionary keys changed during iteration - type_nodes = list(self._job_nodes[node_type].values()) + type_nodes = list(job_nodes[node_type].values()) for node in type_nodes: if ( node.status != NodeStatus.INITIAL @@ -632,10 +650,14 @@ def _get_pod_unique_labels(self, node: Node): ElasticJobLabel.RANK_INDEX_KEY: node.rank_index, } + def _process_diagnosis_action(self, action: DiagnosisAction): + pass + def _process_event(self, event: NodeEvent): node_type = event.node.type node_status = event.node.status node_id = event.node.id + job_nodes = self.get_job_nodes() # Skip deleted event of pod if the cluster has relaunched a new pod # with the same type and rank as the deleted pod. @@ -666,11 +688,11 @@ def _process_event(self, event: NodeEvent): ) return - if node_id not in self._job_nodes[node_type]: + if node_id not in job_nodes[node_type]: logger.info(f"The node {event.node.name} is released.") return else: - cur_node = self._job_nodes[node_type][node_id] + cur_node = job_nodes[node_type][node_id] logger.debug( f"Update node({cur_node.id}), " f"name: {cur_node.name}->{event.node.name}, " @@ -696,6 +718,7 @@ def _process_event(self, event: NodeEvent): restart_training=event.node.restart_training, relaunch_count=event.node.relaunch_count, ) + self._job_context.update_job_node(cur_node) # For the given node id, check whether it meets # the state change condition @@ -726,6 +749,8 @@ def _process_event(self, event: NodeEvent): cur_node.update_status(new_status) new_status = status_change_flow.to_status cur_node.set_exit_reason(event.node.exit_reason) + self._job_context.update_job_node(cur_node) + self._process_node_events(status_change_flow, cur_node) should_relaunch = self._should_relaunch( @@ -755,7 +780,6 @@ def _process_event(self, event: NodeEvent): "exit reason": cur_node.exit_reason, }, ) - if should_relaunch: self._relaunch_node(cur_node) @@ -885,27 +909,32 @@ def _relaunch_node(self, node: Node): self._set_ps_addrs_in_plan(plan) if self._remove_exited_node: plan.remove_nodes.append(node) - node.relaunchable = False # Avoid repeatedly relaunching the node. + # Avoid repeatedly relaunching the node. + node.relaunchable = False + self._job_context.update_job_node(node) self._scaler.scale(plan) def clear_exited_nodes(self): if not self._remove_exited_node: return + job_nodes = self.get_job_nodes() scale_plan = ScalePlan() with self._lock: - for _, nodes in self._job_nodes.items(): + for _, nodes in job_nodes.items(): for _, node in nodes.items(): if not node.is_released and node.exited(): scale_plan.remove_nodes.append(node) node.is_released = True + self._job_context.update_job_node(node) if len(scale_plan.remove_nodes) > 0: logger.info(f"Remove exited nodes {scale_plan.remove_nodes}") self._scaler.scale(scale_plan) def clear_all_nodes(self): scale_plan = ScalePlan() + job_nodes = self.get_job_nodes() with self._lock: - for _, nodes in self._job_nodes.items(): + for _, nodes in job_nodes.items(): for _, node in nodes.items(): if not node.is_released: scale_plan.remove_nodes.append(node) @@ -936,7 +965,8 @@ def all_workers_deleted(self): def all_critical_node_completed(self): alive_critical_nodes = [] - for _, nodes in self._job_nodes.items(): + job_nodes = self.get_job_nodes() + for _, nodes in job_nodes.items(): for node in nodes.values(): if node.critical and node.status in [ NodeStatus.INITIAL, @@ -951,7 +981,8 @@ def all_critical_node_completed(self): return completed def remove_worker(self, worker_id): - if self._job_nodes[NodeType.WORKER][worker_id].critical: + job_nodes = self.get_job_nodes() + if job_nodes[NodeType.WORKER][worker_id].critical: logger.info("Skip the critical worker %s", worker_id) else: logger.info("Delete worker %s", worker_id) @@ -981,37 +1012,30 @@ def post_ps_ready(self): def stop(self): self._enable_relaunch_node = False + job_nodes = self.get_job_nodes() with self._lock: - for node_type in self._job_nodes.keys(): - for node in self._job_nodes[node_type].values(): + for node_type in job_nodes.keys(): + for node in job_nodes[node_type].values(): node.critical = False node.is_released = True node.relaunchable = False - for node in self._job_nodes[NodeType.WORKER].values(): + self._job_context.update_job_node(node) + for node in job_nodes[NodeType.WORKER].values(): node.eval_time = self._speed_monitor.get_worker_eval_time( node.id ) + self._job_context.update_job_node(node) self._stopped = True def update_node_resource_usage( self, node_type, node_id, cpu, memory, gpu_stats=[] ): - if not self._job_nodes: - logger.warning( - "Skip updating node resource usage for job_nodes " - "hasn't been initialized." - ) - return - if ( - node_type not in self._job_nodes - or node_id not in self._job_nodes[node_type] - ): + node = self._job_context.job_node(node_type, node_id) + if node is None: logger.warning( - "Skip updating node resource usage for node " - f"{node_type}-{node_id} can not be found." + f"Skip update node[{node_type}][{node_id}] resources" ) return - node = self._job_nodes[node_type][node_id] node.update_resource_usage(cpu, memory, gpu_stats) cpu_percent = node.used_resource.cpu / node.config_resource.cpu if cpu_percent < _dlrover_context.hang_cpu_usage_percentage: @@ -1022,17 +1046,21 @@ def update_node_resource_usage( if node.start_hang_time > 0: now = datetime.now() node.start_hang_time = 0 + self._job_context.update_job_node(node) def update_node_service_addr(self, node_type, node_id, service_addr): - node = self._job_nodes[node_type][node_id] + node = self._job_context.job_node(node_type, node_id) + if node is None: + logger.error(f"no Node[{node_type}][{node_id}] found") + return node.update_service_address(service_addr) node.status = NodeStatus.RUNNING node.is_released = False - self._job_nodes[node_type][node_id] = node + self._job_context.update_job_node(node) def get_cur_cluster_ps(self): """Get PS nodes in the current training cluster.""" - logger.info("job nodes are {}".format(self._job_nodes)) + logger.info("job nodes are {}".format(self.get_job_nodes())) return self._ps_manager.get_training_ps_cluster() def get_next_cluster_ps(self): @@ -1051,9 +1079,10 @@ def remove_training_nodes(self): """Remove all PS and workers""" self._job_autoscaler.stop_auto_scaling() plan = ScalePlan() - training_nodes = list( - self._job_nodes[NodeType.WORKER].values() - ) + list(self._job_nodes[NodeType.PS].values()) + job_nodes = self.get_job_nodes() + training_nodes = list(job_nodes[NodeType.WORKER].values()) + list( + job_nodes[NodeType.PS].values() + ) for node in training_nodes: if ( node.status in [NodeStatus.RUNNING, NodeStatus.PENDING] @@ -1064,6 +1093,7 @@ def remove_training_nodes(self): node.is_released = True node.status = NodeStatus.DELETED logger.info("Remove node %s", node.name) + self._job_context.update_job_node(node) plan.remove_nodes.append(node) self._scaler.scale(plan) @@ -1089,9 +1119,13 @@ def _report_event( ) def _process_error( - self, node: Node, restart_count: int, error_data: str, level: str + self, + node: Optional[Node], + restart_count: int, + error_data: str, + level: str, ) -> bool: - if self._error_monitor: + if self._error_monitor and node is not None: return self._error_monitor.process_error( node, restart_count, error_data, level ) @@ -1125,7 +1159,7 @@ def handle_training_failure( self, node_type, node_id, restart_count=-1, error_data="", level="" ): """Process the training failure reported by the node.""" - node = self._job_nodes[node_type][node_id] + node = self._job_context.job_node(node_type, node_id) if node.is_released: logger.info(f"The node {node.name} has been released.") return @@ -1144,28 +1178,33 @@ def get_opt_strategy(self) -> ParallelConfig: return strategy def update_node_paral_config(self, node_type, node_id, paral_config): - node = self._job_nodes[node_type][node_id] + node = self._job_context.job_node(node_type, node_id) + if node is None: + logger.warning(f"not found Node[{node_type}][{node_id}]") + return node.update_paral_config(paral_config) + self._job_context.update_job_node(node) def verify_restarting_worker_training(self, node_type, node_id): if node_type != NodeType.WORKER: return False return self._worker_manager.verify_restarting_training(node_id) - def collect_node_heart_beat(self, node_type, node_id, timestamp): + def collect_node_heart_beat( + self, node_type, node_id, timestamp + ) -> DiagnosisAction: with self._lock: - if ( - node_type not in self._job_nodes - or node_id not in self._job_nodes[node_type] - ): - return - node = self._job_nodes[node_type][node_id] + node = self._job_context.job_node(node_type, node_id) + if node is None: + return NoAction() if node.heartbeat_time == 0: logger.info( f"Start receiving heartbeat from node {node_id}" f"-{node.name}" ) node.heartbeat_time = timestamp + self._job_context.update_job_node(node) + return self._job_context.next_action(instance=node_id) def update_node_required_info_callback(self): self._worker_manager.update_node_required_info(self._nodes_required) @@ -1184,20 +1223,18 @@ def process_reported_node_event(self, node_event: NodeEvent): node_id = node.id with self._lock: - if ( - node_type in self._job_nodes - and node_id in self._job_nodes[node_type] - ): + target_node = self._job_context.job_node(node_type, node_id) + if target_node: logger.info( f"Node {node_id}({node_type}) reported " f"status to {event_type}." ) if event_type == NodeEventType.SUCCEEDED: - self._job_nodes[node_type][node_id].set_as_succeeded() + target_node.set_as_succeeded() elif node_event.is_node_check_event(): - self._job_nodes[node_type][ - node_id - ].update_node_check_result(event_type) + target_node.update_node_check_result(event_type) + + self._job_context.update_job_node(target_node) def create_job_manager(args: JobArgs, speed_monitor) -> DistributedJobManager: diff --git a/dlrover/python/master/node/job_auto_scaler.py b/dlrover/python/master/node/job_auto_scaler.py index 1a5f04b55..28b2aa7bd 100644 --- a/dlrover/python/master/node/job_auto_scaler.py +++ b/dlrover/python/master/node/job_auto_scaler.py @@ -23,8 +23,9 @@ ) from dlrover.python.common.global_context import Context from dlrover.python.common.log import default_logger as logger -from dlrover.python.common.node import Node, NodeResource +from dlrover.python.common.node import NodeResource from dlrover.python.master.monitor.speed_monitor import SpeedMonitor +from dlrover.python.master.node.job_context import get_job_context from dlrover.python.master.node.ps import ParameterServerManager from dlrover.python.master.node.worker import WorkerManager from dlrover.python.master.resource.job import ( @@ -40,7 +41,6 @@ def new_job_auto_scaler( job_strategy, job_resource: JobResource, - job_nodes: Dict[str, Dict[int, Node]], job_optimizer: JobResourceOptimizer, speed_monitor: SpeedMonitor, ps_manager: ParameterServerManager, @@ -50,7 +50,6 @@ def new_job_auto_scaler( if job_strategy == DistributionStrategy.PS: return PSTrainingAutoScaler( job_resource, - job_nodes, job_optimizer, speed_monitor, ps_manager, @@ -60,7 +59,6 @@ def new_job_auto_scaler( elif job_strategy == DistributionStrategy.ALLREDUCE: return AllreduceTrainingAutoScaler( job_resource, - job_nodes, job_optimizer, speed_monitor, worker_manager, @@ -76,18 +74,17 @@ class JobAutoScaler(metaclass=ABCMeta): def __init__( self, job_resource: JobResource, - job_nodes: Dict[str, Dict[int, Node]], job_optimizer: JobResourceOptimizer, speed_monitor: SpeedMonitor, node_scaler: Scaler, scale_interval: int, ): self._job_resource = job_resource - self._job_nodes = job_nodes self._job_optimizer = job_optimizer self._speed_monitor = speed_monitor self._scaler = node_scaler self._scale_interval = scale_interval + self._job_context = get_job_context() self._suggested_stop = False self._autoscaling_started = False @@ -95,6 +92,11 @@ def __init__( def suggested_stop(self): return self._suggested_stop + def get_job_nodes(self, node_type=""): + if node_type == "": + return self._job_context.job_nodes() + return self._job_context.job_nodes_by_type(node_type) + @abstractmethod def start_auto_scaling(self): """Start auto-scaling nodes of a job""" @@ -118,7 +120,6 @@ class PSTrainingAutoScaler(JobAutoScaler): def __init__( self, job_resource: JobResource, - job_nodes: Dict[str, Dict[int, Node]], job_optimizer: JobResourceOptimizer, speed_monitor: SpeedMonitor, ps_manager: ParameterServerManager, @@ -127,7 +128,6 @@ def __init__( ) -> None: super().__init__( job_resource, - job_nodes, job_optimizer, speed_monitor, node_scaler, @@ -224,7 +224,8 @@ def execute_job_optimization_plan(self, plan: ResourcePlan): scale_plan.merge(ps_plan) self._speed_monitor.reset_running_speed_monitor() elif node_type == NodeType.WORKER: - chief_num = len(self._job_nodes.get(NodeType.CHIEF, [])) + chief_nodes = self.get_job_nodes(NodeType.CHIEF) + chief_num = len(chief_nodes) worker_num = chief_num + group.count self._speed_monitor.set_target_worker_num(worker_num) worker_plan = self._worker_manager.adjust_worker(group) @@ -260,7 +261,7 @@ def _migrate_nodes(self, node_resources: Dict[str, NodeResource]): return scale_plan def _reduce_timeout_pending_node_resource(self): - """Cut down CPU cores of pending pod at the job starts""" + """Cut down CPU cores of pending pod when job starts""" scale_plan = ScalePlan() plan = self._ps_manager.reduce_pending_node_resource() scale_plan.merge(plan) @@ -278,7 +279,6 @@ class AllreduceTrainingAutoScaler(JobAutoScaler): def __init__( self, job_resource: JobResource, - job_nodes: Dict[str, Dict[int, Node]], job_optimizer: JobResourceOptimizer, speed_monitor: SpeedMonitor, worker_manager: WorkerManager, @@ -286,14 +286,12 @@ def __init__( ) -> None: super().__init__( job_resource, - job_nodes, job_optimizer, speed_monitor, node_scaler, 1800, ) self._worker_manager = worker_manager - self._workers = job_nodes[NodeType.WORKER] def start_auto_scaling(self): """Start auto-scaling nodes of a job""" @@ -334,7 +332,8 @@ def _periodic_adjust_worker(self): def _get_alive_worker_num(self): worker_num = 0 - for _, worker in self._workers.items(): + workers = self.get_job_nodes(NodeType.WORKER) + for _, worker in workers.items(): if worker.status in [ NodeStatus.RUNNING, NodeStatus.PENDING, diff --git a/dlrover/python/master/node/job_context.py b/dlrover/python/master/node/job_context.py new file mode 100644 index 000000000..b6110a156 --- /dev/null +++ b/dlrover/python/master/node/job_context.py @@ -0,0 +1,114 @@ +# Copyright 2024 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import threading +from typing import Dict, Optional + +from dlrover.python.common.constants import NodeType +from dlrover.python.common.node import Node +from dlrover.python.common.singleton import Singleton +from dlrover.python.diagnosis.common.constants import DiagnosisConstant +from dlrover.python.diagnosis.common.diagnosis_action import ( + DiagnosisActionQueue, +) + + +class JobContext(Singleton): + """ + JobContext includes critical states of the training job that + will be shared across multiple components. + """ + + def __init__(self): + self._action_queue = DiagnosisActionQueue() + self._job_nodes: Dict[str, Dict[int, Node]] = {} + self._locker = threading.Lock() + + def enqueue_actions(self, actions): + for action in actions: + self.enqueue_action(action) + + def enqueue_action(self, action): + self._action_queue.add_action(action) + + def next_action( + self, + instance=DiagnosisConstant.LOCAL_INSTANCE, + ): + return self._action_queue.next_action(instance=instance) + + @property + def ps_nodes(self) -> Dict[int, Node]: + with self._locker: + if NodeType.PS in self._job_nodes: + return self._job_nodes[NodeType.PS] + return {} + + @property + def workers(self) -> Dict[int, Node]: + with self._locker: + if NodeType.WORKER in self._job_nodes: + return self._job_nodes[NodeType.WORKER] + return {} + + def job_nodes(self) -> Dict[str, Dict[int, Node]]: + """ + return a copy of job nodes + """ + with self._locker: + return copy.deepcopy(self._job_nodes) + + def job_node(self, node_type: str, node_id: int) -> Optional[Node]: + with self._locker: + node_type = self._preprocess(node_type) + if ( + node_type not in self._job_nodes + or node_id not in self._job_nodes[node_type] + ): + return None + return copy.deepcopy(self._job_nodes[node_type][node_id]) + + def job_nodes_by_type(self, node_type: str) -> Dict[int, Node]: + with self._locker: + node_type = self._preprocess(node_type) + if node_type not in self._job_nodes: + return {} + return copy.deepcopy(self._job_nodes[node_type]) + + def _preprocess(self, node_type: str) -> str: + if node_type == NodeType.CHIEF and node_type not in self._job_nodes: + return NodeType.MASTER + return node_type + + def update_job_nodes(self, job_nodes: Dict[str, Dict[int, Node]]): + with self._locker: + self._job_nodes = copy.deepcopy(job_nodes) + + def update_job_node(self, node: Node): + with self._locker: + if self._job_nodes is None: + self._job_nodes = {} + if node.type not in self._job_nodes: + self._job_nodes[node.type] = {} + + self._job_nodes[node.type][node.id] = copy.deepcopy(node) + + def clear_job_nodes(self): + with self._locker: + self._job_nodes = {} + + +def get_job_context() -> JobContext: + job_context = JobContext.singleton_instance() + return job_context diff --git a/dlrover/python/master/node/job_manager.py b/dlrover/python/master/node/job_manager.py index c0124b8b0..a03f991ba 100644 --- a/dlrover/python/master/node/job_manager.py +++ b/dlrover/python/master/node/job_manager.py @@ -15,12 +15,13 @@ from typing import Dict from dlrover.python.common.log import default_logger as logger -from dlrover.python.common.node import Node +from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction from dlrover.python.master.hyperparams.simple_strategy_generator import ( SimpleStrategyGenerator, ) from dlrover.python.master.monitor.error_monitor import ErrorMonitor from dlrover.python.master.monitor.speed_monitor import SpeedMonitor +from dlrover.python.master.node.job_context import get_job_context from dlrover.python.master.node.training_node import ( SyncNodeTrainingPorts, TrainingNodeConfig, @@ -53,11 +54,10 @@ def __init__( self._stopped = False self._speed_monitor: SpeedMonitor = speed_monitor self._error_monitor: ErrorMonitor = error_monitor - - self._job_nodes: Dict[str, Dict[int, Node]] = {} self._nodes_required = (0, 0, 0) self._training_node_config = TrainingNodeConfig(external_config) + self._job_context = get_job_context() @abstractmethod def start(self): @@ -196,10 +196,17 @@ def handle_training_failure( pass @abstractmethod - def collect_node_heart_beat(self, node_type, node_id, timestamp): + def collect_node_heart_beat( + self, node_type, node_id, timestamp + ) -> DiagnosisAction: """Collect the heart beat message of nodes.""" pass + def get_job_nodes(self, node_type=""): + if node_type == "": + return self._job_context.job_nodes() + return self._job_context.job_nodes_by_type(node_type) + def sync_node_training_port(self, node_id, port) -> SyncNodeTrainingPorts: return self._training_node_config.sync_node_training_port( node_id, port diff --git a/dlrover/python/master/node/local_job_manager.py b/dlrover/python/master/node/local_job_manager.py index d7e82a28b..c0a911b04 100644 --- a/dlrover/python/master/node/local_job_manager.py +++ b/dlrover/python/master/node/local_job_manager.py @@ -14,6 +14,10 @@ from dlrover.python.common.constants import NodeStatus, NodeType from dlrover.python.common.grpc import ParallelConfig from dlrover.python.common.node import Node +from dlrover.python.diagnosis.common.diagnosis_action import ( + DiagnosisAction, + NoAction, +) from dlrover.python.master.monitor.error_monitor import SimpleErrorMonitor from dlrover.python.master.node.job_manager import JobManager from dlrover.python.scheduler.job import JobArgs @@ -36,16 +40,17 @@ def __init__( self._job_resource_optimizer = None def start(self): - self._job_nodes[NodeType.WORKER] = {} + workers = {} worker = self._job_args.node_args[NodeType.WORKER].group_resource self._training_node_config.set_node_num(worker.count) for i in range(worker.count): - self._job_nodes[NodeType.WORKER][i] = Node( + workers[i] = Node( name=NodeType.WORKER + f"-{i}", node_type=NodeType.WORKER, node_id=i, status=NodeStatus.RUNNING, ) + self._job_context.update_job_node(workers[i]) def should_early_stop(self): return False @@ -56,21 +61,32 @@ def add_node_event_callback(self, node_event_callback): def update_node_resource_usage( self, node_type, node_id, cpu, memory, gpu_stats=[] ): - node = self._job_nodes[node_type][node_id] + node = self._job_context.job_node(node_type, node_id) + if node is None: + return node.update_resource_usage(cpu, memory, gpu_stats) + self._job_context.update_job_node(node) def handle_training_failure( self, node_type, node_id, restart_count=-1, error_data="", level="" ): """Process the training failure reported by the node.""" - node = self._job_nodes[node_type][node_id] + node = self._job_context.job_node(node_type, node_id) + if node is None: + return self._error_monitor.process_error( node, restart_count, error_data, level ) - def collect_node_heart_beat(self, node_type, node_id, timestamp): - node = self._job_nodes[node_type][node_id] + def collect_node_heart_beat( + self, node_type, node_id, timestamp + ) -> DiagnosisAction: + node = self._job_context.job_node(node_type, node_id) + if node is None: + return NoAction() node.heartbeat_time = timestamp + self._job_context.update_job_node(node) + return NoAction() def close_job(self): pass @@ -91,12 +107,11 @@ def remove_worker(self, worker_id): pass def get_running_nodes(self): - nodes = list(self._job_nodes[NodeType.WORKER].values()) - return nodes + nodes = self._job_context.job_nodes_by_type(NodeType.WORKER) + return nodes.values() def get_running_workers(self): - workers = list(self._job_nodes[NodeType.WORKER].values()) - return workers + return self._job_context.job_nodes_by_type(NodeType.WORKER) def post_ps_ready(self): pass @@ -146,8 +161,11 @@ def get_opt_strategy(self) -> ParallelConfig: return strategy def update_node_paral_config(self, node_type, node_id, paral_config): - node = self._job_nodes[node_type][node_id] + node = self._job_context.job_node(node_type, node_id) + if node is None: + return node.update_paral_config(paral_config) + self._job_context.update_job_node(node) def create_job_manager(args: JobArgs, speed_monitor) -> LocalJobManager: diff --git a/dlrover/python/master/node/ps.py b/dlrover/python/master/node/ps.py index 880eecc14..3f32ce080 100644 --- a/dlrover/python/master/node/ps.py +++ b/dlrover/python/master/node/ps.py @@ -39,7 +39,6 @@ class ParameterServerManager(TrainingNodeManager): def __init__( self, - ps_nodes: Dict[int, Node], job_resource: JobResource, max_relaunch_num, new_service_fn, @@ -47,8 +46,6 @@ def __init__( ): """ Args: - ps_nodes: A dictionary where the key is the index of PS pod - and the value is the PodInfo instance of PS pod. job_resource: the resource configuration of a job. max_relaunch_num: The maximum relaunch number of PS. new_service_fn: A callable function to generate a server name of @@ -57,7 +54,7 @@ def __init__( PS. """ super(ParameterServerManager, self).__init__( - ps_nodes, new_node_name_fn + NodeType.PS, new_node_name_fn ) self._max_relaunch_num = max_relaunch_num self._job_resource = job_resource @@ -71,8 +68,11 @@ def __init__( self._node_id_iter = itertools.count(self._job_resource.ps_num) self._init_training_ps_cluster() + def _ps_nodes(self): + return self._job_context.ps_nodes + def _init_training_ps_cluster(self): - for node in self._nodes.values(): + for node in self._ps_nodes().values(): alive = node.status in [ NodeStatus.INITIAL, NodeStatus.PENDING, @@ -92,10 +92,11 @@ def relaunch_node(self, node: Node, remove_exited_node=False): with self._lock: node.is_released = True new_id = next(self._node_id_iter) - self._nodes[new_id] = node.get_relaunch_node_info(new_id) + new_node = node.get_relaunch_node_info(new_id) + self._job_context.update_job_node(new_node) if node in self._training_ps_cluster: i = self._training_ps_cluster.index(node) - self._training_ps_cluster[i] = self._nodes[new_id] + self._training_ps_cluster[i] = self._ps_nodes()[new_node.id] logger.info("Relaunch node %s to %s", node.name, new_id) plan.launch_nodes.append( Node( @@ -154,7 +155,7 @@ def _scale_up_ps(self, up_num): critical=True, service_addr=service_addr, ) - self._nodes[ps_id] = ps + self._job_context.update_job_node(ps) new_ps.append(ps) logger.info("Create PS %s", ps) return new_ps @@ -191,6 +192,7 @@ def process_after_ps_cluster_ready(self): node.critical = False node.relaunchable = False node.is_released = True + self._job_context.update_job_node(node) if node.id in self._migrated_ps_nodes: self._migrated_ps_nodes.pop(node.id) plan.remove_nodes.append(node) @@ -199,7 +201,7 @@ def process_after_ps_cluster_ready(self): def _get_alive_ps(self) -> List[Node]: """Get all running PS pods""" alive_ps = [] - for node in self._nodes.values(): + for node in self._ps_nodes().values(): if node.status == NodeStatus.RUNNING and not node.is_released: alive_ps.append(node) return alive_ps @@ -213,7 +215,7 @@ def get_next_training_ps_cluster(self): return self._next_training_ps_cluster all_new_ps_ready = True - for node in self._nodes.values(): + for node in self._ps_nodes().values(): if self._wait_ps_node(node): all_new_ps_ready = False break @@ -234,7 +236,7 @@ def has_ps_failure(self): Check whether there is PS failure and the master does not relaunch the failed PS node. """ - for node in self._nodes.values(): + for node in self._ps_nodes().values(): if node.timeout(_dlrover_ctx.seconds_to_wait_failed_ps): return True return False @@ -265,6 +267,7 @@ def _pre_drop_migrated_ps(self, alive_ps: List[Node]): ): if node not in self._pre_dropped_ps: node.migrated = True + self._job_context.update_job_node(node) self._pre_dropped_ps.append(node) def get_total_request_cpu(self): @@ -290,7 +293,8 @@ def get_ready_for_new_ps_cluster(self): def get_ps_addrs(self): """Get the address list of ps services""" ps_addrs = {} - for ps in list(self._nodes.values()): + nodes = self._ps_nodes() + for ps in list(nodes.values()): if ( ps.id not in self._migrated_ps_nodes and not ps.is_released @@ -305,7 +309,8 @@ def get_ps_addrs(self): def delete_running_ps(self): """Delete all running ps pods""" plan = ScalePlan() - for node in list(self._nodes.values()): + nodes = self._ps_nodes() + for node in list(nodes.values()): if ( node.status in [NodeStatus.RUNNING, NodeStatus.PENDING] and not node.is_released @@ -319,6 +324,8 @@ def delete_running_ps(self): ) node.is_released = True node.status = NodeStatus.DELETED + self._job_context.update_job_node(node) + plan.remove_nodes.append(node) return plan @@ -335,9 +342,13 @@ def migrate_parameter_servers(self, ps_nodes: Dict[str, NodeResource]): def _migrate_parameter_server(self, name: str, cpu=0, memory=0): """Migrate the parameter server node into a new pod""" old_ps_id = int(name.split("-")[-1]) - original_pod = self._nodes[old_ps_id] if old_ps_id in self._migrated_ps_nodes: return + nodes = self._ps_nodes() + if old_ps_id not in nodes: + logger.error(f"not found PS-{old_ps_id} in job") + return + original_pod = nodes[old_ps_id] resource = copy.deepcopy(original_pod.config_resource) with self._lock: @@ -360,8 +371,8 @@ def _migrate_parameter_server(self, name: str, cpu=0, memory=0): service_addr=service_addr, name=self._new_node_name_fn(NodeType.PS, new_ps_id), ) - self._nodes[new_ps_id] = new_node - self._migrated_ps_nodes[old_ps_id] = new_node + self._job_context.update_job_node(new_node) + self._migrated_ps_nodes[old_ps_id] = self._ps_nodes()[new_node.id] logger.info("Migrated PS %s to PS %s", old_ps_id, new_ps_id) return new_node @@ -369,9 +380,10 @@ def exist_migrated_ps_nodes(self): return len(self._migrated_ps_nodes) > 0 def is_all_running(self): + nodes = self._job_context.job_nodes_by_type(self._node_type) running_ps = [ pod_info.id - for pod_info in self._nodes.values() + for pod_info in nodes.values() if pod_info.status == NodeStatus.RUNNING ] return len(running_ps) == self._job_resource.ps_num @@ -417,7 +429,7 @@ def is_training_hang_by_pending(self, total_node_num, job_type) -> bool: return False # collect pending and running nodes - cur_nodes = list(self._nodes.values()) + cur_nodes = list(self._ps_nodes().values()) pending_ps: List[Node] = [] running_ps: List[Node] = [] for node in cur_nodes: diff --git a/dlrover/python/master/node/training_node.py b/dlrover/python/master/node/training_node.py index 0cd0c4dca..06a941cb8 100644 --- a/dlrover/python/master/node/training_node.py +++ b/dlrover/python/master/node/training_node.py @@ -33,6 +33,7 @@ from dlrover.python.common.global_context import Context from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node +from dlrover.python.master.node.job_context import get_job_context from dlrover.python.master.scaler.base_scaler import ScalePlan from dlrover.python.scheduler.job import JobArgs @@ -184,19 +185,21 @@ def is_all_nodes_pending_judgement(strategy) -> bool: class TrainingNodeManager(object): def __init__( self, - nodes: Dict[int, Node], + node_type: str, new_node_name_fn, ): """ Args: - nodes: training nodes + node_type: node type new_node_name_fn: new node name function """ - self._nodes = nodes + self._job_context = get_job_context() + nodes = self._job_context.job_nodes_by_type(node_type) + self._node_type = node_type self._new_node_name_fn = new_node_name_fn self._lock = threading.Lock() - self._node_id_iter = itertools.count(len(self._nodes)) - self._node_rank_iter = itertools.count(len(self._nodes)) + self._node_id_iter = itertools.count(len(nodes)) + self._node_rank_iter = itertools.count(len(nodes)) self._pending_nodes: List[Node] = [] @property @@ -214,34 +217,38 @@ def first_pending_node(self): @property def cur_nodes(self): - cur_nodes = [node.name for node in self._nodes.values()] + nodes = self._job_context.job_nodes_by_type(self._node_type) + cur_nodes = [node.name for node in nodes.values()] return cur_nodes - def update_nodes(self, nodes): - self._nodes = nodes - self._node_id_iter = itertools.count(len(self._nodes)) - self._node_rank_iter = itertools.count(len(self._nodes)) + def update_nodes(self): + nodes = self._job_context.job_nodes_by_type(self._node_type) + self._node_id_iter = itertools.count(len(nodes)) + self._node_rank_iter = itertools.count(len(nodes)) def remove_node(self, node_id): plan = ScalePlan() - if node_id not in self._nodes: - logger.info("Delete non-existed worker %s", node_id) - return plan - worker = self._nodes[node_id] with self._lock: + worker = self._job_context.job_node(self._node_type, node_id) + if worker is None: + logger.info("Delete non-existed worker %s", node_id) + return plan if worker.status in [NodeStatus.DELETED, NodeStatus.INITIAL]: logger.error("Unknown deletable worker id: %s" % node_id) return worker.is_released = True + self._job_context.update_job_node(worker) plan.remove_nodes.append(worker) return plan def relaunch_node(self, node: Node, remove_exited_node=False): plan = ScalePlan() + nodes = self._job_context.job_nodes_by_type(self._node_type) with self._lock: new_id = next(self._node_id_iter) relaunch_node = node.get_relaunch_node_info(new_id) - self._nodes[new_id] = relaunch_node + nodes[new_id] = relaunch_node + self._job_context.update_job_node(relaunch_node) logger.info("Relaunch node %s to %s", node.name, new_id) plan.launch_nodes.append( Node( @@ -257,6 +264,7 @@ def relaunch_node(self, node: Node, remove_exited_node=False): ) if remove_exited_node and not node.is_released and node.exited(): node.is_released = True + self._job_context.update_job_node(node) plan.remove_nodes.append(node) return plan @@ -265,19 +273,22 @@ def reduce_pending_node_resource(self): plan = ScalePlan() # Avoid dictionary changed size during iteration. - cur_nodes = list(self._nodes.values()) + nodes = self._job_context.job_nodes_by_type(self._node_type) + cur_nodes = list(nodes.values()) for node in cur_nodes: if node.status == NodeStatus.PENDING: reduced = reduce_timeout_pending_node_resource(node) if reduced: node.relaunchable = False + self._job_context.update_job_node(node) node_plan = self.relaunch_node(node) plan.remove_nodes.append(node) plan.merge(node_plan) return plan def get_pending_timeout_oom_recovered_node(self): - cur_nodes = list(self._nodes.values()) + nodes = self._job_context.job_nodes_by_type(self._node_type) + cur_nodes = list(nodes.values()) now = time.time() nodes = [] for node in cur_nodes: @@ -301,13 +312,17 @@ def get_running_nodes(self): """TensorFlow Chief nodes""" nodes = [] with self._lock: - for node in self._nodes.values(): + training_nodes = self._job_context.job_nodes_by_type( + self._node_type + ) + for node in training_nodes.values(): if node.status == NodeStatus.RUNNING: nodes.append(node) return nodes def all_nodes_exited(self): - if len(self._nodes) == 0: + nodes = self._job_context.job_nodes_by_type(self._node_type) + if len(nodes) == 0: return True counter = self._get_node_counter() @@ -320,7 +335,7 @@ def all_nodes_exited(self): running_workers = [] pending_high_workers = [] pending_low_workers = [] - for worker_id, worker in self._nodes.items(): + for worker_id, worker in nodes.items(): if worker.is_released: continue if worker.config_resource.priority == PriorityClass.LOW: @@ -372,7 +387,8 @@ def all_nodes_failed(self): def running_nodes_hanged(self) -> List[bool]: cur_time = time.time() node_hang = [] - nodes = list(self._nodes.values()) # Avoid dictionary changed size. + nodes_dict = self._job_context.job_nodes_by_type(self._node_type) + nodes = list(nodes_dict.values()) # Avoid dictionary changed size. for node in nodes: if node.status == NodeStatus.RUNNING: timeout = NodeResourceLimit.MAX_HANG_TIMEOUT_SECS @@ -388,12 +404,14 @@ def running_nodes_hanged(self) -> List[bool]: f"{timeout} from {date_time}!!!" ) node.hang = hang + self._job_context.update_job_node(node) node_hang.append(hang) return node_hang def _get_node_counter(self): with self._lock: - return Counter([node.status for node in self._nodes.values()]) + nodes = self._job_context.job_nodes_by_type(self._node_type) + return Counter([node.status for node in nodes.values()]) def update_critical_node(self, critical_node_restarts): """Update critical node by a dict. @@ -402,7 +420,8 @@ def update_critical_node(self, critical_node_restarts): and values are the relaunchable number of nodes """ logger.info("Update critical worker {}".format(critical_node_restarts)) - for id, node in self._nodes.items(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for id, node in nodes.items(): if id in critical_node_restarts: node.critical = True node.max_relaunch_count = critical_node_restarts[id] diff --git a/dlrover/python/master/node/worker.py b/dlrover/python/master/node/worker.py index 5cb4be2e7..fcc7ec5a2 100644 --- a/dlrover/python/master/node/worker.py +++ b/dlrover/python/master/node/worker.py @@ -41,7 +41,6 @@ class ChiefManager(TrainingNodeManager): def __init__( self, - chief_nodes: Dict[int, Node], job_resource: JobResource, max_relaunch_num, new_service_fn, @@ -49,8 +48,6 @@ def __init__( ): """ Args: - chief_nodes: A dictionary where the key is the index of - chief and the value is a Node instance. job_resource: the resource configuration of a job. max_relaunch_num: The maximum relaunch number of a chief. new_service_fn: A callable function to generate a server name of @@ -58,7 +55,7 @@ def __init__( new_node_name_fn: A callable function to generate a node name of chief. """ - super(ChiefManager, self).__init__(chief_nodes, new_node_name_fn) + super(ChiefManager, self).__init__(NodeType.CHIEF, new_node_name_fn) self._job_resource = job_resource self._max_relaunch_num = max_relaunch_num self._new_service_fn = new_service_fn @@ -66,7 +63,8 @@ def __init__( def is_chief_running(self): """The chief worker with id=0 is responsible to initialize variables in TensorFlow 1.x PS strategy""" - for node in self._nodes.values(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for node in nodes.values(): if node.status == NodeStatus.RUNNING: return True return False @@ -75,7 +73,6 @@ def is_chief_running(self): class EvaluatorManager(TrainingNodeManager): def __init__( self, - evaluator_nodes: Dict[int, Node], job_resource: JobResource, max_relaunch_num, new_service_fn, @@ -83,8 +80,6 @@ def __init__( ): """ Args: - evaluator_nodes: A dictionary where the key is the index of - evaluator and the value is a Node instance. job_resource: the resource configuration of a job. max_relaunch_num: The maximum relaunch number of an evaluator. new_service_fn: A callable function to generate a server name of @@ -93,7 +88,7 @@ def __init__( evaluator. """ super(EvaluatorManager, self).__init__( - evaluator_nodes, new_node_name_fn + NodeType.EVALUATOR, new_node_name_fn ) self._job_resource = job_resource self._max_relaunch_num = max_relaunch_num @@ -102,7 +97,8 @@ def __init__( def is_chief_running(self): """The chief worker with id=0 is responsible to initialize variables in TensorFlow 1.x PS strategy""" - for node in self._nodes.values(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for node in nodes.values(): if node.status == NodeStatus.RUNNING: return True return False @@ -111,7 +107,6 @@ def is_chief_running(self): class WorkerManager(TrainingNodeManager): def __init__( self, - worker_nodes: Dict[int, Node], job_resource: JobResource, max_relaunch_num, new_service_fn, @@ -119,8 +114,6 @@ def __init__( ): """ Args: - worker_nodes: A dictionary where the key is the index of worker - and the value is a Node instance. job_resource: the resource configuration of a job. max_relaunch_num: The maximum relaunch number of worker. new_service_fn: A callable function to generate a server name of @@ -128,7 +121,7 @@ def __init__( new_node_name_fn: A callable function to generate a node name of worker. """ - super(WorkerManager, self).__init__(worker_nodes, new_node_name_fn) + super(WorkerManager, self).__init__(NodeType.WORKER, new_node_name_fn) self._job_resource = job_resource self._max_relaunch_num = max_relaunch_num self._new_service_fn = new_service_fn @@ -147,7 +140,8 @@ def adjust_worker(self, worker_resource: NodeGroupResource): ) ) alive_workers = [] - for worker in self._nodes.values(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for worker in nodes.values(): if worker.status in ALIVE_STATUS: alive_workers.append(worker) alive_num = len(alive_workers) @@ -177,8 +171,8 @@ def _scale_up_workers(self, up_num): config_resource=copy.deepcopy(worker_resource), service_addr=service_addr, ) - self._nodes[worker_id] = new_node - logger.info("Create worker %s", self._nodes[worker_id]) + self._job_context.update_job_node(new_node) + logger.info("Create worker %s", new_node) plan.launch_nodes.append(new_node) return plan @@ -198,8 +192,9 @@ def _scale_down_workers(self, down_num, running_workers: List[Node]): def delete_exited_workers(self): """Delete failed, succeed, finished workers.""" plan = ScalePlan() + nodes = self._job_context.job_nodes_by_type(self._node_type) with self._lock: - for worker in self._nodes.values(): + for worker in nodes.values(): if ( worker.status in [ @@ -215,7 +210,8 @@ def delete_exited_workers(self): def delete_running_workers(self): plan = ScalePlan() - for worker in self._nodes.values(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for worker in nodes.values(): if not worker.critical and worker.status in [ NodeStatus.RUNNING, NodeStatus.PENDING, @@ -231,7 +227,11 @@ def delete_running_workers(self): return plan def remove_noncritical_worker(self, worker_id): - if self._nodes[worker_id].critical: + node = self._job_context.job_node(self._node_type, worker_id) + if node is None: + logger.error(f"not found node[{self._node_type}][{worker_id}]") + return + if node.critical: logger.info("Skip the critical worker %s", worker_id) else: return self.remove_node(worker_id) @@ -239,9 +239,10 @@ def remove_noncritical_worker(self, worker_id): def migrate_workers(self, workers: Dict[str, NodeResource]): """Migrate workers with the new resource""" plan = ScalePlan() + nodes = self._job_context.job_nodes_by_type(self._node_type) for name, resource in workers.items(): old_node_id = int(name.split("-")[-1]) - old_node = self._nodes[old_node_id] + old_node = nodes[old_node_id] if old_node.critical: continue old_node.migrated = True @@ -257,7 +258,7 @@ def migrate_workers(self, workers: Dict[str, NodeResource]): rank_index=task_id, name=self._new_node_name_fn(NodeType.WORKER, node_id), ) - self._nodes[node_id] = new_node + self._job_context.update_job_node(new_node) plan.launch_nodes.append(new_node) plan.remove_nodes.append(old_node) return plan @@ -268,17 +269,19 @@ def remove_not_joined_rdzv_workers(self, worker_ranks: List[int]): worker_ranks: The rank of worker which does not join rendezvous. """ plan = ScalePlan() - for node_id, node in self._nodes.items(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for node_id, node in nodes.items(): if node.rank_index in worker_ranks: p = self.remove_node(node.id) - self._nodes[node_id].relaunchable = False + nodes[node_id].relaunchable = False if p: plan.merge(p) return plan def has_exited_worker(self): """Check whether there is exited worker except evicted workers.""" - for worker in self._nodes.values(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for worker in nodes.values(): if ( worker.exit_reason == NodeExitReason.FATAL_ERROR or worker.status == NodeStatus.SUCCEEDED @@ -288,7 +291,8 @@ def has_exited_worker(self): def wait_worker_restart(self): """Check whether there are workers tha have remaining retries.""" - for worker in self._nodes.values(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for worker in nodes.values(): if ( worker.exit_reason == NodeExitReason.KILLED and worker.relaunch_count < worker.max_relaunch_count @@ -311,11 +315,15 @@ def verify_restarting_training(self, node_id): bool """ restart = False - worker = self._nodes[node_id] + worker = self._job_context.job_node(self._node_type, node_id) + if worker is None: + logger.error(f"not found worker-{node_id}") + return False if not worker.is_released: restart = worker.restart_training # Set False to avoid restart repeatedly. worker.restart_training = False + self._job_context.update_job_node(worker) return restart def is_training_hang_by_pending(self, total_node_num, job_type) -> bool: @@ -359,7 +367,7 @@ def is_training_hang_by_pending(self, total_node_num, job_type) -> bool: return False # collect pending and running nodes - cur_nodes = list(self._nodes.values()) + cur_nodes = list(self._job_context.workers.values()) pending_workers: List[Node] = [] running_workers: List[Node] = [] for node in cur_nodes: @@ -487,7 +495,9 @@ def is_training_hang_by_insufficient_worker(self) -> bool: logger.debug( f"Is training hang by insufficient worker with timeout: {timeout}." ) - cur_nodes = list(self._nodes.values()) + + nodes = self._job_context.job_nodes_by_type(self._node_type) + cur_nodes = list(nodes.values()) # collect available nodes available_nodes: List[Node] = [] diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index 3b106a0f6..2e2be385a 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -140,6 +140,8 @@ def get(self, request, _): elif isinstance(req_message, grpc.ElasticRunConfigRequest): configs = self._job_manager.get_elastic_run_configs() message = grpc.ElasticRunConfig(configs=configs) + elif isinstance(req_message, grpc.HeartBeat): + message = self._report_heartbeat(node_type, node_id, req_message) if message: response.data = message.serialize() @@ -353,8 +355,6 @@ def report(self, request, _): success = self._kv_store_set(message) elif isinstance(message, grpc.ParallelConfig): success = self._report_paral_config(node_type, node_id, message) - elif isinstance(message, grpc.HeartBeat): - success = self._report_heartbeat(node_type, node_id, message) elif isinstance(message, grpc.NodeCheckpointState): success = self._sync_checkpoint(node_type, node_id, message) elif isinstance(message, grpc.DiagnosisReportData): @@ -606,14 +606,6 @@ def _report_paral_config( ) return True - def _report_heartbeat(self, node_type, node_id, message: grpc.HeartBeat): - self._job_manager.collect_node_heart_beat( - node_type, - node_id, - message.timestamp, - ) - return True - def _sync_checkpoint( self, node_type, node_id, message: grpc.NodeCheckpointState ): @@ -649,6 +641,19 @@ def _sync_training_ports( port=sync_ports.training_port, newport=sync_ports.next_check_port ) + def _report_heartbeat( + self, node_type, node_id, message: grpc.HeartBeat + ) -> grpc.HeartbeatResponse: + action = self._job_manager.collect_node_heart_beat( + node_type, node_id, message.timestamp + ) + grpc_action = grpc.DiagnosisAction( + action.__class__.__name__, + action.to_json(), + ) + + return grpc.HeartbeatResponse(action=grpc_action) + def create_master_service( port, diff --git a/dlrover/python/tests/test_common_util.py b/dlrover/python/tests/test_common_util.py index c39eeeeba..895eebbc9 100644 --- a/dlrover/python/tests/test_common_util.py +++ b/dlrover/python/tests/test_common_util.py @@ -19,4 +19,4 @@ class CommonUtilTest(unittest.TestCase): def test_get_dlrover_version(self): self.assertIsNotNone(cu.get_dlrover_version()) - self.assertNotEquals(cu.get_dlrover_version(), "Unknown") + self.assertNotEqual(cu.get_dlrover_version(), "Unknown") diff --git a/dlrover/python/tests/test_diagnosis.py b/dlrover/python/tests/test_diagnosis.py index 501dd1d08..876dc3070 100644 --- a/dlrover/python/tests/test_diagnosis.py +++ b/dlrover/python/tests/test_diagnosis.py @@ -14,7 +14,18 @@ import time import unittest -from dlrover.python.diagnosis.common.constants import DiagnosisDataType +from dlrover.python.common.constants import NodeStatus +from dlrover.python.diagnosis.common.constants import ( + DiagnosisActionType, + DiagnosisConstant, + DiagnosisDataType, +) +from dlrover.python.diagnosis.common.diagnosis_action import ( + DiagnosisAction, + DiagnosisActionQueue, + EventAction, + NodeAction, +) from dlrover.python.diagnosis.common.diagnosis_data import TrainingLog from dlrover.python.master.diagnosis.diagnosis import DiagnosisDataManager @@ -43,6 +54,68 @@ def test_data_manager(self): logs = mgr.get_data(DiagnosisDataType.TRAINING_LOG) self.assertEqual(len(logs), 1) + def test_action_basic(self): + basic_action = DiagnosisAction() + self.assertEqual(basic_action.action_type, DiagnosisActionType.NONE) + self.assertEqual(basic_action._instance, DiagnosisConstant.MASTER) + + event_action = EventAction( + "info", "job", "test", "test123", {"k1": "v1"} + ) + self.assertEqual(event_action.action_type, DiagnosisActionType.EVENT) + self.assertEqual(event_action._instance, DiagnosisConstant.MASTER) + self.assertEqual(event_action.event_type, "info") + self.assertEqual(event_action.event_instance, "job") + self.assertEqual(event_action.event_action, "test") + self.assertEqual(event_action.event_msg, "test123") + self.assertEqual(event_action.event_labels, {"k1": "v1"}) + + node_relaunch_action = NodeAction(1, NodeStatus.FAILED, "hang") + self.assertEqual( + node_relaunch_action.action_type, + DiagnosisActionType.MASTER_RELAUNCH_WORKER, + ) + self.assertEqual(node_relaunch_action._instance, 1) + self.assertEqual(node_relaunch_action.node_id, 1) + self.assertEqual(node_relaunch_action.node_status, NodeStatus.FAILED) + self.assertEqual(node_relaunch_action.reason, "hang") + + node_relaunch_action = NodeAction( + 1, NodeStatus.FAILED, "hang", DiagnosisActionType.RESTART_WORKER + ) + self.assertEqual( + node_relaunch_action.action_type, + DiagnosisActionType.RESTART_WORKER, + ) + + def test_action_queue(self): + action_queue = DiagnosisActionQueue() + action0 = EventAction("test0", expired_time_period=100000) + action1 = EventAction("test1", expired_time_period=1) + action2 = EventAction("test2", expired_time_period=100000) + + action_queue.add_action(action0) + action_queue.add_action(action1) + action_queue.add_action(action2) + + time.sleep(1) + self.assertEqual( + action_queue.next_action(instance=1).action_type, + DiagnosisActionType.NONE, + ) + self.assertEqual( + action_queue.next_action(instance=-1).action_type, + DiagnosisActionType.EVENT, + ) + self.assertEqual( + action_queue.next_action(instance=-1).action_type, + DiagnosisActionType.EVENT, + ) + self.assertEqual( + action_queue.next_action(instance=1).action_type, + DiagnosisActionType.NONE, + ) + if __name__ == "__main__": unittest.main() diff --git a/dlrover/python/tests/test_diagnosis_agent.py b/dlrover/python/tests/test_diagnosis_agent.py index c6770f677..9d624c0ca 100644 --- a/dlrover/python/tests/test_diagnosis_agent.py +++ b/dlrover/python/tests/test_diagnosis_agent.py @@ -19,9 +19,9 @@ from dlrover.python.common import env_utils from dlrover.python.common.constants import RendezvousName -from dlrover.python.common.worker import WorkerContext -from dlrover.python.diagnosis.common.constants import DiagnosisAction +from dlrover.python.diagnosis.common.constants import DiagnosisActionType from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric +from dlrover.python.elastic_agent.context import AgentContext from dlrover.python.elastic_agent.diagnosis.diagnosis_agent import ( DiagnosisAgent, ) @@ -58,6 +58,7 @@ def test_diagnose_training(self): file_path = os.path.join(path, file) errors = "error code is 11111" + agent = DiagnosisAgent.singleton_instance(file_path, errors) spec = _create_worker_spec( @@ -74,29 +75,30 @@ def test_diagnose_training(self): ), failures={}, ) - wc = WorkerContext( + + context = AgentContext( worker_spec=spec, remaining_failovers=2, restart_count=3, run_result=run_result, ) - action = agent.diagnose_training_failure(wc) - self.assertEqual(action, DiagnosisAction.RESTART_WORKER) + action = agent.diagnose_training_failure(context) + self.assertEqual(action, DiagnosisActionType.RESTART_WORKER) agent._errors = "error code is 507035" - action = agent.diagnose_training_failure(wc) - self.assertEqual(action, DiagnosisAction.RELAUNCH_WORKER) + action = agent.diagnose_training_failure(context) + self.assertEqual(action, DiagnosisActionType.RELAUNCH_WORKER) agent._errors = "error code is 11111" - wc.remaining_failovers = 0 - action = agent.diagnose_training_failure(wc) - self.assertEqual(action, DiagnosisAction.RELAUNCH_WORKER) + context.remaining_failovers = 0 + action = agent.diagnose_training_failure(context) + self.assertEqual(action, DiagnosisActionType.RELAUNCH_WORKER) agent._errors = " #" - wc.remaining_failovers = 2 - action = agent.diagnose_training_failure(wc) - self.assertEqual(action, DiagnosisAction.RESTART_WORKER) + context.remaining_failovers = 2 + action = agent.diagnose_training_failure(context) + self.assertEqual(action, DiagnosisActionType.RESTART_WORKER) def test_worker_training_metric(self): test = WorkerTrainingMetric( diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index 00404c55d..2a25696f5 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -111,6 +111,7 @@ def setUp(self) -> None: def tearDown(self): self._master.stop() + os.environ.clear() def test_node_unit(self): node_unit = int(self.rdzv_handler._rdzv_params.get("node_unit", "1")) diff --git a/dlrover/python/tests/test_job_auto_scaler.py b/dlrover/python/tests/test_job_auto_scaler.py index e8da71378..b9a974392 100644 --- a/dlrover/python/tests/test_job_auto_scaler.py +++ b/dlrover/python/tests/test_job_auto_scaler.py @@ -26,6 +26,7 @@ AllreduceTrainingAutoScaler, PSTrainingAutoScaler, ) +from dlrover.python.master.node.job_context import get_job_context from dlrover.python.master.resource.optimizer import ResourcePlan from dlrover.python.tests.test_utils import ( MockK8sAllreduceJobArgs, @@ -39,6 +40,10 @@ class JobAutoScalerTest(unittest.TestCase): def setUp(self) -> None: mock_k8s_client() + self.job_context = get_job_context() + + def tearDown(self) -> None: + self.job_context.clear_job_nodes() def test_execute_job_optimization_plan(self): params = MockK8sPSJobArgs() @@ -50,7 +55,6 @@ def test_execute_job_optimization_plan(self): auto_scaler = PSTrainingAutoScaler( manager._job_resource, - manager._job_nodes, manager._job_optimizer, manager._speed_monitor, manager._ps_manager, @@ -64,11 +68,23 @@ def test_execute_job_optimization_plan(self): plan.node_resources["test-edljob-worker-0"] = NodeResource(8, 8192) plan.node_resources["test-edljob-worker-1"] = NodeResource(8, 8192) plan.node_resources["test-edljob-ps-1"] = NodeResource(8, 8192) - auto_scaler._ps_manager._nodes[1].status = NodeStatus.RUNNING - auto_scaler._worker_manager._nodes[0].critical = True + + ps_nodes = self.job_context.job_nodes_by_type(NodeType.PS) + ps_node = ps_nodes[1] + ps_node.type = NodeType.PS + ps_node.status = NodeStatus.RUNNING + self.job_context.update_job_node(ps_node) + worker_nodes = self.job_context.job_nodes_by_type(NodeType.WORKER) + worker_node = worker_nodes[0] + worker_node.type = NodeType.WORKER + worker_node.critical = True + self.job_context.update_job_node(worker_node) scale_plan = auto_scaler.execute_job_optimization_plan(plan) - self.assertEqual(len(manager._ps_manager._nodes), 4) - self.assertEqual(len(manager._worker_manager._nodes), 7) + + ps_nodes = self.job_context.job_nodes_by_type(NodeType.PS) + self.assertEqual(len(ps_nodes), 4) + worker_nodes = self.job_context.job_nodes_by_type(NodeType.WORKER) + self.assertEqual(len(worker_nodes), 7) self.assertEqual(len(scale_plan.remove_nodes), 1) self.assertEqual(len(scale_plan.launch_nodes), 5) remove_node = scale_plan.remove_nodes[0] @@ -113,7 +129,6 @@ def test_reduce_timeout_pending_node_resource(self): auto_scaler = PSTrainingAutoScaler( manager._job_resource, - manager._job_nodes, manager._job_optimizer, manager._speed_monitor, manager._ps_manager, @@ -121,10 +136,14 @@ def test_reduce_timeout_pending_node_resource(self): manager._scaler, ) auto_scaler._autoscaling_started = True - ps0 = manager._ps_manager._nodes[0] + + ps_nodes = self.job_context.job_nodes_by_type(NodeType.PS) + ps0 = ps_nodes[0] + ps0.type = NodeType.PS ps0.config_resource.cpu = 16 ps0.status = NodeStatus.PENDING ps0.create_time = datetime.now() + timedelta(days=-1) + self.job_context.update_job_node(ps0) plan = auto_scaler._reduce_timeout_pending_node_resource() self.assertEqual( plan.ps_addrs, @@ -139,6 +158,10 @@ def test_reduce_timeout_pending_node_resource(self): class AllreduceAutoScalerTest(unittest.TestCase): def setUp(self) -> None: mock_k8s_client() + self.job_context = get_job_context() + + def tearDown(self) -> None: + self.job_context.clear_job_nodes() def test_execute_job_optimization_plan(self): params = MockK8sAllreduceJobArgs() @@ -146,14 +169,16 @@ def test_execute_job_optimization_plan(self): manager = create_job_manager(params, SpeedMonitor()) manager._init_nodes() - for worker in manager._job_nodes[NodeType.WORKER].values(): + worker_nodes = self.job_context.job_nodes_by_type(NodeType.WORKER) + + for worker in worker_nodes.values(): worker.status = NodeStatus.RUNNING + self.job_context.update_job_node(worker) manager._scaler.scale = mock.MagicMock(return_value=True) auto_scaler = AllreduceTrainingAutoScaler( manager._job_resource, - manager._job_nodes, manager._job_optimizer, manager._speed_monitor, manager._worker_manager, diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py index 50fed89c7..1ba4ecc85 100644 --- a/dlrover/python/tests/test_job_manager.py +++ b/dlrover/python/tests/test_job_manager.py @@ -51,6 +51,7 @@ TaskRescheduleCallback, TFPSNodeHandlingCallback, ) +from dlrover.python.master.node.job_context import get_job_context from dlrover.python.master.node.local_job_manager import LocalJobManager from dlrover.python.master.node.status_flow import ( NODE_STATE_FLOWS, @@ -121,6 +122,10 @@ def test_get_node_state_flow(self): class DistributedJobManagerTest(unittest.TestCase): def setUp(self) -> None: mock_k8s_client() + self.job_context = get_job_context() + + def tearDown(self): + self.job_context.clear_job_nodes() def test_job_resource(self): job = JobResource() @@ -222,8 +227,10 @@ def test_relaunch_node(self): self.assertEqual(manager._ps_relaunch_max_num, 1) manager.start() self.assertEqual(manager._job_args.job_uuid, _MOCK_JOB_UUID) - self.assertEqual(len(manager._job_nodes), 4) - self.assertTrue(manager._job_nodes[NodeType.PS][0].critical) + + job_nodes = self.job_context.job_nodes() + self.assertEqual(len(job_nodes), 4) + self.assertTrue(job_nodes[NodeType.PS][0].critical) node = Node( node_type=NodeType.WORKER, @@ -245,21 +252,21 @@ def test_relaunch_node(self): manager.update_node_resource_usage( NodeType.WORKER, 0, 0.7, 2048, gpu_stats ) # noqa + job_nodes = self.job_context.job_nodes() + self.assertEqual(job_nodes[NodeType.WORKER][0].used_resource.cpu, 0.7) self.assertEqual( - manager._job_nodes[NodeType.WORKER][0].used_resource.cpu, 0.7 - ) - self.assertEqual( - manager._job_nodes[NodeType.WORKER][0].used_resource.memory, 2048 + job_nodes[NodeType.WORKER][0].used_resource.memory, 2048 ) self.assertEqual( - manager._job_nodes[NodeType.WORKER][0].used_resource.gpu_stats, + job_nodes[NodeType.WORKER][0].used_resource.gpu_stats, gpu_stats, # noqa ) node_event: NodeEvent = NodeEvent(NodeEventType.MODIFIED, node) manager._process_event(node_event) + job_nodes = self.job_context.job_nodes() self.assertEqual( - manager._job_nodes[NodeType.WORKER][1].status, NodeStatus.RUNNING + job_nodes[NodeType.WORKER][1].status, NodeStatus.RUNNING ) should_relaunch = manager._should_relaunch(node, NODE_STATE_FLOWS[5]) self.assertFalse(should_relaunch) @@ -352,16 +359,18 @@ def test_get_dead_node_event(self): manager.start() ts = int(time.time()) manager.collect_node_heart_beat(NodeType.WORKER, 0, ts) - worker0 = manager._job_nodes[NodeType.WORKER][0] + + job_nodes = self.job_context.job_nodes() + worker0 = job_nodes[NodeType.WORKER][0] self.assertEqual(worker0.heartbeat_time, ts) - for node in manager._job_nodes[NodeType.WORKER].values(): + for node in job_nodes[NodeType.WORKER].values(): node.status = NodeStatus.RUNNING + self.job_context.update_job_node(node) events = manager._get_dead_node_event() self.assertEqual(len(events), 0) - for index, node in enumerate( - manager._job_nodes[NodeType.WORKER].values() - ): + job_nodes = self.job_context.job_nodes() + for index, node in enumerate(job_nodes[NodeType.WORKER].values()): node.status = NodeStatus.RUNNING now = datetime.now() node.heartbeat_time = (now - timedelta(seconds=1000)).timestamp() @@ -371,6 +380,7 @@ def test_get_dead_node_event(self): else: node.create_time = now - timedelta(seconds=1400) node.start_time = now - timedelta(seconds=1200) + self.job_context.update_job_node(node) events = manager._get_dead_node_event() self.assertEqual(len(events), 2) @@ -378,9 +388,8 @@ def test_get_dead_node_event(self): self.assertIsNotNone(nodes_time_info) self.assertEqual(len(nodes_time_info), 3) - for index, node in enumerate( - manager._job_nodes[NodeType.WORKER].values() - ): + job_nodes = self.job_context.job_nodes() + for index, node in enumerate(job_nodes[NodeType.WORKER].values()): node.status = NodeStatus.RUNNING now = datetime.now() node.heartbeat_time = (now - timedelta(seconds=1000)).timestamp() @@ -392,7 +401,7 @@ def test_get_dead_node_event(self): node.reported_status = 0 node.create_time = now - timedelta(seconds=1400) node.start_time = now - timedelta(seconds=1200) - + self.job_context.update_job_node(node) events = manager._get_dead_node_event() self.assertEqual(len(events), 1) @@ -407,7 +416,7 @@ def test_relaunch_training_master(self): manager._init_nodes() master = Node(NodeType.MASTER, 0, NodeResource(1, 256)) - manager._job_nodes[NodeType.MASTER][0] = master + self.job_context.update_job_node(master) plan = manager._chief_manager.relaunch_node(master) self.assertEqual(plan.launch_nodes[0].id, 1) @@ -416,9 +425,11 @@ def test_process_list_nodes(self): params.initilize() manager = create_job_manager(params, SpeedMonitor()) manager._init_nodes() - self.assertFalse(4 in manager._job_nodes[NodeType.WORKER]) - for node in manager._job_nodes[NodeType.PS].values(): + job_nodes = self.job_context.job_nodes() + self.assertFalse(4 in job_nodes[NodeType.WORKER]) + for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.PENDING + self.job_context.update_job_node(node) nodes = [] for i in range(2): node = Node( @@ -439,16 +450,18 @@ def test_process_list_nodes(self): ) ) manager._process_list_nodes(nodes) - ps_ids = list(manager._job_nodes[NodeType.PS].keys()) + + job_nodes = self.job_context.job_nodes() + ps_ids = list(job_nodes[NodeType.PS].keys()) self.assertListEqual(ps_ids, [0, 1, 2]) - self.assertTrue(4 in manager._job_nodes[NodeType.WORKER]) + self.assertTrue(4 in self.job_context.job_nodes()[NodeType.WORKER]) @patch.object(DistributedJobManager, "_process_event") def test_process_list_nodes_for_empty_case(self, mock_method): params = MockK8sPSJobArgs() params.initilize() manager = create_job_manager(params, SpeedMonitor()) - manager._job_nodes = { + job_nodes = { NodeType.PS: { 0: Node( node_type=NodeType.PS, @@ -468,6 +481,7 @@ def test_process_list_nodes_for_empty_case(self, mock_method): ) }, } + self.job_context.update_job_nodes(job_nodes) manager._process_list_nodes([]) self.assertEqual(mock_method.call_count, 2) @@ -483,9 +497,12 @@ def test_create_allreduce_job_manager(self): manager._adjust_worker_for_estimator() manager._init_nodes() manager._init_job_auto_scaler() - self.assertEqual(len(manager._job_nodes[NodeType.WORKER]), 3) + + job_nodes = self.job_context.job_nodes() + self.assertEqual(len(job_nodes[NodeType.WORKER]), 3) manager.start_auto_scaling() - self.assertEqual(len(manager._job_nodes[NodeType.WORKER]), 3) + job_nodes = self.job_context.job_nodes() + self.assertEqual(len(job_nodes[NodeType.WORKER]), 3) def test_recover_tasks_for_failed_workers(self): ds_name_0 = "test-0" @@ -552,27 +569,38 @@ def test_check_worker_status(self): manager._init_nodes() self.assertFalse(manager.all_workers_exited()) - for worker in manager._job_nodes[NodeType.WORKER].values(): + job_nodes = self.job_context.job_nodes() + + for worker in job_nodes[NodeType.WORKER].values(): worker.status = NodeStatus.FINISHED - for worker in manager._job_nodes[NodeType.CHIEF].values(): + for worker in job_nodes[NodeType.CHIEF].values(): worker.status = NodeStatus.FINISHED - for worker in manager._job_nodes[NodeType.EVALUATOR].values(): + for worker in job_nodes[NodeType.EVALUATOR].values(): worker.status = NodeStatus.FINISHED + self.job_context.update_job_nodes(job_nodes) self.assertTrue(manager.all_workers_exited()) - for worker in manager._job_nodes[NodeType.WORKER].values(): + job_nodes = self.job_context.job_nodes() + for worker in job_nodes[NodeType.WORKER].values(): worker.status = NodeStatus.FAILED - for worker in manager._job_nodes[NodeType.CHIEF].values(): + self.job_context.update_job_node(worker) + for worker in job_nodes[NodeType.CHIEF].values(): worker.status = NodeStatus.FAILED - for worker in manager._job_nodes[NodeType.EVALUATOR].values(): + self.job_context.update_job_node(worker) + for worker in job_nodes[NodeType.EVALUATOR].values(): worker.status = NodeStatus.FAILED + self.job_context.update_job_node(worker) self.assertTrue(manager.all_workers_failed()) - for worker in manager._job_nodes[NodeType.PS].values(): + job_nodes = self.job_context.job_nodes() + for worker in job_nodes[NodeType.PS].values(): worker.status = NodeStatus.FINISHED - manager._job_nodes[NodeType.WORKER][0].status = NodeStatus.RUNNING + self.job_context.update_job_node(worker) + job_nodes[NodeType.WORKER][0].status = NodeStatus.RUNNING + self.job_context.update_job_node(job_nodes[NodeType.WORKER][0]) self.assertFalse(manager.all_critical_node_completed()) - manager._job_nodes[NodeType.WORKER][0].status = NodeStatus.FINISHED + job_nodes[NodeType.WORKER][0].status = NodeStatus.FINISHED + self.job_context.update_job_node(job_nodes[NodeType.WORKER][0]) self.assertTrue(manager.all_critical_node_completed()) def test_tf_ps_node_handling(self): @@ -615,10 +643,12 @@ def test_all_running_node_hang(self): hang = manager.all_running_node_hanged() self.assertFalse(hang) - for _, nodes in manager._job_nodes.items(): + job_nodes = self.job_context.job_nodes() + for _, nodes in job_nodes.items(): for _, node in nodes.items(): node.start_hang_time = time.time() - 3600 * 4 node.status = NodeStatus.RUNNING + self.job_context.update_job_node(node) manager.update_node_resource_usage(NodeType.WORKER, 0, 0.01, 256) hang = manager.all_running_node_hanged() self.assertTrue(hang) @@ -631,33 +661,42 @@ def test_early_stop_part1(self): params.initilize() manager = create_job_manager(params, SpeedMonitor()) manager._init_nodes() - for node in manager._job_nodes[NodeType.PS].values(): + + job_nodes = self.job_context.job_nodes() + for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.PENDING node.is_recovered_oom = True node.create_time = datetime.now() + self.job_context.update_job_node(node) result, reason, msg = manager.should_early_stop() self.assertFalse(result) self.assertFalse(reason) self.assertFalse(msg) manager._remove_exited_node = True - manager._job_nodes[NodeType.WORKER][0].status = NodeStatus.FAILED + job_nodes = self.job_context.job_nodes() + job_nodes[NodeType.WORKER][0].status = NodeStatus.FAILED + self.job_context.update_job_node(job_nodes[NodeType.WORKER][0]) manager.clear_exited_nodes() - self.assertTrue(manager._job_nodes[NodeType.WORKER][0].is_released) + job_nodes = self.job_context.job_nodes() + self.assertTrue(job_nodes[NodeType.WORKER][0].is_released) - for node in manager._job_nodes[NodeType.PS].values(): + for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.PENDING node.create_time = datetime.now() + timedelta(days=-1) node.is_recovered_oom = True + self.job_context.update_job_node(node) result, reason, msg = manager.should_early_stop() self.assertTrue(result) self.assertTrue(reason) self.assertTrue(msg) - for node in manager._job_nodes[NodeType.PS].values(): + job_nodes = self.job_context.job_nodes() + for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.RUNNING node.create_time = datetime.now() + timedelta(days=-1) node.is_recovered_oom = True + self.job_context.update_job_node(node) result, reason, msg = manager.should_early_stop() self.assertFalse(result) self.assertFalse(reason) @@ -734,7 +773,9 @@ def test_when_node_not_init(self): params = MockK8sPSJobArgs() params.initilize() manager = create_job_manager(params, SpeedMonitor()) - self.assertTrue(not manager._job_nodes) + job_context = get_job_context() + job_nodes = job_context.job_nodes() + self.assertTrue(len(job_nodes) == 0) manager.update_node_resource_usage(NodeType.WORKER, 0, 10, 10240, None) @@ -746,7 +787,7 @@ def test_start_and_stop(self): manager.start() active_threads_name = [t.name for t in threading.enumerate()] self.assertIn("node_monitor", active_threads_name) - self.assertIn("node_heart_beat_monitor", active_threads_name) + self.assertIn("diagnose_job", active_threads_name) manager.stop() def test_concurrency_heart_beat_collecting(self): @@ -756,8 +797,9 @@ def test_concurrency_heart_beat_collecting(self): manager = create_job_manager(params, SpeedMonitor()) manager.start() - self.assertEqual(len(manager._job_nodes[NodeType.WORKER]), worker_size) - for i, node in manager._job_nodes[NodeType.WORKER].items(): + job_nodes = self.job_context.job_nodes() + self.assertEqual(len(job_nodes[NodeType.WORKER]), worker_size) + for i, node in job_nodes[NodeType.WORKER].items(): self.assertEqual(node.id, i) self.assertEqual(node.heartbeat_time, 0) futures = [] @@ -773,14 +815,15 @@ def test_concurrency_heart_beat_collecting(self): future.result() self.assertEqual(len(futures), worker_size) - for i, node in manager._job_nodes[NodeType.WORKER].items(): + job_nodes = self.job_context.job_nodes() + for i, node in job_nodes[NodeType.WORKER].items(): self.assertEqual(node.id, i) self.assertEqual(node.heartbeat_time, i) manager.stop() + self.job_context.clear_job_nodes() # test when job manager not init - manager._job_nodes = {} try: manager.collect_node_heart_beat("worker", 1, 111) except Exception: @@ -802,7 +845,10 @@ def test_local_job_manager(self): args.node_args[NodeType.WORKER].group_resource.count = 4 job_manager = LocalJobManager(args, error_monitor=SimpleErrorMonitor()) job_manager.start() - self.assertEqual(len(job_manager._job_nodes[NodeType.WORKER]), 4) + + job_context = get_job_context() + job_nodes = job_context.job_nodes() + self.assertEqual(len(job_nodes[NodeType.WORKER]), 4) gpu_stats: list[GPUStats] = [ GPUStats( index=0, @@ -815,7 +861,8 @@ def test_local_job_manager(self): NodeType.WORKER, 0, 10, 10240, gpu_stats ) - worker = job_manager._job_nodes[NodeType.WORKER][0] + job_nodes = job_context.job_nodes() + worker = job_nodes[NodeType.WORKER][0] self.assertEqual(worker.used_resource.cpu, 10) self.assertEqual(worker.used_resource.memory, 10240) self.assertEqual(worker.used_resource.gpu_stats, gpu_stats) @@ -824,6 +871,8 @@ def test_local_job_manager(self): optimizer_config = OptimizerConfig(1, "test_optimizer", 2) paral_config = ParallelConfig(dataloader_config, optimizer_config) job_manager.update_node_paral_config(NodeType.WORKER, 0, paral_config) - worker = job_manager._job_nodes[NodeType.WORKER][0] + + job_nodes = job_context.job_nodes() + worker = job_nodes[NodeType.WORKER][0] self.assertEqual(worker.paral_config, paral_config) job_manager.handle_training_failure(NodeType.WORKER, 3) diff --git a/dlrover/python/tests/test_log.py b/dlrover/python/tests/test_log.py index 3d90f69e9..16a5c5666 100644 --- a/dlrover/python/tests/test_log.py +++ b/dlrover/python/tests/test_log.py @@ -25,6 +25,7 @@ def test_default_log_level(self): from dlrover.python.common.log import default_logger as logger from dlrover.python.common.log import get_log_level + os.environ[BasicClass.LOG_LEVEL_ENV] = "INFO" self.assertEqual(get_log_level(), "INFO") logger.info("test123") diff --git a/dlrover/python/tests/test_master.py b/dlrover/python/tests/test_master.py index eef86ef4a..e9e72adb7 100644 --- a/dlrover/python/tests/test_master.py +++ b/dlrover/python/tests/test_master.py @@ -28,6 +28,7 @@ _create_master_service_on_k8s, ) from dlrover.python.master.main import update_context +from dlrover.python.master.node.job_context import get_job_context from dlrover.python.master.shard.dataset_splitter import new_dataset_splitter from dlrover.python.tests.test_utils import ( MockK8sPSJobArgs, @@ -44,23 +45,28 @@ def setUp(self) -> None: params = MockK8sPSJobArgs() params.initilize() self.master = DistributedJobMaster(2222, params) + self.job_context = get_job_context() + + def tearDown(self): + self.job_context.clear_job_nodes() def test_exit_by_workers(self): self.master.job_manager._init_nodes() - job_nodes = self.master.job_manager._job_nodes + job_nodes = self.job_context.job_nodes() for node in job_nodes[NodeType.WORKER].values(): node.status = NodeStatus.FINISHED for node in job_nodes[NodeType.EVALUATOR].values(): node.status = NodeStatus.FINISHED for node in job_nodes[NodeType.CHIEF].values(): node.status = NodeStatus.FINISHED + self.job_context.update_job_nodes(job_nodes) self.master.run() self.assertEqual(self.master._exit_code, 0) self.assertEqual(self.master._exit_reason, JobExitReason.SUCCEEDED) def test_exit_by_tasks(self): self.master.job_manager._init_nodes() - job_nodes = self.master.job_manager._job_nodes + job_nodes = self.job_context.job_nodes() for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.FINISHED for node in job_nodes[NodeType.EVALUATOR].values(): @@ -69,6 +75,7 @@ def test_exit_by_tasks(self): node.status = NodeStatus.FINISHED job_nodes[NodeType.WORKER][0].status = NodeStatus.FINISHED + self.job_context.update_job_nodes(job_nodes) splitter = new_dataset_splitter( False, @@ -91,11 +98,14 @@ def test_exit_by_tasks(self): def test_early_stop(self): self.master.job_manager._init_nodes() - job_nodes = self.master.job_manager._job_nodes + job_nodes = self.job_context.job_nodes() + for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.PENDING node.is_recovered_oom = True node.create_time = datetime.now() + timedelta(days=-1) + self.job_context.update_job_nodes(job_nodes) + exit_code = self.master.run() self.master.job_manager.clear_all_nodes() self.assertEqual(exit_code, 1) diff --git a/dlrover/python/tests/test_node.py b/dlrover/python/tests/test_node.py index bf320df95..95c855d6a 100644 --- a/dlrover/python/tests/test_node.py +++ b/dlrover/python/tests/test_node.py @@ -64,3 +64,7 @@ def test_is_unrecoverable_failure(self): node.update_node_check_result(NodeEventType.NODE_CHECK_FAILED) self.assertTrue(node.is_succeeded()) + + node.update_from_node(node) + node.id = 100 + node.update_from_node(node) diff --git a/dlrover/python/tests/test_pod_scaler.py b/dlrover/python/tests/test_pod_scaler.py index efec29bdf..431be9e11 100644 --- a/dlrover/python/tests/test_pod_scaler.py +++ b/dlrover/python/tests/test_pod_scaler.py @@ -40,6 +40,9 @@ def setUp(self) -> None: os.environ["POD_IP"] = "127.0.0.1" mock_k8s_client() + def tearDown(self) -> None: + os.environ.clear() + def test_init_pod_template(self): error_monitor = SimpleErrorMonitor() scaler = PodScaler("elasticjob-sample", "default", error_monitor) diff --git a/dlrover/python/tests/test_ps_manager.py b/dlrover/python/tests/test_ps_manager.py index d213510af..ed278029e 100644 --- a/dlrover/python/tests/test_ps_manager.py +++ b/dlrover/python/tests/test_ps_manager.py @@ -22,6 +22,7 @@ ) from dlrover.python.common.global_context import Context from dlrover.python.common.node import Node, NodeGroupResource, NodeResource +from dlrover.python.master.node.job_context import get_job_context from dlrover.python.master.node.ps import ParameterServerManager from dlrover.python.master.resource.job import JobResource from dlrover.python.scheduler.factory import new_elastic_job @@ -40,20 +41,24 @@ def setUp(self) -> None: self._elastic_job = new_elastic_job( PlatformType.KUBERNETES, "test", "default" ) - self._job_nodes = self._job_resource.init_job_node_meta( + self._job_context = get_job_context() + job_nodes = self._job_resource.init_job_node_meta( 1, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + self._job_context.update_job_nodes(job_nodes) self._ps_manager = ParameterServerManager( - self._job_nodes[NodeType.PS], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + def tearDown(self) -> None: + self._job_context.clear_job_nodes() + def test_get_training_ps_cluster(self): ps_nodes = self._ps_manager.get_training_ps_cluster() self.assertEqual(len(ps_nodes), 2) @@ -67,9 +72,11 @@ def test_get_training_ps_cluster(self): ) def test_cut_pending_ps_cpu(self): - for _, node in self._ps_manager._nodes.items(): + nodes = self._job_context.job_nodes_by_type(NodeType.PS) + for _, node in nodes.items(): node.status = NodeStatus.PENDING node.create_time = datetime.now() + timedelta(days=-1) + self._job_context.update_job_node(node) plan = self._ps_manager.reduce_pending_node_resource() self.assertEqual(len(plan.launch_nodes), 2) @@ -80,8 +87,11 @@ def test_scale_up_ps(self): self._ps_manager._scale_up_ps(2) training_ps = self._ps_manager.get_next_training_ps_cluster() self.assertEqual(len(training_ps), 2) - for node in self._ps_manager._nodes.values(): + + nodes = self._job_context.job_nodes_by_type(NodeType.PS) + for node in nodes.values(): node.status = NodeStatus.RUNNING + self._job_context.update_job_node(node) training_ps = self._ps_manager.get_next_training_ps_cluster() self.assertEqual(len(training_ps), 4) @@ -91,15 +101,17 @@ def test_scale_down_ps(self): self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + self._job_context.update_job_nodes(job_nodes) ps_manager = ParameterServerManager( - job_nodes[NodeType.PS], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - for node in ps_manager._nodes.values(): + nodes = self._job_context.job_nodes_by_type(NodeType.PS) + for node in nodes.values(): node.status = NodeStatus.RUNNING + self._job_context.update_job_node(node) ps_manager._scale_down_ps(1) self.assertEqual(len(ps_manager._pre_dropped_ps), 1) self.assertEqual(ps_manager._pre_dropped_ps[0].id, 1) @@ -113,17 +125,20 @@ def test_delete_running_ps(self): self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + self._job_context.update_job_nodes(job_nodes) ps_manager = ParameterServerManager( - job_nodes[NodeType.PS], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - for node in ps_manager._nodes.values(): + nodes = self._job_context.job_nodes_by_type(NodeType.PS) + for node in nodes.values(): node.status = NodeStatus.RUNNING + self._job_context.update_job_node(node) plan = ps_manager.delete_running_ps() + job_nodes = self._job_context.job_nodes() self.assertEqual(len(plan.remove_nodes), 2) self.assertTrue(job_nodes[NodeType.PS][0].is_released) self.assertTrue(job_nodes[NodeType.PS][1].is_released) @@ -134,15 +149,20 @@ def test_migrate_parameter_servers(self): self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + self._job_context.update_job_nodes(job_nodes) ps_manager = ParameterServerManager( - job_nodes[NodeType.PS], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - for node in ps_manager._nodes.values(): + + nodes = self._job_context.ps_nodes + for node in nodes.values(): node.status = NodeStatus.RUNNING + self._job_context.update_job_node(node) + + job_nodes = self._job_context.job_nodes() node_name = job_nodes[NodeType.PS][0].name nodes = {node_name: NodeResource(20, 2048)} plan = ps_manager.migrate_parameter_servers(nodes) @@ -150,11 +170,14 @@ def test_migrate_parameter_servers(self): self.assertEqual(ps_manager._migrated_ps_nodes[0].id, 2) self.assertTrue(ps_manager.exist_migrated_ps_nodes()) - ps_manager._pre_drop_migrated_ps(list(ps_manager._nodes.values())) + nodes = self._job_context.ps_nodes + ps_manager._pre_drop_migrated_ps(list(nodes.values())) self.assertEqual(len(ps_manager._pre_dropped_ps), 0) - for node in ps_manager._nodes.values(): + for node in nodes.values(): node.status = NodeStatus.RUNNING - ps_manager._pre_drop_migrated_ps(list(ps_manager._nodes.values())) + self._job_context.update_job_node(node) + nodes = self._job_context.ps_nodes + ps_manager._pre_drop_migrated_ps(list(nodes.values())) self.assertEqual(len(ps_manager._pre_dropped_ps), 1) training_ps = ps_manager.get_next_training_ps_cluster() @@ -169,22 +192,25 @@ def test_parameter_server_failure(self): self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + self._job_context.update_job_nodes(job_nodes) ps_manager = ParameterServerManager( - job_nodes[NodeType.PS], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - for node in ps_manager._nodes.values(): + nodes = self._job_context.job_nodes_by_type(NodeType.PS) + for node in nodes.values(): node.status = NodeStatus.RUNNING + self._job_context.update_job_node(node) ps_failure = ps_manager.has_ps_failure() self.assertFalse(ps_failure) - latest_ps_index = len(ps_manager._nodes) - 1 - ps = ps_manager._nodes[latest_ps_index] + latest_ps_index = len(nodes) - 1 + ps = nodes[latest_ps_index] ps_manager._ps_cluster_changed = True ps.status = NodeStatus.INITIAL ps.init_time -= 600 + self._job_context.update_job_node(ps) ps_failure = ps_manager.has_ps_failure() self.assertTrue(ps_failure) cluster = ps_manager.get_next_training_ps_cluster() @@ -197,7 +223,6 @@ def test_parameter_server_failure(self): def test_is_training_hang_by_pending_ps(self): _dlrover_ctx.pending_fail_strategy = 1 ps_manager = ParameterServerManager( - self._job_nodes[NodeType.PS], self._job_resource, 3, self._elastic_job.get_node_service_addr, @@ -234,7 +259,7 @@ def test_is_training_hang_by_pending_ps(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - ps_manager._nodes = mock_nodes + self._job_context.update_job_node(mock_node) self.assertFalse( ps_manager.is_training_hang_by_pending( ps_num, DistributionStrategy.ALLREDUCE @@ -246,6 +271,7 @@ def test_is_training_hang_by_pending_ps(self): ) ) mock_nodes.clear() + self._job_context.clear_job_nodes() # mock with 3 running + 1 pending long time for index in range(4): @@ -262,7 +288,7 @@ def test_is_training_hang_by_pending_ps(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - ps_manager._nodes = mock_nodes + self._job_context.update_job_node(mock_node) self.assertFalse( ps_manager.is_training_hang_by_pending( ps_num, DistributionStrategy.ALLREDUCE @@ -274,6 +300,7 @@ def test_is_training_hang_by_pending_ps(self): ) ) mock_nodes.clear() + self._job_context.clear_job_nodes() # mock with 4 running for index in range(4): @@ -285,7 +312,7 @@ def test_is_training_hang_by_pending_ps(self): NodeStatus.RUNNING, ) mock_nodes[index] = mock_node - ps_manager._nodes = mock_nodes + self._job_context.update_job_node(mock_node) self.assertFalse( ps_manager.is_training_hang_by_pending( ps_num, DistributionStrategy.ALLREDUCE @@ -297,3 +324,4 @@ def test_is_training_hang_by_pending_ps(self): ) ) mock_nodes.clear() + self._job_context.clear_job_nodes() diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index 5b54afd0e..cb9a02363 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -37,6 +37,7 @@ from dlrover.python.master.elastic_training.sync_service import SyncService from dlrover.python.master.monitor.speed_monitor import SpeedMonitor from dlrover.python.master.node.dist_job_manager import create_job_manager +from dlrover.python.master.node.job_context import get_job_context from dlrover.python.master.servicer import MasterServicer from dlrover.python.master.shard.task_manager import TaskManager from dlrover.python.master.stats.job_collector import JobMetricCollector @@ -60,11 +61,16 @@ def setUp(self) -> None: worker_resource.node_resource.gpu_type = "a100" speed_monitor = SpeedMonitor() self.task_manager = TaskManager(False, speed_monitor) + self.job_manager = create_job_manager(params, speed_monitor) + self.job_context = get_job_context() + self.job_manager._init_nodes() self.job_manager._init_job_auto_scaler() - for node in self.job_manager._job_nodes[NodeType.WORKER].values(): + job_nodes = self.job_context.job_nodes_by_type(NodeType.WORKER) + for node in job_nodes.values(): node.status = NodeStatus.RUNNING + self.job_context.update_job_node(node) self.job_metric_collector = JobMetricCollector( "1", "default", "local", "dlrover" ) @@ -88,6 +94,7 @@ def setUp(self) -> None: def tearDown(self) -> None: os.environ.clear() + self.job_context.clear_job_nodes() def test_query_running_nodes(self): request = elastic_training_pb2.Message() @@ -168,10 +175,15 @@ def test_metric_service(self): reporter._runtime_stats = [] self.assertEqual(reporter._model_info.op_stats.flops, 10000) - worker0 = self.job_manager._job_nodes[NodeType.WORKER][0] + job_nodes = self.job_context.job_nodes() + worker0 = job_nodes[NodeType.WORKER][0] worker0.status = NodeStatus.RUNNING - ps0 = self.job_manager._job_nodes[NodeType.PS][0] + self.job_context.update_job_node(worker0) + + ps0 = job_nodes[NodeType.PS][0] ps0.status = NodeStatus.RUNNING + self.job_context.update_job_node(ps0) + request = grpc.GlobalStep() self.task_manager._speed_monitor.add_running_worker(NodeType.WORKER, 0) self.task_manager._speed_monitor.set_target_worker_num(1) @@ -206,8 +218,10 @@ def test_metric_service(self): def test_query_ps_nodes(self): self.job_manager._init_nodes() - for node in self.job_manager._job_nodes[NodeType.PS].values(): + nodes = self.job_context.job_nodes_by_type(NodeType.PS) + for node in nodes.values(): node.status = NodeStatus.RUNNING + self.job_context.update_job_node(node) res = self.servicer._query_ps_nodes() self.assertEqual(len(res.nodes), 3) self.assertEqual( @@ -395,8 +409,9 @@ def test_report_heartbeat(self): request.data = message.serialize() request.node_type = NodeType.WORKER request.node_id = 0 - self.servicer.report(request, None) - worker0 = self.servicer._job_manager._job_nodes[NodeType.WORKER][0] + self.servicer.get(request, None) + + worker0 = self.job_context.job_node(NodeType.WORKER, 0) self.assertEqual(worker0.heartbeat_time, ts) def test_sync_checkpoint(self): @@ -435,28 +450,32 @@ def test_deal_with_reported_node_event(self): request.message = "OOM" self.assertTrue(self.servicer._deal_with_reported_node_event(request)) self.assertFalse( - self.job_manager._job_nodes[task_type][task_id].is_succeeded() + self.job_manager._job_context.job_node( + task_type, task_id + ).is_succeeded() ) request.event_type = NodeEventType.NODE_CHECK_FAILED request.message = "" self.assertTrue(self.servicer._deal_with_reported_node_event(request)) self.assertTrue( - self.job_manager._job_nodes[task_type][ - task_id - ].is_node_check_failed() + self.job_manager._job_context.job_node( + task_type, task_id + ).is_node_check_failed() ) request.event_type = NodeEventType.SUCCEEDED request.message = "" self.assertTrue(self.servicer._deal_with_reported_node_event(request)) self.assertTrue( - self.job_manager._job_nodes[task_type][task_id].is_succeeded() + self.job_manager._job_context.job_node( + task_type, task_id + ).is_succeeded() ) self.assertFalse( - self.job_manager._job_nodes[task_type][ - task_id - ].is_node_check_failed() + self.job_manager._job_context.job_node( + task_type, task_id + ).is_node_check_failed() ) @@ -481,6 +500,10 @@ def setUp(self) -> None: job_metric_collector=self.job_metric_collector, elastic_ps_service=self.elastic_ps_service, ) + self.job_context = get_job_context() + + def tearDown(self) -> None: + self.job_context.clear_job_nodes() def test_update_node_addr(self): request = grpc.NodeMeta() @@ -492,11 +515,12 @@ def test_update_node_addr(self): request.addr = "localhost:5001" self.job_manager._init_nodes() self.servicer._update_node_address(request) - self.assertEqual( - self.job_manager._job_nodes[task_type][task_id].service_addr, addr - ) - for node in self.job_manager._job_nodes[NodeType.PS].values(): + node = self.job_context.job_node(task_type, task_id) + self.assertEqual(node.service_addr, addr) + ps_nodes = self.job_context.job_nodes_by_type(NodeType.PS) + for node in ps_nodes.values(): node.status = NodeStatus.RUNNING + self.job_context.update_job_node(node) res = self.servicer._query_ps_nodes() self.assertEqual(addr, res.nodes[task_id].addr) self.assertEqual("", res.nodes[0].addr) diff --git a/dlrover/python/tests/test_sync_service.py b/dlrover/python/tests/test_sync_service.py index bbc8140a1..7a5e4455c 100644 --- a/dlrover/python/tests/test_sync_service.py +++ b/dlrover/python/tests/test_sync_service.py @@ -17,6 +17,7 @@ from dlrover.python.master.elastic_training.sync_service import SyncService from dlrover.python.master.monitor.speed_monitor import SpeedMonitor from dlrover.python.master.node.dist_job_manager import create_job_manager +from dlrover.python.master.node.job_context import get_job_context from dlrover.python.tests.test_utils import MockK8sPSJobArgs, mock_k8s_client @@ -28,22 +29,30 @@ def setUp(self) -> None: params.initilize() self.job_manager = create_job_manager(params, SpeedMonitor()) self.job_manager._init_nodes() + self.job_context = get_job_context() + + def tearDown(self): + self.job_context.clear_job_nodes() def test_sync(self): sync_service = SyncService(self.job_manager) - for node in self.job_manager._job_nodes[NodeType.CHIEF].values(): + job_nodes = self.job_context.job_nodes() + for node in job_nodes[NodeType.CHIEF].values(): node.status = NodeStatus.RUNNING + self.job_context.update_job_node(node) - for node in self.job_manager._job_nodes[NodeType.WORKER].values(): + for node in job_nodes[NodeType.WORKER].values(): node.status = NodeStatus.RUNNING + self.job_context.update_job_node(node) sync_name = "sync-0" - for node in self.job_manager._job_nodes[NodeType.CHIEF].values(): + job_nodes = self.job_context.job_nodes() + for node in job_nodes[NodeType.CHIEF].values(): sync_service.join_sync(sync_name, node.type, node.id) finished = sync_service.sync_finished(sync_name) self.assertFalse(finished) - for node in self.job_manager._job_nodes[NodeType.WORKER].values(): + for node in job_nodes[NodeType.WORKER].values(): sync_service.join_sync(sync_name, node.type, node.id) finished = sync_service.sync_finished(sync_name) self.assertTrue(finished) diff --git a/dlrover/python/tests/test_time_util.py b/dlrover/python/tests/test_time_util.py new file mode 100644 index 000000000..ea5f5afcb --- /dev/null +++ b/dlrover/python/tests/test_time_util.py @@ -0,0 +1,31 @@ +# Copyright 2024 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from datetime import datetime, timedelta + +import dlrover.python.util.time_util as tu + + +class TimeUtilTest(unittest.TestCase): + def test_has_expired(self): + self.assertFalse( + tu.has_expired( + (datetime.now() + timedelta(seconds=1)).timestamp(), 500 + ) + ) + self.assertTrue( + tu.has_expired( + (datetime.now() - timedelta(seconds=1)).timestamp(), 500 + ) + ) diff --git a/dlrover/python/tests/test_worker_manager.py b/dlrover/python/tests/test_worker_manager.py index bdd9d7c16..cdca564b9 100644 --- a/dlrover/python/tests/test_worker_manager.py +++ b/dlrover/python/tests/test_worker_manager.py @@ -25,6 +25,7 @@ ) from dlrover.python.common.global_context import Context from dlrover.python.common.node import Node, NodeGroupResource, NodeResource +from dlrover.python.master.node.job_context import get_job_context from dlrover.python.master.node.worker import ChiefManager, WorkerManager from dlrover.python.master.resource.job import JobResource from dlrover.python.scheduler.factory import new_elastic_job @@ -43,27 +44,33 @@ def setUp(self) -> None: self._elastic_job = new_elastic_job( PlatformType.KUBERNETES, "test", "default" ) - self._job_nodes = self._job_resource.init_job_node_meta( + + self.job_context = get_job_context() + job_nodes = self._job_resource.init_job_node_meta( 1, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + self.job_context.update_job_nodes(job_nodes) self._worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + def tearDown(self) -> None: + self.job_context.clear_job_nodes() + def test_scale_up_workers(self): self._worker_manager._scale_up_workers(3) - self.assertEqual(len(self._worker_manager._nodes), 8) - self.assertEqual(self._worker_manager._nodes[7].id, 7) + workers = self.job_context.workers + self.assertEqual(len(workers), 8) + self.assertEqual(workers[7].id, 7) def test_scale_down_workers(self): - workers = list(self._worker_manager._nodes.values()) + workers = list(self.job_context.workers.values()) self._worker_manager._scale_down_workers(2, workers) released_workers = [] for worker in workers: @@ -72,8 +79,11 @@ def test_scale_down_workers(self): self.assertEqual(len(released_workers), 2) def test_delete_exited_workers(self): - self._worker_manager._nodes[3].status = NodeStatus.FINISHED - self._worker_manager._nodes[4].status = NodeStatus.FAILED + workers = self.job_context.workers + workers[3].status = NodeStatus.FINISHED + self.job_context.update_job_node(workers[3]) + workers[4].status = NodeStatus.FAILED + self.job_context.update_job_node(workers[4]) plan = self._worker_manager.delete_exited_workers() node_names = [node.name for node in plan.remove_nodes] @@ -83,8 +93,9 @@ def test_delete_exited_workers(self): ) def test_delete_running_workers(self): - for node in self._worker_manager._nodes.values(): + for node in self.job_context.workers.values(): node.status = NodeStatus.RUNNING + self.job_context.update_job_node(node) plan = self._worker_manager.delete_running_workers() node_names = [node.name for node in plan.remove_nodes] self.assertListEqual( @@ -100,21 +111,21 @@ def test_delete_running_workers(self): def test_relaunch_node(self): worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - failed_worker = self._job_nodes[NodeType.WORKER][4] + failed_worker = self.job_context.workers[4] failed_worker.status = NodeStatus.FAILED failed_worker.max_relaunch_count = 3 + self.job_context.update_job_node(failed_worker) plan = worker_manager.relaunch_node( failed_worker, remove_exited_node=True ) self.assertEqual(plan.launch_nodes[0].config_resource.cpu, 16) + self.assertEqual(self.job_context.workers[5].id, 5) self.assertEqual(plan.launch_nodes[0].max_relaunch_count, 3) - self.assertEqual(worker_manager._nodes[5].id, 5) self.assertEqual(plan.remove_nodes[0].config_resource.cpu, 16) def test_relaunch_chief_node(self): @@ -123,71 +134,78 @@ def test_relaunch_chief_node(self): node_id=0, config_resource=NodeResource(cpu=16, memory=10240), ) + job_nodes = { + NodeType.MASTER: {0: tf_master_node}, + } + self.job_context.update_job_nodes(job_nodes) manager = ChiefManager( - {0: tf_master_node}, self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) plan = manager.relaunch_node(tf_master_node) + nodes = self.job_context.job_nodes_by_type(NodeType.CHIEF) self.assertEqual(plan.launch_nodes[0].config_resource.cpu, 16) - self.assertEqual(manager._nodes[1].id, 1) + self.assertEqual(nodes[1].id, 1) def test_reduce_pending_node_resource(self): worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - for node in worker_manager._nodes.values(): + for node in self.job_context.workers.values(): node.status = NodeStatus.PENDING node.create_time = datetime.now() + timedelta(days=-1) + self.job_context.update_job_node(node) plan = worker_manager.reduce_pending_node_resource() self.assertEqual(len(plan.launch_nodes), 5) - for node in worker_manager._nodes.values(): + for node in self.job_context.workers.values(): node.config_resource.gpu_num = 1 + self.job_context.update_job_node(node) plan = worker_manager.reduce_pending_node_resource() self.assertTrue(plan.empty()) def test_pending_without_workers(self): worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - for node in worker_manager._nodes.values(): + for node in self.job_context.workers.values(): node.status = NodeStatus.FAILED node.exit_reason = NodeExitReason.FATAL_ERROR + self.job_context.update_job_node(node) exited = worker_manager.has_exited_worker() self.assertTrue(exited) - for node in worker_manager._nodes.values(): + for node in self.job_context.workers.values(): node.exit_reason = NodeExitReason.KILLED + self.job_context.update_job_node(node) exited = worker_manager.has_exited_worker() self.assertFalse(exited) - worker_manager._nodes[0].status = NodeStatus.SUCCEEDED + self.job_context.workers[0].status = NodeStatus.SUCCEEDED + self.job_context.update_job_node(self.job_context.workers[0]) exited = worker_manager.has_exited_worker() self.assertTrue(exited) wait = worker_manager.wait_worker_restart() self.assertTrue(wait) - for node in worker_manager._nodes.values(): + for node in self.job_context.workers.values(): node.relaunch_count = node.max_relaunch_count + self.job_context.update_job_node(node) wait = worker_manager.wait_worker_restart() self.assertFalse(wait) def test_verify_restarting_training(self): worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, @@ -195,17 +213,18 @@ def test_verify_restarting_training(self): ) reset = worker_manager.verify_restarting_training(0) self.assertFalse(reset) - worker_manager._nodes[0].restart_training = True + self.job_context.workers[0].restart_training = True + self.job_context.update_job_node(self.job_context.workers[0]) reset = worker_manager.verify_restarting_training(0) self.assertTrue(reset) - worker_manager._nodes[0].is_released = True + self.job_context.workers[0].is_released = True + self.job_context.update_job_node(self.job_context.workers[0]) reset = worker_manager.verify_restarting_training(0) self.assertFalse(reset) def test_is_training_hang_by_pending_workers(self): _dlrover_ctx.pending_fail_strategy = 2 worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, @@ -256,7 +275,7 @@ def test_is_training_hang_by_pending_workers(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + self.job_context.update_job_node(mock_node) self.assertFalse( worker_manager.is_training_hang_by_pending( worker_num, DistributionStrategy.ALLREDUCE @@ -268,6 +287,7 @@ def test_is_training_hang_by_pending_workers(self): ) ) mock_nodes.clear() + self.job_context.clear_job_nodes() # mock with 3 running + 1 pending long time for index in range(4): @@ -284,7 +304,8 @@ def test_is_training_hang_by_pending_workers(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + self.job_context.update_job_node(mock_node) + self.assertTrue( worker_manager.is_training_hang_by_pending( worker_num, DistributionStrategy.ALLREDUCE @@ -296,6 +317,7 @@ def test_is_training_hang_by_pending_workers(self): ) ) mock_nodes.clear() + self.job_context.clear_job_nodes() # mock with 4 running + 1 pending long time worker_num = 5 @@ -313,7 +335,7 @@ def test_is_training_hang_by_pending_workers(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + self.job_context.update_job_node(mock_node) self.assertFalse( worker_manager.is_training_hang_by_pending( worker_num, DistributionStrategy.ALLREDUCE @@ -325,6 +347,7 @@ def test_is_training_hang_by_pending_workers(self): ) ) mock_nodes.clear() + self.job_context.clear_job_nodes() # mock with 3 running + 1 initial long time worker_num = 4 @@ -342,7 +365,7 @@ def test_is_training_hang_by_pending_workers(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + self.job_context.update_job_node(mock_node) self.assertTrue( worker_manager.is_training_hang_by_pending( worker_num, DistributionStrategy.ALLREDUCE @@ -375,7 +398,7 @@ def test_is_training_hang_by_pending_workers(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-10) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + self.job_context.update_job_node(mock_node) self.assertFalse( worker_manager.is_training_hang_by_pending( worker_num, DistributionStrategy.ALLREDUCE @@ -398,7 +421,8 @@ def test_is_training_hang_by_pending_workers(self): ) mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + self.job_context.update_job_node(mock_node) + self.assertTrue( worker_manager.is_training_hang_by_pending( worker_num, DistributionStrategy.ALLREDUCE @@ -422,7 +446,7 @@ def test_is_training_hang_by_pending_workers(self): ) mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + self.job_context.update_job_node(mock_node) self.assertTrue( worker_manager.is_training_hang_by_pending( worker_num, DistributionStrategy.ALLREDUCE @@ -448,7 +472,7 @@ def test_is_training_hang_by_pending_workers(self): mock_node.status = NodeStatus.RUNNING mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + self.job_context.update_job_node(mock_node) self.assertFalse( worker_manager.is_training_hang_by_pending( worker_num, DistributionStrategy.ALLREDUCE @@ -472,7 +496,7 @@ def test_is_training_hang_by_pending_workers(self): ) mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + self.job_context.update_job_node(mock_node) self.assertFalse( worker_manager.is_training_hang_by_pending( worker_num, DistributionStrategy.ALLREDUCE @@ -516,7 +540,6 @@ def test_is_training_hang_by_pending_workers(self): def test_is_training_hang_by_insufficient_worker(self): worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, @@ -549,7 +572,7 @@ def test_is_training_hang_by_insufficient_worker(self): if index == 0: mock_node.status = NodeStatus.PENDING mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + self.job_context.update_job_node(mock_node) for _ in range(5): if worker_manager.is_training_hang_by_insufficient_worker(): is_insufficient += 1 @@ -557,6 +580,7 @@ def test_is_training_hang_by_insufficient_worker(self): self.assertEqual(is_insufficient, 0) mock_nodes.clear() is_insufficient = 0 + self.job_context.clear_job_nodes() # mock with 3 running for index in range(3): @@ -568,7 +592,7 @@ def test_is_training_hang_by_insufficient_worker(self): NodeStatus.RUNNING, ) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + self.job_context.update_job_node(mock_node) for _ in range(5): if worker_manager.is_training_hang_by_insufficient_worker(): is_insufficient += 1 @@ -576,6 +600,7 @@ def test_is_training_hang_by_insufficient_worker(self): self.assertTrue(is_insufficient >= 2) mock_nodes.clear() is_insufficient = 0 + self.job_context.clear_job_nodes() # mock with 3 running + 1 released for index in range(4): @@ -589,8 +614,8 @@ def test_is_training_hang_by_insufficient_worker(self): if index == 0: mock_node.status = NodeStatus.DELETED mock_node.is_released = True + self.job_context.update_job_node(mock_node) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes for _ in range(5): if worker_manager.is_training_hang_by_insufficient_worker(): is_insufficient += 1 diff --git a/dlrover/python/diagnosis/common/diagnose_action.py b/dlrover/python/util/time_util.py similarity index 71% rename from dlrover/python/diagnosis/common/diagnose_action.py rename to dlrover/python/util/time_util.py index ea96de464..04c6b9c19 100644 --- a/dlrover/python/diagnosis/common/diagnose_action.py +++ b/dlrover/python/util/time_util.py @@ -11,12 +11,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from datetime import datetime, timedelta -class DiagnoseAction: - def __init__(self): - self._actions: List[str] = [] - - def add_action(self, action: str): - self._actions.append(action) +def has_expired(timestamp: float, time_period: int) -> bool: + dt = datetime.fromtimestamp(timestamp) + expired_dt = dt + timedelta(milliseconds=time_period) + return expired_dt < datetime.now()