From 223c17ab9082dcd0513cd0faeecfe2334c985112 Mon Sep 17 00:00:00 2001 From: bsang Date: Tue, 15 Oct 2024 16:06:36 +0800 Subject: [PATCH 01/10] worker agent pre-check --- dlrover/python/common/worker.py | 9 +- .../elastic_agent/config/launch_config.py | 98 +++++++++++++++++++ .../diagnosis/diagnosis_agent.py | 33 ++++--- .../python/elastic_agent/torch/training.py | 96 ++++-------------- dlrover/python/tests/test_diagnosis_agent.py | 2 +- .../tests/test_elastic_training_agent.py | 2 +- dlrover/trainer/torch/elastic_run.py | 2 +- 7 files changed, 144 insertions(+), 98 deletions(-) create mode 100644 dlrover/python/elastic_agent/config/launch_config.py diff --git a/dlrover/python/common/worker.py b/dlrover/python/common/worker.py index c09c4fdaa..9fcba69aa 100644 --- a/dlrover/python/common/worker.py +++ b/dlrover/python/common/worker.py @@ -12,15 +12,16 @@ # limitations under the License. from torch.distributed.elastic.agent.server.api import RunResult, WorkerSpec +from dlrover.python.elastic_agent.config.launch_config import ElasticLaunchConfig class WorkerContext: def __init__( self, - worker_spec: WorkerSpec, - remaining_failovers: int, - restart_count: int, - run_result: RunResult, + worker_spec: WorkerSpec = None, + remaining_failovers: int = 0, + restart_count: int = 0, + run_result: RunResult = None, ): self._worker_spec: WorkerSpec = worker_spec self.remaining_failovers = remaining_failovers diff --git a/dlrover/python/elastic_agent/config/launch_config.py b/dlrover/python/elastic_agent/config/launch_config.py new file mode 100644 index 000000000..959c9a9f1 --- /dev/null +++ b/dlrover/python/elastic_agent/config/launch_config.py @@ -0,0 +1,98 @@ +from dataclasses import dataclass +from torch.distributed.launcher.api import LaunchConfig +from dlrover.python.common.constants import ( + Accelerators, + AscendConstants, + ConfigPath, + JobConstant, + NodeEnv, + NodeErrorMessage, + NodeStatus, + RendezvousName, + TrainingExceptionLevel, +) +from typing import ( + Dict, + List, + Optional, + Tuple, + Union, +) +from torch.distributed.elastic.multiprocessing import ( + PContext, + SignalException, + Std, +) +import os +import torch +from dlrover.python.common.log import default_logger as logger + + +@dataclass +class ElasticLaunchConfig(LaunchConfig): + """ + Creates a rendezvous config of elastic training. + + Args: + network_check: whether to check the network available before training. + comm_perf_test: whether to test the communication performance. + node_unit: the number of unit of nodes. The number of nodes must be + a multiple of node_unit. + auto_config: indicate if automatically configure the nnodes and + nproc_per_node. + auto_tunning: whether to auto-tune the parallelism configuration. + exclude_straggler: The node will exit if it is a straggler in network + check and exclude_straggler is True. + save_at_breakpoint: indicate if save the checkpoint from the shared + memory into the disk after a failure occurs. + accelerator: the type of accelerator processor like nvidia.com/gpu, + ascend-npu. + training_log_file: the training log file of this training job + failure_node_errors: the error information that indicate the node + is a failure node + """ + + network_check: bool = False + comm_perf_test: bool = False + node_unit: int = 1 + training_port: int = AscendConstants.HCCL_PORT_START_DEFAULT + auto_config: bool = False + auto_tunning: bool = False + exclude_straggler: bool = False + save_at_breakpoint: bool = False + accelerator: str = "" + log_dir: Optional[str] = None # Keep Compatibility with PyTorch>=2.3.0 + redirects: Union[Std, Dict[int, Std]] = Std.NONE + tee: Union[Std, Dict[int, Std]] = Std.NONE + training_log_file: str = "" + failure_node_errors: str = "" + + def set_node_unit(self, node_unit): + """Set the number unit of nodes.""" + self.node_unit = node_unit + self.rdzv_configs["node_unit"] = node_unit + + def auto_configure_params(self): + self.training_log_file = os.getenv(NodeEnv.TRAINING_LOG_FILE, "") + self.failure_node_errors = os.getenv(NodeEnv.FAILURE_NODE_ERRORS, "") + if len(self.failure_node_errors) > 0: + errors = self.failure_node_errors.strip() + if errors[0] != "#" or errors[-1] != "#": + logger.warning("invalid failure node errors: %s", errors) + self.failure_node_errors = "" + + device = "" + if torch.cuda.is_available(): + device = torch.cuda.get_device_name() + if "Ascend" in device: + self.accelerator = Accelerators.ASCEND_NPU + if not self.auto_config: + return + + if NodeEnv.NODE_NUM in os.environ: + self.min_nodes = int(os.environ[NodeEnv.NODE_NUM]) + self.max_nodes = int(os.environ[NodeEnv.NODE_NUM]) + if torch.cuda.is_available(): + self.nproc_per_node = torch.cuda.device_count() + if self.min_nodes >= 4: + self.network_check = True \ No newline at end of file diff --git a/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py b/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py index 7b1619829..d472461d7 100644 --- a/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py +++ b/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py @@ -91,6 +91,20 @@ def start(self): def stop(self): self._stopped = True + def diagnose_problems( + self, problems: List[Inference] + ) -> DiagnoseAction: + conclusions: List[Inference] = [] + for problem in problems: + ic = InferenceChain([problem], self._diagnosis_operators) + try: + infs = ic.infer() + if len(infs) > 0: + conclusions = combine_inferences(conclusions, infs) + except Exception as e: + logger.error(f"fail to diagnose observation {problem}: {e}") + return coordinate_inferences(conclusions) + def _observe(self) -> List[Inference]: observations: List[Inference] = [] for problem in self._observe_problems: @@ -103,20 +117,6 @@ def _observe(self) -> List[Inference]: logger.error(f"fail to observe problem {problem}: {e}") return observations - def _diagnose_observations( - self, observations: List[Inference] - ) -> DiagnoseAction: - conclusions: List[Inference] = [] - for ob in observations: - ic = InferenceChain([ob], self._diagnosis_operators) - try: - infs = ic.infer() - if len(infs) > 0: - conclusions = combine_inferences(conclusions, infs) - except Exception as e: - logger.error(f"fail to diagnose observation {ob}: {e}") - return coordinate_inferences(conclusions) - def _periodically_diagnosis(self): logger.info("Start periodically diagnosis...") while True: @@ -127,7 +127,7 @@ def _periodically_diagnosis(self): observations = self._observe() if len(observations) > 0: logger.info(f"Observed problems: {observations}") - self._diagnose_observations(observations) + self.diagnose_problems(observations) time.sleep( DiagnosisConstant.AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS @@ -197,3 +197,6 @@ def _report_failure_to_master( def _report_metric_to_master(self, agent_metric: WorkerTrainingMetric): self._client.report_diagnosis_agent_metrics(agent_metric) + + + diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 29eebd7d4..0fa228bd2 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -23,7 +23,6 @@ from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from contextlib import closing -from dataclasses import dataclass from datetime import datetime, timedelta from typing import ( Any, @@ -36,7 +35,6 @@ Union, ) -import torch import torch.distributed.elastic.timer as timer from torch.distributed import PrefixStore, Store from torch.distributed.elastic import events, metrics @@ -58,7 +56,6 @@ from torch.distributed.elastic.multiprocessing import ( PContext, SignalException, - Std, ) from torch.distributed.elastic.multiprocessing.errors import ( ChildFailedError, @@ -74,7 +71,6 @@ AscendConstants, ConfigPath, JobConstant, - NodeEnv, NodeErrorMessage, NodeStatus, RendezvousName, @@ -103,6 +99,13 @@ version_less_than_230, version_less_than_240, ) +from dlrover.python.diagnosis.common.inference_chain import ( + Inference, + InferenceName, + InferenceAttribute, + InferenceDescription, +) +from dlrover.python.elastic_agent.config.launch_config import ElasticLaunchConfig try: from torch_npu.contrib import transfer_to_npu # noqa: F401 @@ -136,76 +139,6 @@ class RendezvousOutSyncError(Exception): pass -@dataclass -class ElasticLaunchConfig(LaunchConfig): - """ - Creates a rendezvous config of elastic training. - - Args: - network_check: whether to check the network available before training. - comm_perf_test: whether to test the communication performance. - node_unit: the number of unit of nodes. The number of nodes must be - a multiple of node_unit. - auto_config: indicate if automatically configure the nnodes and - nproc_per_node. - auto_tunning: whether to auto-tune the parallelism configuration. - exclude_straggler: The node will exit if it is a straggler in network - check and exclude_straggler is True. - save_at_breakpoint: indicate if save the checkpoint from the shared - memory into the disk after a failure occurs. - accelerator: the type of accelerator processor like nvidia.com/gpu, - ascend-npu. - training_log_file: the training log file of this training job - failure_node_errors: the error information that indicate the node - is a failure node - """ - - network_check: bool = False - comm_perf_test: bool = False - node_unit: int = 1 - training_port: int = AscendConstants.HCCL_PORT_START_DEFAULT - auto_config: bool = False - auto_tunning: bool = False - exclude_straggler: bool = False - save_at_breakpoint: bool = False - accelerator: str = "" - log_dir: Optional[str] = None # Keep Compatibility with PyTorch>=2.3.0 - redirects: Union[Std, Dict[int, Std]] = Std.NONE - tee: Union[Std, Dict[int, Std]] = Std.NONE - training_log_file: str = "" - failure_node_errors: str = "" - - def set_node_unit(self, node_unit): - """Set the number unit of nodes.""" - self.node_unit = node_unit - self.rdzv_configs["node_unit"] = node_unit - - def auto_configure_params(self): - self.training_log_file = os.getenv(NodeEnv.TRAINING_LOG_FILE, "") - self.failure_node_errors = os.getenv(NodeEnv.FAILURE_NODE_ERRORS, "") - if len(self.failure_node_errors) > 0: - errors = self.failure_node_errors.strip() - if errors[0] != "#" or errors[-1] != "#": - logger.warning("invalid failure node errors: %s", errors) - self.failure_node_errors = "" - - device = "" - if torch.cuda.is_available(): - device = torch.cuda.get_device_name() - if "Ascend" in device: - self.accelerator = Accelerators.ASCEND_NPU - if not self.auto_config: - return - - if NodeEnv.NODE_NUM in os.environ: - self.min_nodes = int(os.environ[NodeEnv.NODE_NUM]) - self.max_nodes = int(os.environ[NodeEnv.NODE_NUM]) - if torch.cuda.is_available(): - self.nproc_per_node = torch.cuda.device_count() - if self.min_nodes >= 4: - self.network_check = True - - class MasterRendezvousHandler(RendezvousHandler): """The rendezvous handler completes rendezvous by connecting with the ElasticJob master. The master will collect all nodes @@ -740,6 +673,18 @@ def _assign_worker_ranks( workers.append(worker) return workers + def _pre_check(self): + problems: List[Inference] = [] + if self._config.network_check: + succ = run_network_check(self._config, self._entrypoint) + if not succ: + problems.append(Inference( + name=InferenceName.NODE, + attribution=InferenceAttribute.IS, + description=InferenceDescription.FAILURE, + )) + self._diagnose_agent.diagnose_problems(problems) + def _initialize_workers(self, worker_group): logger.info("Start initializing training workers.") start_pending = 0 @@ -748,8 +693,7 @@ def _initialize_workers(self, worker_group): ) while True: try: - if self._config.network_check: - run_network_check(self._config, self._entrypoint) + self._pre_check() super()._initialize_workers(worker_group) # We need to register handler after starting workers because # the PContext start_worker will overwrite the handler. diff --git a/dlrover/python/tests/test_diagnosis_agent.py b/dlrover/python/tests/test_diagnosis_agent.py index c6770f677..c1ed97c43 100644 --- a/dlrover/python/tests/test_diagnosis_agent.py +++ b/dlrover/python/tests/test_diagnosis_agent.py @@ -30,10 +30,10 @@ build_master_client, ) from dlrover.python.elastic_agent.torch.training import ( - ElasticLaunchConfig, _create_worker_spec, ) from dlrover.python.tests.test_utils import start_local_master +from dlrover.python.elastic_agent.config.launch_config import ElasticLaunchConfig class TestDiagnosisAgent(unittest.TestCase): diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index afcd03ffb..75b723f01 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -46,7 +46,6 @@ DdpCheckpointSaver, ) from dlrover.python.elastic_agent.torch.training import ( - ElasticLaunchConfig, ElasticTrainingAgent, MasterRendezvousHandler, NodeCheckElasticAgent, @@ -59,6 +58,7 @@ node_health_check, ) from dlrover.python.tests.test_utils import start_local_master +from dlrover.python.elastic_agent.config.launch_config import ElasticLaunchConfig class ElasticTrainingAgentTest(unittest.TestCase): diff --git a/dlrover/trainer/torch/elastic_run.py b/dlrover/trainer/torch/elastic_run.py index e6d057203..39dd56f7a 100644 --- a/dlrover/trainer/torch/elastic_run.py +++ b/dlrover/trainer/torch/elastic_run.py @@ -116,10 +116,10 @@ from dlrover.python.common.log import default_logger as logger from dlrover.python.elastic_agent.master_client import MasterClient from dlrover.python.elastic_agent.torch.training import ( - ElasticLaunchConfig, launch_agent, ) from dlrover.trainer.torch.utils import version_less_than_230 +from dlrover.python.elastic_agent.config.launch_config import ElasticLaunchConfig def parse_args(args): From a57028c32c5e4dd76e421ed86d1fed2234e36045 Mon Sep 17 00:00:00 2001 From: bsang Date: Wed, 16 Oct 2024 09:53:44 +0800 Subject: [PATCH 02/10] update --- dlrover/python/common/grpc.py | 11 +++++++++++ dlrover/python/master/diagnosis/diagnosis.py | 14 +++++++------- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/dlrover/python/common/grpc.py b/dlrover/python/common/grpc.py index a0c78a756..c420808c8 100644 --- a/dlrover/python/common/grpc.py +++ b/dlrover/python/common/grpc.py @@ -514,3 +514,14 @@ class ElasticRunConfig(Message): @dataclass class SucceededRequest(Message): pass + + +@dataclass +class DiagnosisAction(Message): + instance: str = "" + action: str = "" + + +@dataclass +class HeartbeatResponse(Message): + diagnosis_action: DiagnosisAction = DiagnosisAction() diff --git a/dlrover/python/master/diagnosis/diagnosis.py b/dlrover/python/master/diagnosis/diagnosis.py index 17dd073ea..48692775a 100644 --- a/dlrover/python/master/diagnosis/diagnosis.py +++ b/dlrover/python/master/diagnosis/diagnosis.py @@ -72,7 +72,7 @@ def start_observing(self): try: thread = threading.Thread( - target=self._diagnose_failures(), + target=self._diagnose(), name="diagnose_failures", daemon=True, ) @@ -88,7 +88,7 @@ def stop_observing(self): logger.info("Stop Diagnosis Manager to observing training.") self._is_observing_started = False - def _diagnose_failures(self): + def _diagnose(self): logger.info("Start to diagnose failures for observing.") while True: if not self._is_observing_started: @@ -96,6 +96,8 @@ def _diagnose_failures(self): break observed_problems = self._diagnostician.observe_training() + + solutions: List[Inference] = [] for problem in observed_problems: logger.info(f"observed problems: {problem}") root_causes = self._diagnostician.diagnose_failure(problem) @@ -144,9 +146,7 @@ def __init__(self, data_manager): self._data_manager = data_manager self._pre_checks: List[Inference] = [] self._training_problems: List[Inference] = [] - - def get_pre_check_operators(self) -> List[InferenceOperator]: - return [] + self._observing_operators = def get_observing_operators(self) -> List[InferenceOperator]: return [CheckTrainingHangOperator(self._data_manager)] @@ -167,5 +167,5 @@ def observe_training(self) -> List[Inference]: ) return ic.infer() - def diagnose_failure(self, inference: Inference) -> List[Inference]: - pass + def diagnose_problem(self, inference: Inference) -> List[Inference]: + return [] From 2449f620a7b567c7933f9089de8f8792f61be0c4 Mon Sep 17 00:00:00 2001 From: bsang Date: Fri, 18 Oct 2024 16:25:38 +0800 Subject: [PATCH 03/10] complete impl --- dlrover/python/common/grpc.py | 6 +- dlrover/python/common/time.py | 7 ++ dlrover/python/diagnosis/common/constants.py | 3 + .../diagnosis/common/diagnose_action.py | 74 ++++++++++++++- ...oordinator.py => coordinate_inferences.py} | 4 +- .../inferenceoperator/operator.py | 11 +++ .../python/elastic_agent/common/__init__.py | 0 .../{ => elastic_agent}/common/worker.py | 41 ++++++--- .../diagnosis/diagnosis_agent.py | 66 ++++++++++---- dlrover/python/elastic_agent/master_client.py | 17 +++- .../python/elastic_agent/monitor/training.py | 8 -- .../python/elastic_agent/torch/training.py | 26 ++++-- .../diagnosis/diagnosis_data_manager.py | 37 ++++++++ .../{diagnosis.py => diagnosis_manager.py} | 91 +++++++------------ dlrover/python/master/dist_master.py | 8 +- .../python/master/node/dist_job_manager.py | 45 ++++++++- dlrover/python/master/node/job.py | 16 ++++ dlrover/python/master/node/job_manager.py | 12 ++- .../python/master/node/local_job_manager.py | 5 +- dlrover/python/master/servicer.py | 27 +++--- dlrover/python/tests/test_diagnosis.py | 2 +- dlrover/python/tests/test_diagnosis_agent.py | 2 +- dlrover/python/tests/test_job_manager.py | 2 +- dlrover/python/tests/test_servicer.py | 2 +- 24 files changed, 367 insertions(+), 145 deletions(-) create mode 100644 dlrover/python/common/time.py rename dlrover/python/diagnosis/inferencechain/{coordinator.py => coordinate_inferences.py} (88%) create mode 100644 dlrover/python/elastic_agent/common/__init__.py rename dlrover/python/{ => elastic_agent}/common/worker.py (61%) create mode 100644 dlrover/python/master/diagnosis/diagnosis_data_manager.py rename dlrover/python/master/diagnosis/{diagnosis.py => diagnosis_manager.py} (62%) create mode 100644 dlrover/python/master/node/job.py diff --git a/dlrover/python/common/grpc.py b/dlrover/python/common/grpc.py index c420808c8..aeb9bc958 100644 --- a/dlrover/python/common/grpc.py +++ b/dlrover/python/common/grpc.py @@ -518,10 +518,12 @@ class SucceededRequest(Message): @dataclass class DiagnosisAction(Message): - instance: str = "" + rank: int = -1 + timestamp: float = 0.0 + expired_time_period: int = 0 action: str = "" @dataclass class HeartbeatResponse(Message): - diagnosis_action: DiagnosisAction = DiagnosisAction() + diagnosis_actions: List[DiagnosisAction] = field(default_factory=list) diff --git a/dlrover/python/common/time.py b/dlrover/python/common/time.py new file mode 100644 index 000000000..657f64cac --- /dev/null +++ b/dlrover/python/common/time.py @@ -0,0 +1,7 @@ +from datetime import datetime, timedelta + + +def has_expired(timestamp: float, time_period: int) -> bool: + dt = datetime.fromtimestamp(timestamp) + expired_dt = dt + timedelta(seconds=time_period) + return expired_dt < datetime.now() diff --git a/dlrover/python/diagnosis/common/constants.py b/dlrover/python/diagnosis/common/constants.py index 601bb730d..82f10f393 100644 --- a/dlrover/python/diagnosis/common/constants.py +++ b/dlrover/python/diagnosis/common/constants.py @@ -24,6 +24,8 @@ class InferenceConfigKey(object): class DiagnosisConstant(object): MASTER_DIAGNOSIS_OBSERVING_INTERVAL_SECS = 180 AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS = 60 + MASTER_RANK = -1 + ANY_RANK = -2 class DiagnosisDataType(object): @@ -36,3 +38,4 @@ class DiagnosisAction(object): NO_ACTION = "no_action" RESTART_WORKER = "restart_worker" RELAUNCH_WORKER = "relaunch_worker" + ACTION_EXPIRED_TIME_PERIOD = 60 * 5 diff --git a/dlrover/python/diagnosis/common/diagnose_action.py b/dlrover/python/diagnosis/common/diagnose_action.py index ea96de464..33e03c0ee 100644 --- a/dlrover/python/diagnosis/common/diagnose_action.py +++ b/dlrover/python/diagnosis/common/diagnose_action.py @@ -10,13 +10,79 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import threading +from datetime import datetime +from dlrover.python.common.time import has_expired from typing import List +from dlrover.python.common.log import default_logger as logger +from dlrover.python.diagnosis.common.constants import ( + DiagnosisConstant, + DiagnosisAction as DiagnosisActionConstants, +) class DiagnoseAction: + def __init__(self, timestamp=0, expired_time_period=0, action="", rank=DiagnosisConstant.ANY_RANK): + self.action = action + if timestamp == 0: + self.timestamp = int(round(datetime.now().timestamp())) + else: + self.timestamp= timestamp + + if expired_time_period == 0: + self.expired_time_period = DiagnosisActionConstants.ACTION_EXPIRED_TIME_PERIOD + else: + self.expired_time_period = expired_time_period + # rank indicates the rank (worker) which is to + # execute this action. Rank = -1 indicates the + # master is to execute the action + self.rank = rank + + def set_action(self, action: str): + self.action = action + + def has_expired(self) -> bool: + return has_expired(self.timestamp, self.expired_time_period) + + +def is_same_action(action1: DiagnoseAction, action2: DiagnoseAction) -> bool: + return action1.action == action2.action and action1.rank == action2.rank + + +class DiagnoseActionQueue: def __init__(self): - self._actions: List[str] = [] + self._actions: List[DiagnoseAction] = [] + self._lock = threading.Lock() + + def add_action(self, new_action: DiagnoseAction): + with self._lock: + for action in self._actions: + if is_same_action(new_action, action): + return + logger.info(f"enqueue action {new_action.action} of {new_action.rank}") + self._actions.append(new_action) + + def _remove_expired_actions(self): + with self._lock: + actions = [] + for action in self._actions: + if not action.has_expired(): + actions.append(action) + else: + logger.info(f"Action {action} has expired") + + self._actions = actions + + def next_actions(self, rank=DiagnosisConstant.ANY_RANK) -> List[DiagnoseAction]: + self._remove_expired_actions() + with self._lock: + rank_actions = [] + remain_actions = [] + for action in self._actions: + if action.rank == rank or rank == DiagnosisConstant.ANY_RANK: + rank_actions.append(action) + else: + remain_actions.append(action) + return rank_actions + - def add_action(self, action: str): - self._actions.append(action) diff --git a/dlrover/python/diagnosis/inferencechain/coordinator.py b/dlrover/python/diagnosis/inferencechain/coordinate_inferences.py similarity index 88% rename from dlrover/python/diagnosis/inferencechain/coordinator.py rename to dlrover/python/diagnosis/inferencechain/coordinate_inferences.py index 07cb70326..c8fb2d32c 100644 --- a/dlrover/python/diagnosis/inferencechain/coordinator.py +++ b/dlrover/python/diagnosis/inferencechain/coordinate_inferences.py @@ -17,5 +17,5 @@ from dlrover.python.diagnosis.common.inference_chain import Inference -def coordinate_inferences(observations: List[Inference]) -> DiagnoseAction: - return DiagnoseAction() +def coordinate_inferences(problems: List[Inference]) -> List[DiagnoseAction]: + return [] diff --git a/dlrover/python/diagnosis/inferencechain/inferenceoperator/operator.py b/dlrover/python/diagnosis/inferencechain/inferenceoperator/operator.py index 5f213873a..5757afe65 100644 --- a/dlrover/python/diagnosis/inferencechain/inferenceoperator/operator.py +++ b/dlrover/python/diagnosis/inferencechain/inferenceoperator/operator.py @@ -17,6 +17,10 @@ from dlrover.python.diagnosis.inferencechain.inferenceoperator.metrics_collection_operator import ( # noqa: E501 MetricsCollectionOperator, ) +from dlrover.python.master.diagnosis.diagnosis_data_manager import DiagnosisDataManager +from dlrover.python.diagnosis.inferencechain.inferenceoperator.check_training_hang_operator import ( # noqa: E501 + CheckTrainingHangOperator, +) def get_training_failure_operators(): @@ -29,3 +33,10 @@ def get_worker_observe_operators(): def get_worker_diagnosis_operators(): return [] + + +def get_master_observe_operators(data_mgr: DiagnosisDataManager = None): + return [ + CheckTrainingHangOperator(data_mgr), + ] + diff --git a/dlrover/python/elastic_agent/common/__init__.py b/dlrover/python/elastic_agent/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dlrover/python/common/worker.py b/dlrover/python/elastic_agent/common/worker.py similarity index 61% rename from dlrover/python/common/worker.py rename to dlrover/python/elastic_agent/common/worker.py index 9fcba69aa..7b59962ec 100644 --- a/dlrover/python/common/worker.py +++ b/dlrover/python/elastic_agent/common/worker.py @@ -12,21 +12,20 @@ # limitations under the License. from torch.distributed.elastic.agent.server.api import RunResult, WorkerSpec -from dlrover.python.elastic_agent.config.launch_config import ElasticLaunchConfig +from dlrover.python.diagnosis.common.diagnose_action import ( + DiagnoseActionQueue, + DiagnoseAction, +) +from typing import Optional, List class WorkerContext: - def __init__( - self, - worker_spec: WorkerSpec = None, - remaining_failovers: int = 0, - restart_count: int = 0, - run_result: RunResult = None, - ): - self._worker_spec: WorkerSpec = worker_spec - self.remaining_failovers = remaining_failovers - self.restart_count = restart_count - self._run_result = run_result + def __init__(self): + self._worker_spec: Optional[WorkerSpec] = None + self.remaining_failovers = 0 + self.restart_count = 0 + self._run_result: Optional[RunResult] = None + self._diagnose_action_queue = DiagnoseActionQueue() @property def worker_spec(self): @@ -44,3 +43,21 @@ def to_string(self) -> str: f"restart_count: {self.restart_count}\n" f"run_result: {self._run_result}" ) + + def update_context( + self, + worker_spec: WorkerSpec = None, + remaining_failovers: int = 0, + restart_count: int = 0, + run_result: RunResult = None, + ): + self._worker_spec: WorkerSpec = worker_spec + self.remaining_failovers = remaining_failovers + self.restart_count = restart_count + self._run_result = run_result + + def enqueue_diagnose_action(self, action: DiagnoseAction): + self._diagnose_action_queue.add_action(action) + + def next_actions(self) -> List[DiagnoseAction]: + return self._diagnose_action_queue.next_actions() diff --git a/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py b/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py index d472461d7..048b44f93 100644 --- a/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py +++ b/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py @@ -23,7 +23,7 @@ from dlrover.python.common.error import ProcessError from dlrover.python.common.log import default_logger as logger from dlrover.python.common.singleton import Singleton -from dlrover.python.common.worker import WorkerContext +from dlrover.python.elastic_agent.common.worker import WorkerContext from dlrover.python.diagnosis.common.constants import ( DiagnosisAction, DiagnosisConstant, @@ -39,7 +39,7 @@ combine_inferences, is_inference_included, ) -from dlrover.python.diagnosis.inferencechain.coordinator import ( +from dlrover.python.diagnosis.inferencechain.coordinate_inferences import ( coordinate_inferences, ) from dlrover.python.diagnosis.inferencechain.inference_chain import ( @@ -54,7 +54,7 @@ class DiagnosisAgent(Singleton): - def __init__(self, training_log_file: str, errors: str): + def __init__(self, training_log_file: str, errors: str, worker_context: WorkerContext): self._client = MasterClient.singleton_instance() self._training_log_file = training_log_file self._errors = errors @@ -68,6 +68,9 @@ def __init__(self, training_log_file: str, errors: str): ] self._observe_operators = get_worker_observe_operators() self._diagnosis_operators = get_worker_diagnosis_operators() + self._worker_context = worker_context + self._diagnosis_thread = None + self._report_thread = None self.start() @@ -81,19 +84,26 @@ def start(self): self._stopped = False # start a async thread to diagnose periodically - thread = threading.Thread( + self._diagnosis_thread = threading.Thread( target=self._periodically_diagnosis, name="periodically_diagnosis", daemon=True, ) - thread.start() + self._diagnosis_thread.start() + + self._report_thread = threading.Thread( + target=self._periodically_report, + name="diagnosis_reporter", + daemon=True, + ) + self._report_thread.start() def stop(self): self._stopped = True def diagnose_problems( self, problems: List[Inference] - ) -> DiagnoseAction: + ) -> List[DiagnoseAction]: conclusions: List[Inference] = [] for problem in problems: ic = InferenceChain([problem], self._diagnosis_operators) @@ -133,9 +143,9 @@ def _periodically_diagnosis(self): DiagnosisConstant.AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS ) - def diagnose_training_failure(self, worker_context: WorkerContext) -> str: + def diagnose_training_failure(self) -> DiagnoseAction: self._report_failure_to_master( - worker_context.run_result.failures, worker_context.restart_count + self._worker_context.run_result.failures, self._worker_context.restart_count ) # check if the node is failed inference = Inference( @@ -156,25 +166,29 @@ def diagnose_training_failure(self, worker_context: WorkerContext) -> str: ) failure_node = is_inference_included(infer_results, failure_inf) - if worker_context.remaining_failovers > 0 and not failure_node: + if self._worker_context.remaining_failovers > 0 and not failure_node: logger.info( - f"[{worker_context.worker_spec.role}] Worker group " - f"{worker_context.run_result.state.name}, " + f"[{self._worker_context.worker_spec.role}] Worker group " + f"{self._worker_context.run_result.state.name}, " f"is failure node: {failure_node}," - f"{worker_context.remaining_failovers}/" - f"{worker_context.worker_spec.max_restarts} " + f"{self._worker_context.remaining_failovers}/" + f"{self._worker_context.worker_spec.max_restarts} " f"attempts left; will restart worker group." ) - return DiagnosisAction.RESTART_WORKER + return DiagnoseAction( + action=DiagnosisAction.RESTART_WORKER, + ) else: logger.info( - f"[{worker_context.worker_spec.role}] Worker group " - f"{worker_context.run_result.state.name}, " + f"[{self._worker_context.worker_spec.role}] Worker group " + f"{self._worker_context.run_result.state.name}, " f"is failure node: {failure_node}, " - f"no attempts({worker_context.worker_spec.max_restarts}) " + f"no attempts({self._worker_context.worker_spec.max_restarts}) " "left; will relaunch." ) - return DiagnosisAction.RELAUNCH_WORKER + return DiagnoseAction( + action=DiagnosisAction.RELAUNCH_WORKER, + ) def _report_failure_to_master( self, failures: Dict[int, ProcessFailure], restart_count: int @@ -198,5 +212,17 @@ def _report_failure_to_master( def _report_metric_to_master(self, agent_metric: WorkerTrainingMetric): self._client.report_diagnosis_agent_metrics(agent_metric) - - + def send_heartbeat(self): + try: + ts = int(time.time()) + actions = self._client.report_heart_beat(ts) + for action in actions: + self._worker_context.enqueue_diagnose_action(action) + except Exception as e: + logger.warning(f"fail to report a heartbeat: {e}") + + def _periodically_report(self): + logger.info("Start diagnosis agent reporter.") + while True: + self.send_heartbeat() + time.sleep(15) diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index e646da113..c2ddb67d1 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -16,7 +16,7 @@ import threading import time from contextlib import closing -from typing import Dict +from typing import Dict, List from dlrover.proto import elastic_training_pb2, elastic_training_pb2_grpc from dlrover.python.common import env_utils, grpc @@ -24,6 +24,7 @@ from dlrover.python.common.log import default_logger as logger from dlrover.python.common.singleton import Singleton from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData +from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction def retry_grpc_request(func): @@ -231,9 +232,19 @@ def report_global_step( ) return self._report(message) - def report_heart_beat(self, timestamp): + def report_heart_beat(self, timestamp) -> List[DiagnoseAction]: message = grpc.HeartBeat(timestamp=timestamp) - return self._report(message) + response: grpc.HeartbeatResponse = self._get(message) + actions: List[DiagnoseAction] = [] + for grpc_action in response.diagnosis_actions: + action = DiagnoseAction( + rank=grpc_action.rank, + timestamp=grpc_action.timestamp, + expired_time_period=grpc_action.expired_time_period, + action=grpc_action.action, + ) + actions.append(action) + return actions def get_cluster_version(self, version_type, task_type, task_id): request = grpc.ClusterVersionRequest( diff --git a/dlrover/python/elastic_agent/monitor/training.py b/dlrover/python/elastic_agent/monitor/training.py index d60396639..49591ff36 100644 --- a/dlrover/python/elastic_agent/monitor/training.py +++ b/dlrover/python/elastic_agent/monitor/training.py @@ -123,17 +123,9 @@ def report_resource_with_step(self): except Exception as e: logger.warning(e) - def send_heartbeat(self): - try: - ts = int(time.time()) - self._master_client.report_heart_beat(ts) - except Exception: - logger.warning("Fail to report a heartbeat.") - def _periodically_report(self): logger.info("Start training agent reporter.") while True: if self._group_rank == 0: self.report_resource_with_step() - self.send_heartbeat() time.sleep(15) diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 0fa228bd2..1289f472a 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -63,7 +63,7 @@ ) from torch.distributed.elastic.rendezvous import RendezvousParameters from torch.distributed.elastic.rendezvous.api import RendezvousHandler -from torch.distributed.launcher.api import LaunchConfig, _get_entrypoint_name +from torch.distributed.launcher.api import _get_entrypoint_name from dlrover.python.common import env_utils from dlrover.python.common.constants import ( @@ -83,7 +83,7 @@ find_free_port_in_set, ) from dlrover.python.common.log import default_logger as logger -from dlrover.python.common.worker import WorkerContext +from dlrover.python.elastic_agent.common.worker import WorkerContext from dlrover.python.diagnosis.common.constants import DiagnosisAction from dlrover.python.elastic_agent.config.paral_config_tuner import ( ParalConfigTuner, @@ -106,6 +106,7 @@ InferenceDescription, ) from dlrover.python.elastic_agent.config.launch_config import ElasticLaunchConfig +from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction try: from torch_npu.contrib import transfer_to_npu # noqa: F401 @@ -376,8 +377,11 @@ def __init__( self._save_ckpt_executor = ThreadPoolExecutor(max_workers=1) self._save_ckpt_future = None + self._worker_context = WorkerContext() self._diagnose_agent = DiagnosisAgent( - training_log_file, failure_node_errors + training_log_file=training_log_file, + errors=failure_node_errors, + worker_context=self._worker_context, ) @prof @@ -779,6 +783,10 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: ) put_metric(f"workers.{role}.{state.name.lower()}", 1) + actions = self._worker_context.next_actions() + for action in actions: + self._process_diagnose_action(action) + if state == WorkerState.SUCCEEDED: logger.info( f"[{role}] worker group successfully finished." @@ -802,16 +810,14 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: logger.error(f"The worker fails with {run_result.failures}") self._save_ckpt_to_storage() - worker_context = WorkerContext( + self._worker_context.update_context( worker_spec=self._worker_group.spec, remaining_failovers=self._remaining_failovers, restart_count=self._restart_count, run_result=run_result, ) try: - action = self._diagnose_agent.diagnose_training_failure( - worker_context - ) + action = self._diagnose_agent.diagnose_training_failure() except Exception as e: logger.warning(f"Failed to diagnose errors: {e}") if self._remaining_failovers > 0: @@ -829,11 +835,11 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: else: raise Exception(f"[{role}] worker group in {state.name} state") - def _process_diagnose_action(self, action: str): - if action == DiagnosisAction.RESTART_WORKER: + def _process_diagnose_action(self, action: DiagnoseAction): + if action.action == DiagnosisAction.RESTART_WORKER: self._remaining_failovers -= 1 self._restart_workers(self._worker_group) - elif action == DiagnosisAction.RELAUNCH_WORKER: + elif action.action == DiagnosisAction.RELAUNCH_WORKER: self._stop_workers(self._worker_group) self._worker_group.state = WorkerState.FAILED diff --git a/dlrover/python/master/diagnosis/diagnosis_data_manager.py b/dlrover/python/master/diagnosis/diagnosis_data_manager.py new file mode 100644 index 000000000..d4ef761b6 --- /dev/null +++ b/dlrover/python/master/diagnosis/diagnosis_data_manager.py @@ -0,0 +1,37 @@ +from typing import Dict, List +from dlrover.python.common.time import has_expired +from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData +from dlrover.python.common.log import default_logger as logger + + +class DiagnosisDataManager: + def __init__(self, expire_time_period): + self.diagnosis_data: Dict[str, List[DiagnosisData]] = {} + self.expire_time_period = expire_time_period + + def store_data(self, data: DiagnosisData): + data_type = data.data_type + if data_type not in self.diagnosis_data: + logger.debug(f"{data_type} is not found in the store") + self.diagnosis_data[data_type] = [] + self.diagnosis_data[data_type].append(data) + self._clean_diagnosis_data(data_type) + + def get_data(self, data_type: str) -> List[DiagnosisData]: + if data_type not in self.diagnosis_data: + return [] + return self.diagnosis_data[data_type] + + def _clean_diagnosis_data(self, data_type: str): + if data_type not in self.diagnosis_data: + return + + data = self.diagnosis_data[data_type] + n = 0 + for d in data: + if has_expired(d.timestamp, self.expire_time_period): + n = n + 1 + else: + break + + self.diagnosis_data[data_type] = data[n:] diff --git a/dlrover/python/master/diagnosis/diagnosis.py b/dlrover/python/master/diagnosis/diagnosis_manager.py similarity index 62% rename from dlrover/python/master/diagnosis/diagnosis.py rename to dlrover/python/master/diagnosis/diagnosis_manager.py index 48692775a..731bc2a82 100644 --- a/dlrover/python/master/diagnosis/diagnosis.py +++ b/dlrover/python/master/diagnosis/diagnosis_manager.py @@ -13,8 +13,8 @@ import threading import time -from datetime import datetime, timedelta -from typing import Dict, List + +from typing import List from dlrover.python.common.log import default_logger as logger from dlrover.python.diagnosis.common.constants import DiagnosisConstant @@ -24,27 +24,31 @@ InferenceDescription, InferenceName, ) +from dlrover.python.diagnosis.common.diagnose_action import ( + DiagnoseAction, +) from dlrover.python.diagnosis.inferencechain.inference_chain import ( Inference, InferenceChain, - InferenceOperator, -) -from dlrover.python.diagnosis.inferencechain.inferenceoperator.check_training_hang_operator import ( # noqa: E501 - CheckTrainingHangOperator, ) - -def has_expired(timestamp: float, time_period: int) -> bool: - dt = datetime.fromtimestamp(timestamp) - expired_dt = dt + timedelta(seconds=time_period) - return expired_dt < datetime.now() +from dlrover.python.master.diagnosis.diagnosis_data_manager import DiagnosisDataManager +from dlrover.python.diagnosis.inferencechain.inferenceoperator.operator import ( + get_master_observe_operators, +) +from dlrover.python.diagnosis.inferencechain.inference_chain import combine_inferences +from dlrover.python.diagnosis.inferencechain.coordinate_inferences import ( + coordinate_inferences, +) +from dlrover.python.master.node.job import JobContext class DiagnosisManager: - def __init__(self): + def __init__(self, job_context: JobContext = None): self._is_observing_started = False self._data_manager: DiagnosisDataManager = DiagnosisDataManager(600) self._diagnostician: Diagnostician = Diagnostician(self._data_manager) + self._job_context = job_context def collect_diagnosis_data(self, data: DiagnosisData): self._data_manager.store_data(data) @@ -68,7 +72,7 @@ def start_observing(self): InferenceDescription.HANG, ) ] - self._diagnostician.register_problems(problems) + self._diagnostician.register_training_problems(problems) try: thread = threading.Thread( @@ -100,45 +104,20 @@ def _diagnose(self): solutions: List[Inference] = [] for problem in observed_problems: logger.info(f"observed problems: {problem}") - root_causes = self._diagnostician.diagnose_failure(problem) - for root_cause in root_causes: - logger.info(f"identify root cause: {root_cause}") + infs = self._diagnostician.diagnose_problem(problem) + logger.info(f"have the solution to {problem}: {infs}") + if len(infs) > 0: + solutions = combine_inferences(solutions, infs) + + actions = coordinate_inferences(solutions) + self._job_context.enqueue_actions(actions) + time.sleep( DiagnosisConstant.MASTER_DIAGNOSIS_OBSERVING_INTERVAL_SECS ) - -class DiagnosisDataManager: - def __init__(self, expire_time_period): - self.diagnosis_data: Dict[str, List[DiagnosisData]] = {} - self.expire_time_period = expire_time_period - - def store_data(self, data: DiagnosisData): - data_type = data.data_type - if data_type not in self.diagnosis_data: - logger.debug(f"{data_type} is not found in the store") - self.diagnosis_data[data_type] = [] - self.diagnosis_data[data_type].append(data) - self._clean_diagnosis_data(data_type) - - def get_data(self, data_type: str) -> List[DiagnosisData]: - if data_type not in self.diagnosis_data: - return [] - return self.diagnosis_data[data_type] - - def _clean_diagnosis_data(self, data_type: str): - if data_type not in self.diagnosis_data: - return - - data = self.diagnosis_data[data_type] - n = 0 - for d in data: - if has_expired(d.timestamp, self.expire_time_period): - n = n + 1 - else: - break - - self.diagnosis_data[data_type] = data[n:] + def next_actions(self, rank) -> List[DiagnoseAction]: + return self._action_queue.next_actions(rank) class Diagnostician: @@ -146,24 +125,20 @@ def __init__(self, data_manager): self._data_manager = data_manager self._pre_checks: List[Inference] = [] self._training_problems: List[Inference] = [] - self._observing_operators = - - def get_observing_operators(self) -> List[InferenceOperator]: - return [CheckTrainingHangOperator(self._data_manager)] + self._observing_operators = get_master_observe_operators(data_mgr=data_manager) def register_pre_check(self, pre_checks: List[Inference]): self._pre_checks = pre_checks - def register_problems(self, problems: List[Inference]): + def register_training_problems(self, problems: List[Inference]): self._training_problems = problems - def check_training(self) -> List[Inference]: - ic = InferenceChain(self._pre_checks, self.get_pre_check_operators()) - return ic.infer() - def observe_training(self) -> List[Inference]: + if len(self._training_problems) == 0: + logger.warning("No training problem is registered.") + return [] ic = InferenceChain( - self._training_problems, self.get_observing_operators() + self._training_problems, self._observing_operators ) return ic.infer() diff --git a/dlrover/python/master/dist_master.py b/dlrover/python/master/dist_master.py index c68942e2c..65eefcf62 100644 --- a/dlrover/python/master/dist_master.py +++ b/dlrover/python/master/dist_master.py @@ -26,7 +26,7 @@ ReporterType, ) from dlrover.python.common.log import default_logger as logger -from dlrover.python.master.diagnosis.diagnosis import DiagnosisManager +from dlrover.python.master.diagnosis.diagnosis_manager import DiagnosisManager from dlrover.python.master.elastic_training.elastic_ps import ElasticPsService from dlrover.python.master.elastic_training.rdzv_manager import ( ElasticTrainingRendezvousManager, @@ -47,6 +47,7 @@ from dlrover.python.master.shard.task_manager import TaskManager from dlrover.python.master.stats.job_collector import JobMetricCollector from dlrover.python.scheduler.job import JobArgs +from dlrover.python.master.node.job import JobContext def _create_elastic_ps_service_if_needed(params: JobArgs): @@ -123,8 +124,9 @@ def __init__( ) self.speed_monitor = SpeedMonitor() + self.job_context = JobContext() self.job_manager = ( - create_job_manager(args, self.speed_monitor) + create_job_manager(args, self.speed_monitor, self.job_context) if args.enable_elastic_scheduling else None ) @@ -143,7 +145,7 @@ def __init__( error_monitor ), } - self.diagnosis_manager = DiagnosisManager() + self.diagnosis_manager = DiagnosisManager(self.job_context) self.job_metric_collector = self._create_metric_collector_if_needed( args ) diff --git a/dlrover/python/master/node/dist_job_manager.py b/dlrover/python/master/node/dist_job_manager.py index cb7151ace..eb0134fa4 100644 --- a/dlrover/python/master/node/dist_job_manager.py +++ b/dlrover/python/master/node/dist_job_manager.py @@ -31,11 +31,13 @@ NodeStatus, NodeType, TrainingExceptionLevel, + MasterConstants, ) from dlrover.python.common.global_context import Context from dlrover.python.common.grpc import ParallelConfig from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node, NodeGroupResource +from dlrover.python.diagnosis.common.constants import DiagnosisAction from dlrover.python.master.monitor.error_monitor import K8sJobErrorMonitor from dlrover.python.master.node.event_callback import ( ClusterContext, @@ -101,12 +103,14 @@ def __init__( job_scaler=None, error_monitor=None, external_config=None, + job_context=None, ): super().__init__( job_args=job_args, speed_monitor=speed_monitor, error_monitor=error_monitor, external_config=external_config, + job_context=job_context, ) self._remove_exited_node = job_args.remove_exited_node node_restart_count: Dict[str, int] = {} @@ -210,8 +214,8 @@ def start(self): target=self._monitor_nodes, name="node_monitor", daemon=True ).start() threading.Thread( - target=self._monitor_node_heart_beat, - name="node_heart_beat_monitor", + target=self._diagnose_job(), + name="diagnose_job", daemon=True, ).start() if os.getenv("KUBERNETES_SERVICE_HOST"): @@ -458,6 +462,32 @@ def _monitor_node_heart_beat(self): logger.warning(detail_trace_back) time.sleep(15) + def _diagnose_job(self): + logger.info("Start diagnosing the job.") + while True: + if self._stopped: + logger.info("Stop diagnosing job.") + break + with self._lock: + try: + events = self._get_dead_node_event() + except Exception as e: + logger.warning(e) + events = [] + + for event in events: + try: + self._process_event(event) + except Exception as e: + logger.warning(e) + detail_trace_back = traceback.format_exc() + logger.warning(detail_trace_back) + + actions = self._job_context.next_actions(MasterConstants.MASTER_RANK) + for action in actions: + self._process_diagnosis_action(action) + time.sleep(15) + def _get_dead_node_event(self, window_interval=900) -> List[NodeEvent]: now = time.time() dead_events: List[NodeEvent] = [] @@ -609,6 +639,9 @@ def _get_pod_unique_labels(self, node: Node): ElasticJobLabel.RANK_INDEX_KEY: node.rank_index, } + def _process_diagnosis_action(self, action: DiagnosisAction): + pass + def _process_event(self, event: NodeEvent): node_type = event.node.type node_status = event.node.status @@ -1118,13 +1151,13 @@ def verify_restarting_worker_training(self, node_type, node_id): return False return self._worker_manager.verify_restarting_training(node_id) - def collect_node_heart_beat(self, node_type, node_id, timestamp): + def collect_node_heart_beat(self, node_type, node_id, timestamp) -> List[DiagnosisAction]: with self._lock: if ( node_type not in self._job_nodes or node_id not in self._job_nodes[node_type] ): - return + return [] node = self._job_nodes[node_type][node_id] if node.heartbeat_time == 0: logger.info( @@ -1132,6 +1165,7 @@ def collect_node_heart_beat(self, node_type, node_id, timestamp): f"-{node.name}" ) node.heartbeat_time = timestamp + return self._job_context.next_actions(node_id) def update_node_required_info_callback(self): self._worker_manager.update_node_required_info(self._nodes_required) @@ -1141,7 +1175,7 @@ def update_succeeded_node(self, node_id, node_type): super().update_succeeded_node(node_id, node_type) -def create_job_manager(args: JobArgs, speed_monitor) -> DistributedJobManager: +def create_job_manager(args: JobArgs, speed_monitor, job_context=None) -> DistributedJobManager: critical_worker_index = get_critical_worker_index(args) # Custom distribution strategy does not exit if there are pending nodes wait_pending_relaunch = ( @@ -1166,4 +1200,5 @@ def create_job_manager(args: JobArgs, speed_monitor) -> DistributedJobManager: node_watcher=node_watcher, job_scaler=job_scaler, error_monitor=node_error_monitor, + job_context=job_context, ) diff --git a/dlrover/python/master/node/job.py b/dlrover/python/master/node/job.py new file mode 100644 index 000000000..f2b10bfa7 --- /dev/null +++ b/dlrover/python/master/node/job.py @@ -0,0 +1,16 @@ +from dlrover.python.diagnosis.common.diagnose_action import DiagnoseActionQueue + + +class JobContext: + def __init__(self): + self._action_queue = DiagnoseActionQueue() + + def enqueue_actions(self, actions): + for action in actions: + self._action_queue.add_action(action) + + def next_actions(self, rank): + return self._action_queue.next_actions(rank) + + def update_context(self): + pass diff --git a/dlrover/python/master/node/job_manager.py b/dlrover/python/master/node/job_manager.py index acec71396..3f1b8d2bd 100644 --- a/dlrover/python/master/node/job_manager.py +++ b/dlrover/python/master/node/job_manager.py @@ -12,7 +12,7 @@ # limitations under the License. from abc import ABCMeta, abstractmethod -from typing import Dict +from typing import Dict, List from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node @@ -28,6 +28,8 @@ from dlrover.python.master.resource.job import JobResource from dlrover.python.scheduler.job import JobArgs from dlrover.python.scheduler.kubernetes import k8sClient +from dlrover.python.master.node.job import JobContext +from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction class JobManager(metaclass=ABCMeta): @@ -41,6 +43,7 @@ def __init__( speed_monitor=None, error_monitor=None, external_config=None, + job_context=None, ): self._job_resource = JobResource() self._job_args = job_args @@ -58,6 +61,11 @@ def __init__( self._training_node_config = TrainingNodeConfig(external_config) + if job_context is None: + self._job_context = JobContext() + else: + self._job_context = job_context + @abstractmethod def start(self): pass @@ -195,7 +203,7 @@ def handle_training_failure( pass @abstractmethod - def collect_node_heart_beat(self, node_type, node_id, timestamp): + def collect_node_heart_beat(self, node_type, node_id, timestamp) -> List[DiagnoseAction]: """Collect the heart beat message of nodes.""" pass diff --git a/dlrover/python/master/node/local_job_manager.py b/dlrover/python/master/node/local_job_manager.py index d7e82a28b..f13574c2d 100644 --- a/dlrover/python/master/node/local_job_manager.py +++ b/dlrover/python/master/node/local_job_manager.py @@ -17,6 +17,8 @@ from dlrover.python.master.monitor.error_monitor import SimpleErrorMonitor from dlrover.python.master.node.job_manager import JobManager from dlrover.python.scheduler.job import JobArgs +from typing import List +from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction class LocalJobManager(JobManager): @@ -68,9 +70,10 @@ def handle_training_failure( node, restart_count, error_data, level ) - def collect_node_heart_beat(self, node_type, node_id, timestamp): + def collect_node_heart_beat(self, node_type, node_id, timestamp) -> List[DiagnoseAction]: node = self._job_nodes[node_type][node_id] node.heartbeat_time = timestamp + return [] def close_job(self): pass diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index e8250dc17..d1163d704 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -33,7 +33,7 @@ from dlrover.python.common.global_context import Context from dlrover.python.common.log import default_logger as logger from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData -from dlrover.python.master.diagnosis.diagnosis import DiagnosisManager +from dlrover.python.master.diagnosis.diagnosis_manager import DiagnosisManager from dlrover.python.master.elastic_training.kv_store_service import ( KVStoreService, ) @@ -140,6 +140,8 @@ def get(self, request, _): elif isinstance(req_message, grpc.ElasticRunConfigRequest): configs = self._job_manager.get_elastic_run_configs() message = grpc.ElasticRunConfig(configs=configs) + elif isinstance(req_message, grpc.HeartBeat): + message = self._report_heartbeat(node_type, node_id, req_message) if message: response.data = message.serialize() @@ -355,8 +357,6 @@ def report(self, request, _): success = self._kv_store_set(message) elif isinstance(message, grpc.ParallelConfig): success = self._report_paral_config(node_type, node_id, message) - elif isinstance(message, grpc.HeartBeat): - success = self._report_heartbeat(node_type, node_id, message) elif isinstance(message, grpc.NodeCheckpointState): success = self._sync_checkpoint(node_type, node_id, message) elif isinstance(message, grpc.DiagnosisReportData): @@ -602,14 +602,6 @@ def _report_paral_config( ) return True - def _report_heartbeat(self, node_type, node_id, message: grpc.HeartBeat): - self._job_manager.collect_node_heart_beat( - node_type, - node_id, - message.timestamp, - ) - return True - def _sync_checkpoint( self, node_type, node_id, message: grpc.NodeCheckpointState ): @@ -649,6 +641,19 @@ def _sync_training_ports( port=sync_ports.training_port, newport=sync_ports.next_check_port ) + def _report_heartbeat(self, node_type, node_id, message: grpc.HeartBeat) -> grpc.HeartbeatResponse: + actions = self._job_manager.collect_node_heart_beat(node_type, node_id, message.timestamp) + grpc_actions: List[grpc.DiagnosisAction] = [] + for action in actions: + grpc_action = grpc.DiagnosisAction( + rank=action.rank, + timestamp=action.timestamp, + action=action.action, + expired_time_period=action.expired_time_period, + ) + grpc_actions.append(grpc_action) + return grpc.HeartbeatResponse(diagnosis_actions=grpc_actions) + def create_master_service( port, diff --git a/dlrover/python/tests/test_diagnosis.py b/dlrover/python/tests/test_diagnosis.py index 501dd1d08..42826505f 100644 --- a/dlrover/python/tests/test_diagnosis.py +++ b/dlrover/python/tests/test_diagnosis.py @@ -16,7 +16,7 @@ from dlrover.python.diagnosis.common.constants import DiagnosisDataType from dlrover.python.diagnosis.common.diagnosis_data import TrainingLog -from dlrover.python.master.diagnosis.diagnosis import DiagnosisDataManager +from dlrover.python.master.diagnosis.diagnosis_manager import DiagnosisDataManager class DiagnosisTest(unittest.TestCase): diff --git a/dlrover/python/tests/test_diagnosis_agent.py b/dlrover/python/tests/test_diagnosis_agent.py index c1ed97c43..26849ec60 100644 --- a/dlrover/python/tests/test_diagnosis_agent.py +++ b/dlrover/python/tests/test_diagnosis_agent.py @@ -19,7 +19,7 @@ from dlrover.python.common import env_utils from dlrover.python.common.constants import RendezvousName -from dlrover.python.common.worker import WorkerContext +from dlrover.python.elastic_agent.common.worker import WorkerContext from dlrover.python.diagnosis.common.constants import DiagnosisAction from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric from dlrover.python.elastic_agent.diagnosis.diagnosis_agent import ( diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py index 770ecde85..ee6711d04 100644 --- a/dlrover/python/tests/test_job_manager.py +++ b/dlrover/python/tests/test_job_manager.py @@ -673,7 +673,7 @@ def test_start_and_stop(self): manager.start() active_threads_name = [t.name for t in threading.enumerate()] self.assertIn("node_monitor", active_threads_name) - self.assertIn("node_heart_beat_monitor", active_threads_name) + self.assertIn("diagnose_job", active_threads_name) manager.stop() def test_concurrency_heart_beat_collecting(self): diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index b30b6e999..d7d978eb0 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -27,7 +27,7 @@ ) from dlrover.python.common.grpc import GPUStats from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric -from dlrover.python.master.diagnosis.diagnosis import DiagnosisManager +from dlrover.python.master.diagnosis.diagnosis_manager import DiagnosisManager from dlrover.python.master.elastic_training.elastic_ps import ElasticPsService from dlrover.python.master.elastic_training.rdzv_manager import ( ElasticTrainingRendezvousManager, From cb117e1e37549337e3932b221a37eb520089e518 Mon Sep 17 00:00:00 2001 From: bsang Date: Wed, 23 Oct 2024 16:25:37 +0800 Subject: [PATCH 04/10] fix unit test --- .../master/diagnosis/diagnosis_manager.py | 2 +- .../python/master/node/dist_job_manager.py | 51 +++++++++---------- dlrover/python/tests/test_job_manager.py | 1 + dlrover/python/tests/test_servicer.py | 2 +- 4 files changed, 27 insertions(+), 29 deletions(-) diff --git a/dlrover/python/master/diagnosis/diagnosis_manager.py b/dlrover/python/master/diagnosis/diagnosis_manager.py index 731bc2a82..9ec767052 100644 --- a/dlrover/python/master/diagnosis/diagnosis_manager.py +++ b/dlrover/python/master/diagnosis/diagnosis_manager.py @@ -76,7 +76,7 @@ def start_observing(self): try: thread = threading.Thread( - target=self._diagnose(), + target=self._diagnose, name="diagnose_failures", daemon=True, ) diff --git a/dlrover/python/master/node/dist_job_manager.py b/dlrover/python/master/node/dist_job_manager.py index eb0134fa4..5effe062e 100644 --- a/dlrover/python/master/node/dist_job_manager.py +++ b/dlrover/python/master/node/dist_job_manager.py @@ -31,8 +31,8 @@ NodeStatus, NodeType, TrainingExceptionLevel, - MasterConstants, ) +from dlrover.python.diagnosis.common.constants import DiagnosisConstant from dlrover.python.common.global_context import Context from dlrover.python.common.grpc import ParallelConfig from dlrover.python.common.log import default_logger as logger @@ -214,7 +214,7 @@ def start(self): target=self._monitor_nodes, name="node_monitor", daemon=True ).start() threading.Thread( - target=self._diagnose_job(), + target=self._diagnose_job, name="diagnose_job", daemon=True, ).start() @@ -440,27 +440,27 @@ def _monitor_nodes(self): time.sleep(30) time.sleep(5) - def _monitor_node_heart_beat(self): - logger.info("Start monitoring the heartbeat of nodes.") - while True: - if self._stopped: - logger.info("Stop monitoring the heartbeat of nodes.") - break - with self._lock: - try: - events = self._get_dead_node_event() - except Exception as e: - logger.warning(e) - events = [] - - for event in events: - try: - self._process_event(event) - except Exception as e: - logger.warning(e) - detail_trace_back = traceback.format_exc() - logger.warning(detail_trace_back) - time.sleep(15) + # def _monitor_node_heart_beat(self): + # logger.info("Start monitoring the heartbeat of nodes.") + # while True: + # if self._stopped: + # logger.info("Stop monitoring the heartbeat of nodes.") + # break + # with self._lock: + # try: + # events = self._get_dead_node_event() + # except Exception as e: + # logger.warning(e) + # events = [] + # + # for event in events: + # try: + # self._process_event(event) + # except Exception as e: + # logger.warning(e) + # detail_trace_back = traceback.format_exc() + # logger.warning(detail_trace_back) + # time.sleep(15) def _diagnose_job(self): logger.info("Start diagnosing the job.") @@ -474,7 +474,6 @@ def _diagnose_job(self): except Exception as e: logger.warning(e) events = [] - for event in events: try: self._process_event(event) @@ -482,8 +481,7 @@ def _diagnose_job(self): logger.warning(e) detail_trace_back = traceback.format_exc() logger.warning(detail_trace_back) - - actions = self._job_context.next_actions(MasterConstants.MASTER_RANK) + actions = self._job_context.next_actions(DiagnosisConstant.MASTER_RANK) for action in actions: self._process_diagnosis_action(action) time.sleep(15) @@ -764,7 +762,6 @@ def _process_event(self, event: NodeEvent): "exit reason": cur_node.exit_reason, }, ) - if should_relaunch: self._relaunch_node(cur_node) diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py index ee6711d04..712779066 100644 --- a/dlrover/python/tests/test_job_manager.py +++ b/dlrover/python/tests/test_job_manager.py @@ -212,6 +212,7 @@ def test_get_critical_worker_index(self): self.assertDictEqual(critical_worker, {}) def test_relaunch_node(self): + print("!!!!!!!!!!!Start test relaunch_node!!!!!!!!!!!!\n") params = MockK8sPSJobArgs() params.initilize() manager = create_job_manager(params, SpeedMonitor()) diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index d7d978eb0..e5eba00fa 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -394,7 +394,7 @@ def test_report_heartbeat(self): request.data = message.serialize() request.node_type = NodeType.WORKER request.node_id = 0 - self.servicer.report(request, None) + self.servicer.get(request, None) worker0 = self.servicer._job_manager._job_nodes[NodeType.WORKER][0] self.assertEqual(worker0.heartbeat_time, ts) From d781ca48a765ae8810b0885d2710f2b0baed7260 Mon Sep 17 00:00:00 2001 From: bsang Date: Thu, 31 Oct 2024 11:21:16 +0800 Subject: [PATCH 05/10] update --- dlrover/python/common/grpc.py | 6 +- dlrover/python/common/node.py | 3 + dlrover/python/common/time.py | 7 - dlrover/python/diagnosis/common/constants.py | 12 +- .../diagnosis/common/diagnose_action.py | 88 ---------- .../diagnosis/common/diagnosis_action.py | 133 ++++++++++++++++ .../inferencechain/coordinate_inferences.py | 4 +- .../inferenceoperator/operator.py | 9 +- .../python/elastic_agent/common/__init__.py | 13 ++ .../common/{worker.py => worker_context.py} | 55 +++++-- .../elastic_agent/config/launch_config.py | 39 +++-- .../diagnosis/diagnosis_agent.py | 32 ++-- dlrover/python/elastic_agent/master_client.py | 8 +- .../python/elastic_agent/torch/training.py | 80 ++++++---- .../diagnosis/diagnosis_data_manager.py | 18 ++- .../master/diagnosis/diagnosis_manager.py | 39 ++--- dlrover/python/master/dist_master.py | 7 +- .../python/master/node/dist_job_manager.py | 150 ++++++++++-------- dlrover/python/master/node/job.py | 16 -- dlrover/python/master/node/job_auto_scaler.py | 20 +-- dlrover/python/master/node/job_context.py | 138 ++++++++++++++++ dlrover/python/master/node/job_manager.py | 30 ++-- .../python/master/node/local_job_manager.py | 42 +++-- dlrover/python/master/node/ps.py | 48 ++++-- dlrover/python/master/node/training_node.py | 63 +++++--- dlrover/python/master/node/worker.py | 74 +++++---- dlrover/python/master/servicer.py | 8 +- dlrover/python/tests/test_common_util.py | 2 +- dlrover/python/tests/test_diagnosis.py | 4 +- dlrover/python/tests/test_diagnosis_agent.py | 44 ++--- .../tests/test_elastic_training_agent.py | 4 +- dlrover/python/tests/test_job_auto_scaler.py | 43 +++-- dlrover/python/tests/test_job_manager.py | 144 ++++++++++++----- dlrover/python/tests/test_log.py | 1 + dlrover/python/tests/test_master.py | 16 +- dlrover/python/tests/test_ps_manager.py | 64 ++++++-- dlrover/python/tests/test_servicer.py | 8 +- dlrover/python/util/time_util.py | 20 +++ dlrover/trainer/torch/elastic_run.py | 8 +- 39 files changed, 1001 insertions(+), 499 deletions(-) delete mode 100644 dlrover/python/common/time.py delete mode 100644 dlrover/python/diagnosis/common/diagnose_action.py create mode 100644 dlrover/python/diagnosis/common/diagnosis_action.py rename dlrover/python/elastic_agent/common/{worker.py => worker_context.py} (54%) delete mode 100644 dlrover/python/master/node/job.py create mode 100644 dlrover/python/master/node/job_context.py create mode 100644 dlrover/python/util/time_util.py diff --git a/dlrover/python/common/grpc.py b/dlrover/python/common/grpc.py index aeb9bc958..1808679cc 100644 --- a/dlrover/python/common/grpc.py +++ b/dlrover/python/common/grpc.py @@ -518,10 +518,8 @@ class SucceededRequest(Message): @dataclass class DiagnosisAction(Message): - rank: int = -1 - timestamp: float = 0.0 - expired_time_period: int = 0 - action: str = "" + action_cls: str = "" + action_content: str = "" @dataclass diff --git a/dlrover/python/common/node.py b/dlrover/python/common/node.py index 63fbc49ea..c27b7fe1b 100644 --- a/dlrover/python/common/node.py +++ b/dlrover/python/common/node.py @@ -364,3 +364,6 @@ def to_dict(self): d.pop("config_resource", None) d.pop("used_resource", None) return d + + def update_from_node(self, node): + self.__dict__.update(node.__dict__) diff --git a/dlrover/python/common/time.py b/dlrover/python/common/time.py deleted file mode 100644 index 657f64cac..000000000 --- a/dlrover/python/common/time.py +++ /dev/null @@ -1,7 +0,0 @@ -from datetime import datetime, timedelta - - -def has_expired(timestamp: float, time_period: int) -> bool: - dt = datetime.fromtimestamp(timestamp) - expired_dt = dt + timedelta(seconds=time_period) - return expired_dt < datetime.now() diff --git a/dlrover/python/diagnosis/common/constants.py b/dlrover/python/diagnosis/common/constants.py index 82f10f393..f3b132625 100644 --- a/dlrover/python/diagnosis/common/constants.py +++ b/dlrover/python/diagnosis/common/constants.py @@ -24,8 +24,9 @@ class InferenceConfigKey(object): class DiagnosisConstant(object): MASTER_DIAGNOSIS_OBSERVING_INTERVAL_SECS = 180 AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS = 60 - MASTER_RANK = -1 - ANY_RANK = -2 + MASTER = -1 + ANY_INSTANCE = -2 + LOCAL_INSTANCE = -3 class DiagnosisDataType(object): @@ -34,8 +35,13 @@ class DiagnosisDataType(object): XPU_TIMER_METRIC = "XPU_TIMER_METRIC" -class DiagnosisAction(object): +class DiagnosisActionConstants(object): NO_ACTION = "no_action" RESTART_WORKER = "restart_worker" RELAUNCH_WORKER = "relaunch_worker" + + ACTION_TYPE_ANY = "any" + ACTION_EXPIRED_TIME_PERIOD = 60 * 5 + + TYPE_NODE = "node" diff --git a/dlrover/python/diagnosis/common/diagnose_action.py b/dlrover/python/diagnosis/common/diagnose_action.py deleted file mode 100644 index 33e03c0ee..000000000 --- a/dlrover/python/diagnosis/common/diagnose_action.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2024 The DLRover Authors. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import threading -from datetime import datetime -from dlrover.python.common.time import has_expired -from typing import List -from dlrover.python.common.log import default_logger as logger -from dlrover.python.diagnosis.common.constants import ( - DiagnosisConstant, - DiagnosisAction as DiagnosisActionConstants, -) - - -class DiagnoseAction: - def __init__(self, timestamp=0, expired_time_period=0, action="", rank=DiagnosisConstant.ANY_RANK): - self.action = action - if timestamp == 0: - self.timestamp = int(round(datetime.now().timestamp())) - else: - self.timestamp= timestamp - - if expired_time_period == 0: - self.expired_time_period = DiagnosisActionConstants.ACTION_EXPIRED_TIME_PERIOD - else: - self.expired_time_period = expired_time_period - # rank indicates the rank (worker) which is to - # execute this action. Rank = -1 indicates the - # master is to execute the action - self.rank = rank - - def set_action(self, action: str): - self.action = action - - def has_expired(self) -> bool: - return has_expired(self.timestamp, self.expired_time_period) - - -def is_same_action(action1: DiagnoseAction, action2: DiagnoseAction) -> bool: - return action1.action == action2.action and action1.rank == action2.rank - - -class DiagnoseActionQueue: - def __init__(self): - self._actions: List[DiagnoseAction] = [] - self._lock = threading.Lock() - - def add_action(self, new_action: DiagnoseAction): - with self._lock: - for action in self._actions: - if is_same_action(new_action, action): - return - logger.info(f"enqueue action {new_action.action} of {new_action.rank}") - self._actions.append(new_action) - - def _remove_expired_actions(self): - with self._lock: - actions = [] - for action in self._actions: - if not action.has_expired(): - actions.append(action) - else: - logger.info(f"Action {action} has expired") - - self._actions = actions - - def next_actions(self, rank=DiagnosisConstant.ANY_RANK) -> List[DiagnoseAction]: - self._remove_expired_actions() - with self._lock: - rank_actions = [] - remain_actions = [] - for action in self._actions: - if action.rank == rank or rank == DiagnosisConstant.ANY_RANK: - rank_actions.append(action) - else: - remain_actions.append(action) - return rank_actions - - diff --git a/dlrover/python/diagnosis/common/diagnosis_action.py b/dlrover/python/diagnosis/common/diagnosis_action.py new file mode 100644 index 000000000..611feeef0 --- /dev/null +++ b/dlrover/python/diagnosis/common/diagnosis_action.py @@ -0,0 +1,133 @@ +# Copyright 2024 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +from abc import ABCMeta +from datetime import datetime +from typing import List, Dict + +from dlrover.python.common.constants import NodeType +from dlrover.python.common.log import default_logger as logger +from dlrover.python.diagnosis.common.constants import ( + DiagnosisActionConstants, + DiagnosisConstant, +) +from dlrover.python.util.time_util import has_expired +import json + + +class DiagnosisAction(metaclass=ABCMeta): + def __init__(self, action_type: str, instance: int, timestamp=0, expired_time_period=0): + self.action_type = action_type + self.instance = instance + if timestamp == 0: + self.timestamp = int(round(datetime.now().timestamp())) + else: + self.timestamp = timestamp + + if expired_time_period == 0: + self.expired_time_period = ( + DiagnosisActionConstants.ACTION_EXPIRED_TIME_PERIOD + ) + else: + self.expired_time_period = expired_time_period + + def has_expired(self) -> bool: + return has_expired(self.timestamp, self.expired_time_period) + + def to_json(self): + data = {k.lstrip("_"): v for k, v in self.__dict__.items()} + return json.dumps(data) + + @classmethod + def from_json(cls, json_data): + return cls(**json.loads(json_data)) + + +class DiagnosisNodeAction(DiagnosisAction): + def __init__( + self, + timestamp=0, + expired_time_period=0, + action="", + node_type=NodeType.WORKER, + instance=DiagnosisConstant.LOCAL_INSTANCE, + ): + super().__init__( + DiagnosisActionConstants.TYPE_NODE, instance, timestamp, expired_time_period + ) + self.action = action + self.node_type = node_type + + def update_action(self, action: str): + self.action = action + + +def is_same_action(action1: DiagnosisAction, action2: DiagnosisAction) -> bool: + if ( + action1.action_type == DiagnosisActionConstants.TYPE_NODE + and action2.action_type == DiagnosisActionConstants.TYPE_NODE + and action1.action == action2.action + and action1.rank == action2.rank + ): + return True + return False + + +class DiagnosisActionQueue: + def __init__(self): + self._actions: Dict[int, List[DiagnosisAction]] = {} + self._lock = threading.Lock() + + def add_action(self, new_action: DiagnosisAction): + with self._lock: + instance = new_action.instance + if instance not in self._actions: + self._actions[instance] = [] + ins_actions = self._actions[instance] + for action in ins_actions: + if is_same_action(new_action, action): + return + logger.info(f"enqueue action {new_action}") + ins_actions.append(new_action) + + def _remove_expired_actions(self): + with self._lock: + for instance in self._actions.keys(): + action_queue = self._actions[instance] + actions = [] + for action in action_queue: + if not action.has_expired(): + actions.append(action) + else: + logger.info(f"Action {action} has expired") + self._actions[instance] = actions + + def next_actions( + self, + instance=DiagnosisConstant.LOCAL_INSTANCE, + action_type=DiagnosisActionConstants.ACTION_TYPE_ANY, + ) -> List[DiagnosisAction]: + self._remove_expired_actions() + with self._lock: + if instance not in self._actions: + return [] + deque_actions = [] + remain_actions = [] + actions = self._actions[instance] + for action in actions: + if action_type == DiagnosisActionConstants.TYPE_NODE or action_type == action.action_type: + deque_actions.append(action) + else: + remain_actions.append(action) + self._actions[instance] = remain_actions + return deque_actions diff --git a/dlrover/python/diagnosis/inferencechain/coordinate_inferences.py b/dlrover/python/diagnosis/inferencechain/coordinate_inferences.py index c8fb2d32c..131496ebb 100644 --- a/dlrover/python/diagnosis/inferencechain/coordinate_inferences.py +++ b/dlrover/python/diagnosis/inferencechain/coordinate_inferences.py @@ -13,9 +13,9 @@ from typing import List -from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction +from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction from dlrover.python.diagnosis.common.inference_chain import Inference -def coordinate_inferences(problems: List[Inference]) -> List[DiagnoseAction]: +def coordinate_inferences(problems: List[Inference]) -> List[DiagnosisAction]: return [] diff --git a/dlrover/python/diagnosis/inferencechain/inferenceoperator/operator.py b/dlrover/python/diagnosis/inferencechain/inferenceoperator/operator.py index 5757afe65..454f380dc 100644 --- a/dlrover/python/diagnosis/inferencechain/inferenceoperator/operator.py +++ b/dlrover/python/diagnosis/inferencechain/inferenceoperator/operator.py @@ -14,12 +14,14 @@ from dlrover.python.diagnosis.inferencechain.inferenceoperator.check_failure_node_operator import ( # noqa: E501 CheckFailureNodeOperator, ) +from dlrover.python.diagnosis.inferencechain.inferenceoperator.check_training_hang_operator import ( # noqa: E501 + CheckTrainingHangOperator, +) from dlrover.python.diagnosis.inferencechain.inferenceoperator.metrics_collection_operator import ( # noqa: E501 MetricsCollectionOperator, ) -from dlrover.python.master.diagnosis.diagnosis_data_manager import DiagnosisDataManager -from dlrover.python.diagnosis.inferencechain.inferenceoperator.check_training_hang_operator import ( # noqa: E501 - CheckTrainingHangOperator, +from dlrover.python.master.diagnosis.diagnosis_data_manager import ( + DiagnosisDataManager, ) @@ -39,4 +41,3 @@ def get_master_observe_operators(data_mgr: DiagnosisDataManager = None): return [ CheckTrainingHangOperator(data_mgr), ] - diff --git a/dlrover/python/elastic_agent/common/__init__.py b/dlrover/python/elastic_agent/common/__init__.py index e69de29bb..73d1d62c2 100644 --- a/dlrover/python/elastic_agent/common/__init__.py +++ b/dlrover/python/elastic_agent/common/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/dlrover/python/elastic_agent/common/worker.py b/dlrover/python/elastic_agent/common/worker_context.py similarity index 54% rename from dlrover/python/elastic_agent/common/worker.py rename to dlrover/python/elastic_agent/common/worker_context.py index 7b59962ec..4161668d1 100644 --- a/dlrover/python/elastic_agent/common/worker.py +++ b/dlrover/python/elastic_agent/common/worker_context.py @@ -11,21 +11,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional + from torch.distributed.elastic.agent.server.api import RunResult, WorkerSpec -from dlrover.python.diagnosis.common.diagnose_action import ( - DiagnoseActionQueue, - DiagnoseAction, + +from dlrover.python.diagnosis.common.diagnosis_action import ( + DiagnosisAction, + DiagnosisActionQueue, +) +from dlrover.python.common.singleton import Singleton +from dlrover.python.diagnosis.common.constants import ( + DiagnosisConstant, + DiagnosisActionConstants, ) -from typing import Optional, List -class WorkerContext: +class WorkerContext(Singleton): def __init__(self): self._worker_spec: Optional[WorkerSpec] = None self.remaining_failovers = 0 self.restart_count = 0 self._run_result: Optional[RunResult] = None - self._diagnose_action_queue = DiagnoseActionQueue() + self._diagnose_action_queue = DiagnosisActionQueue() @property def worker_spec(self): @@ -44,20 +51,38 @@ def to_string(self) -> str: f"run_result: {self._run_result}" ) - def update_context( - self, - worker_spec: WorkerSpec = None, - remaining_failovers: int = 0, - restart_count: int = 0, - run_result: RunResult = None, + def _update_context( + self, + worker_spec: WorkerSpec = None, + remaining_failovers: int = 0, + restart_count: int = 0, + run_result: RunResult = None, ): self._worker_spec: WorkerSpec = worker_spec self.remaining_failovers = remaining_failovers self.restart_count = restart_count self._run_result = run_result - def enqueue_diagnose_action(self, action: DiagnoseAction): + def enqueue_diagnose_action(self, action: DiagnosisAction): self._diagnose_action_queue.add_action(action) - def next_actions(self) -> List[DiagnoseAction]: - return self._diagnose_action_queue.next_actions() + def next_actions( + self, + instance=DiagnosisConstant.LOCAL_INSTANCE, + action_type=DiagnosisActionConstants.ACTION_TYPE_ANY, + ) -> List[DiagnosisAction]: + return self._diagnose_action_queue.next_actions( + instance=instance, action_type=action_type + ) + +def get_worker_context() -> WorkerContext: + return WorkerContext.singleton_instance() + +def update_worker_context( + worker_spec: WorkerSpec = None, + remaining_failovers: int = 0, + restart_count: int = 0, + run_result: RunResult = None, +): + worker_context = get_worker_context() + worker_context._update_context(worker_spec, remaining_failovers, restart_count, run_result) diff --git a/dlrover/python/elastic_agent/config/launch_config.py b/dlrover/python/elastic_agent/config/launch_config.py index 959c9a9f1..bcd53fe76 100644 --- a/dlrover/python/elastic_agent/config/launch_config.py +++ b/dlrover/python/elastic_agent/config/launch_config.py @@ -1,5 +1,28 @@ +# Copyright 2024 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch.distributed.elastic.multiprocessing import ( + PContext, + SignalException, + Std, +) from torch.distributed.launcher.api import LaunchConfig + from dlrover.python.common.constants import ( Accelerators, AscendConstants, @@ -11,20 +34,6 @@ RendezvousName, TrainingExceptionLevel, ) -from typing import ( - Dict, - List, - Optional, - Tuple, - Union, -) -from torch.distributed.elastic.multiprocessing import ( - PContext, - SignalException, - Std, -) -import os -import torch from dlrover.python.common.log import default_logger as logger @@ -95,4 +104,4 @@ def auto_configure_params(self): if torch.cuda.is_available(): self.nproc_per_node = torch.cuda.device_count() if self.min_nodes >= 4: - self.network_check = True \ No newline at end of file + self.network_check = True diff --git a/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py b/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py index 048b44f93..600736cad 100644 --- a/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py +++ b/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py @@ -23,13 +23,12 @@ from dlrover.python.common.error import ProcessError from dlrover.python.common.log import default_logger as logger from dlrover.python.common.singleton import Singleton -from dlrover.python.elastic_agent.common.worker import WorkerContext from dlrover.python.diagnosis.common.constants import ( - DiagnosisAction, + DiagnosisActionConstants, DiagnosisConstant, InferenceConfigKey, ) -from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction +from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction, DiagnosisNodeAction from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric from dlrover.python.diagnosis.common.inference_chain import ( Inference, @@ -50,11 +49,17 @@ get_worker_diagnosis_operators, get_worker_observe_operators, ) +from dlrover.python.elastic_agent.common.worker_context import WorkerContext from dlrover.python.elastic_agent.master_client import MasterClient class DiagnosisAgent(Singleton): - def __init__(self, training_log_file: str, errors: str, worker_context: WorkerContext): + def __init__( + self, + training_log_file: str, + errors: str, + worker_context: WorkerContext, + ): self._client = MasterClient.singleton_instance() self._training_log_file = training_log_file self._errors = errors @@ -102,8 +107,8 @@ def stop(self): self._stopped = True def diagnose_problems( - self, problems: List[Inference] - ) -> List[DiagnoseAction]: + self, problems: List[Inference] + ) -> List[DiagnosisAction]: conclusions: List[Inference] = [] for problem in problems: ic = InferenceChain([problem], self._diagnosis_operators) @@ -143,9 +148,10 @@ def _periodically_diagnosis(self): DiagnosisConstant.AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS ) - def diagnose_training_failure(self) -> DiagnoseAction: + def diagnose_training_failure(self) -> DiagnosisAction: self._report_failure_to_master( - self._worker_context.run_result.failures, self._worker_context.restart_count + self._worker_context.run_result.failures, + self._worker_context.restart_count, ) # check if the node is failed inference = Inference( @@ -175,8 +181,9 @@ def diagnose_training_failure(self) -> DiagnoseAction: f"{self._worker_context.worker_spec.max_restarts} " f"attempts left; will restart worker group." ) - return DiagnoseAction( - action=DiagnosisAction.RESTART_WORKER, + return DiagnosisNodeAction( + action=DiagnosisActionConstants.RESTART_WORKER, + instance=DiagnosisConstant.LOCAL_INSTANCE, ) else: logger.info( @@ -186,8 +193,9 @@ def diagnose_training_failure(self) -> DiagnoseAction: f"no attempts({self._worker_context.worker_spec.max_restarts}) " "left; will relaunch." ) - return DiagnoseAction( - action=DiagnosisAction.RELAUNCH_WORKER, + return DiagnosisNodeAction( + action=DiagnosisActionConstants.RELAUNCH_WORKER, + instance=DiagnosisConstant.LOCAL_INSTANCE, ) def _report_failure_to_master( diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index c2ddb67d1..d8cfbe73d 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -23,8 +23,8 @@ from dlrover.python.common.constants import NetworkFailureReason, NodeEnv from dlrover.python.common.log import default_logger as logger from dlrover.python.common.singleton import Singleton +from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData -from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction def retry_grpc_request(func): @@ -232,12 +232,12 @@ def report_global_step( ) return self._report(message) - def report_heart_beat(self, timestamp) -> List[DiagnoseAction]: + def report_heart_beat(self, timestamp) -> List[DiagnosisAction]: message = grpc.HeartBeat(timestamp=timestamp) response: grpc.HeartbeatResponse = self._get(message) - actions: List[DiagnoseAction] = [] + actions: List[DiagnosisAction] = [] for grpc_action in response.diagnosis_actions: - action = DiagnoseAction( + action = DiagnosisAction( rank=grpc_action.rank, timestamp=grpc_action.timestamp, expired_time_period=grpc_action.expired_time_period, diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 1289f472a..4b88a25b8 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -53,10 +53,7 @@ ) from torch.distributed.elastic.metrics import put_metric from torch.distributed.elastic.metrics.api import prof -from torch.distributed.elastic.multiprocessing import ( - PContext, - SignalException, -) +from torch.distributed.elastic.multiprocessing import PContext, SignalException from torch.distributed.elastic.multiprocessing.errors import ( ChildFailedError, ProcessFailure, @@ -83,8 +80,24 @@ find_free_port_in_set, ) from dlrover.python.common.log import default_logger as logger -from dlrover.python.elastic_agent.common.worker import WorkerContext -from dlrover.python.diagnosis.common.constants import DiagnosisAction +from dlrover.python.diagnosis.common.constants import DiagnosisActionConstants +from dlrover.python.diagnosis.common.diagnosis_action import ( + DiagnosisAction, + DiagnosisNodeAction, +) +from dlrover.python.diagnosis.common.inference_chain import ( + Inference, + InferenceAttribute, + InferenceDescription, + InferenceName, +) +from dlrover.python.elastic_agent.common.worker_context import ( + get_worker_context, + update_worker_context, +) +from dlrover.python.elastic_agent.config.launch_config import ( + ElasticLaunchConfig, +) from dlrover.python.elastic_agent.config.paral_config_tuner import ( ParalConfigTuner, ) @@ -99,14 +112,6 @@ version_less_than_230, version_less_than_240, ) -from dlrover.python.diagnosis.common.inference_chain import ( - Inference, - InferenceName, - InferenceAttribute, - InferenceDescription, -) -from dlrover.python.elastic_agent.config.launch_config import ElasticLaunchConfig -from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction try: from torch_npu.contrib import transfer_to_npu # noqa: F401 @@ -377,7 +382,7 @@ def __init__( self._save_ckpt_executor = ThreadPoolExecutor(max_workers=1) self._save_ckpt_future = None - self._worker_context = WorkerContext() + self._worker_context = get_worker_context() self._diagnose_agent = DiagnosisAgent( training_log_file=training_log_file, errors=failure_node_errors, @@ -682,11 +687,13 @@ def _pre_check(self): if self._config.network_check: succ = run_network_check(self._config, self._entrypoint) if not succ: - problems.append(Inference( - name=InferenceName.NODE, - attribution=InferenceAttribute.IS, - description=InferenceDescription.FAILURE, - )) + problems.append( + Inference( + name=InferenceName.NODE, + attribution=InferenceAttribute.IS, + description=InferenceDescription.FAILURE, + ) + ) self._diagnose_agent.diagnose_problems(problems) def _initialize_workers(self, worker_group): @@ -785,7 +792,7 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: actions = self._worker_context.next_actions() for action in actions: - self._process_diagnose_action(action) + self._process_diagnosis_action(action) if state == WorkerState.SUCCEEDED: logger.info( @@ -810,7 +817,7 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: logger.error(f"The worker fails with {run_result.failures}") self._save_ckpt_to_storage() - self._worker_context.update_context( + update_worker_context( worker_spec=self._worker_group.spec, remaining_failovers=self._remaining_failovers, restart_count=self._restart_count, @@ -819,12 +826,16 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: try: action = self._diagnose_agent.diagnose_training_failure() except Exception as e: - logger.warning(f"Failed to diagnose errors: {e}") + logger.warning(f"failed to diagnose errors: {e}") if self._remaining_failovers > 0: - action = DiagnosisAction.RESTART_WORKER + action = DiagnosisNodeAction( + action=DiagnosisActionConstants.RESTART_WORKER, + ) else: - action = DiagnosisAction.RELAUNCH_WORKER - self._process_diagnose_action(action) + action = DiagnosisNodeAction( + action=DiagnosisActionConstants.RELAUNCH_WORKER, + ) + self._process_diagnosis_action(action) if self._worker_group.state == WorkerState.FAILED: return run_result elif state == WorkerState.HEALTHY: @@ -835,13 +846,14 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: else: raise Exception(f"[{role}] worker group in {state.name} state") - def _process_diagnose_action(self, action: DiagnoseAction): - if action.action == DiagnosisAction.RESTART_WORKER: - self._remaining_failovers -= 1 - self._restart_workers(self._worker_group) - elif action.action == DiagnosisAction.RELAUNCH_WORKER: - self._stop_workers(self._worker_group) - self._worker_group.state = WorkerState.FAILED + def _process_diagnosis_action(self, action: DiagnosisAction): + if action.action_type == DiagnosisActionConstants.TYPE_NODE: + if action.action == DiagnosisActionConstants.RESTART_WORKER: + self._remaining_failovers -= 1 + self._restart_workers(self._worker_group) + elif action.action == DiagnosisActionConstants.RELAUNCH_WORKER: + self._stop_workers(self._worker_group) + self._worker_group.state = WorkerState.FAILED def _wait_async_saver(self): """ @@ -1211,7 +1223,7 @@ def run(self, role: str = DEFAULT_ROLE) -> bool: ) raise RuntimeError("This node is down.") elif self._node_rank in stragglers: - logger.warn("This node is a straggler!") + logger.warning("This node is a straggler!") if self._config.exclude_straggler: raise RuntimeError("The node is a straggler and exits.") return success diff --git a/dlrover/python/master/diagnosis/diagnosis_data_manager.py b/dlrover/python/master/diagnosis/diagnosis_data_manager.py index d4ef761b6..f1dbe83e4 100644 --- a/dlrover/python/master/diagnosis/diagnosis_data_manager.py +++ b/dlrover/python/master/diagnosis/diagnosis_data_manager.py @@ -1,7 +1,21 @@ +# Copyright 2024 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Dict, List -from dlrover.python.common.time import has_expired -from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData + from dlrover.python.common.log import default_logger as logger +from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData +from dlrover.python.util.time_util import has_expired class DiagnosisDataManager: diff --git a/dlrover/python/master/diagnosis/diagnosis_manager.py b/dlrover/python/master/diagnosis/diagnosis_manager.py index 9ec767052..0e05a08f2 100644 --- a/dlrover/python/master/diagnosis/diagnosis_manager.py +++ b/dlrover/python/master/diagnosis/diagnosis_manager.py @@ -13,42 +13,43 @@ import threading import time - from typing import List from dlrover.python.common.log import default_logger as logger -from dlrover.python.diagnosis.common.constants import DiagnosisConstant +from dlrover.python.diagnosis.common.constants import ( + DiagnosisConstant, + DiagnosisActionConstants, +) +from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData from dlrover.python.diagnosis.common.inference_chain import ( InferenceAttribute, InferenceDescription, InferenceName, ) -from dlrover.python.diagnosis.common.diagnose_action import ( - DiagnoseAction, +from dlrover.python.diagnosis.inferencechain.coordinate_inferences import ( + coordinate_inferences, ) from dlrover.python.diagnosis.inferencechain.inference_chain import ( Inference, InferenceChain, + combine_inferences, ) - -from dlrover.python.master.diagnosis.diagnosis_data_manager import DiagnosisDataManager from dlrover.python.diagnosis.inferencechain.inferenceoperator.operator import ( get_master_observe_operators, ) -from dlrover.python.diagnosis.inferencechain.inference_chain import combine_inferences -from dlrover.python.diagnosis.inferencechain.coordinate_inferences import ( - coordinate_inferences, +from dlrover.python.master.diagnosis.diagnosis_data_manager import ( + DiagnosisDataManager, ) -from dlrover.python.master.node.job import JobContext +from dlrover.python.master.node.job_context import get_job_context class DiagnosisManager: - def __init__(self, job_context: JobContext = None): + def __init__(self): self._is_observing_started = False self._data_manager: DiagnosisDataManager = DiagnosisDataManager(600) self._diagnostician: Diagnostician = Diagnostician(self._data_manager) - self._job_context = job_context + self._job_context = get_job_context() def collect_diagnosis_data(self, data: DiagnosisData): self._data_manager.store_data(data) @@ -116,8 +117,10 @@ def _diagnose(self): DiagnosisConstant.MASTER_DIAGNOSIS_OBSERVING_INTERVAL_SECS ) - def next_actions(self, rank) -> List[DiagnoseAction]: - return self._action_queue.next_actions(rank) + def next_actions( + self, instance=DiagnosisConstant.LOCAL_INSTANCE, action_type=DiagnosisActionConstants.ACTION_TYPE_ANY + ) -> List[DiagnosisAction]: + return self._job_context.next_actions(instance=instance, action_type=action_type) class Diagnostician: @@ -125,7 +128,9 @@ def __init__(self, data_manager): self._data_manager = data_manager self._pre_checks: List[Inference] = [] self._training_problems: List[Inference] = [] - self._observing_operators = get_master_observe_operators(data_mgr=data_manager) + self._observing_operators = get_master_observe_operators( + data_mgr=data_manager + ) def register_pre_check(self, pre_checks: List[Inference]): self._pre_checks = pre_checks @@ -137,9 +142,7 @@ def observe_training(self) -> List[Inference]: if len(self._training_problems) == 0: logger.warning("No training problem is registered.") return [] - ic = InferenceChain( - self._training_problems, self._observing_operators - ) + ic = InferenceChain(self._training_problems, self._observing_operators) return ic.infer() def diagnose_problem(self, inference: Inference) -> List[Inference]: diff --git a/dlrover/python/master/dist_master.py b/dlrover/python/master/dist_master.py index 65eefcf62..1c85ba43c 100644 --- a/dlrover/python/master/dist_master.py +++ b/dlrover/python/master/dist_master.py @@ -43,11 +43,11 @@ TaskRescheduleCallback, TFPSNodeHandlingCallback, ) +from dlrover.python.master.node.job_context import JobContext from dlrover.python.master.servicer import create_master_service from dlrover.python.master.shard.task_manager import TaskManager from dlrover.python.master.stats.job_collector import JobMetricCollector from dlrover.python.scheduler.job import JobArgs -from dlrover.python.master.node.job import JobContext def _create_elastic_ps_service_if_needed(params: JobArgs): @@ -124,9 +124,8 @@ def __init__( ) self.speed_monitor = SpeedMonitor() - self.job_context = JobContext() self.job_manager = ( - create_job_manager(args, self.speed_monitor, self.job_context) + create_job_manager(args, self.speed_monitor) if args.enable_elastic_scheduling else None ) @@ -145,7 +144,7 @@ def __init__( error_monitor ), } - self.diagnosis_manager = DiagnosisManager(self.job_context) + self.diagnosis_manager = DiagnosisManager() self.job_metric_collector = self._create_metric_collector_if_needed( args ) diff --git a/dlrover/python/master/node/dist_job_manager.py b/dlrover/python/master/node/dist_job_manager.py index 5effe062e..3d4df37bb 100644 --- a/dlrover/python/master/node/dist_job_manager.py +++ b/dlrover/python/master/node/dist_job_manager.py @@ -32,12 +32,14 @@ NodeType, TrainingExceptionLevel, ) -from dlrover.python.diagnosis.common.constants import DiagnosisConstant from dlrover.python.common.global_context import Context from dlrover.python.common.grpc import ParallelConfig from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node, NodeGroupResource -from dlrover.python.diagnosis.common.constants import DiagnosisAction +from dlrover.python.diagnosis.common.constants import ( + DiagnosisConstant, +) +from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction from dlrover.python.master.monitor.error_monitor import K8sJobErrorMonitor from dlrover.python.master.node.event_callback import ( ClusterContext, @@ -47,6 +49,10 @@ JobAutoScaler, new_job_auto_scaler, ) +from dlrover.python.master.node.job_context import ( + update_job_node, + update_job_nodes, +) from dlrover.python.master.node.job_manager import JobManager from dlrover.python.master.node.ps import ParameterServerManager from dlrover.python.master.node.status_flow import ( @@ -103,14 +109,12 @@ def __init__( job_scaler=None, error_monitor=None, external_config=None, - job_context=None, ): super().__init__( job_args=job_args, speed_monitor=speed_monitor, error_monitor=error_monitor, external_config=external_config, - job_context=job_context, ) self._remove_exited_node = job_args.remove_exited_node node_restart_count: Dict[str, int] = {} @@ -337,32 +341,25 @@ def _create_initial_scale_plan(self): def _init_training_node_manager(self): self._ps_manager = ParameterServerManager( - self._job_nodes.get(NodeType.PS, {}), self._job_resource, self._ps_relaunch_max_num, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - chief_nodes = self._job_nodes.get(NodeType.CHIEF, {}) - if not chief_nodes: - chief_nodes = self._job_nodes.get(NodeType.MASTER, {}) self._chief_manager = ChiefManager( - chief_nodes, self._job_resource, self._relaunch_on_worker_failure, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) self._worker_manager = WorkerManager( - self._job_nodes.get(NodeType.WORKER, {}), self._job_resource, self._relaunch_on_worker_failure, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) self._evaluator_manager = EvaluatorManager( - self._job_nodes.get(NodeType.EVALUATOR, {}), self._job_resource, self._relaunch_on_worker_failure, self._elastic_job.get_node_service_addr, @@ -373,7 +370,7 @@ def add_node_event_callback(self, node_event_callback): self._node_event_callbacks.append(node_event_callback) def _init_nodes(self): - self._job_nodes = self._job_resource.init_job_node_meta( + job_nodes = self._job_resource.init_job_node_meta( self._relaunch_on_worker_failure, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, @@ -387,28 +384,23 @@ def _init_nodes(self): self._pending_relaunch_count = 0 set_critical_node( - self._job_nodes, + job_nodes, self._ps_is_critical, self._ps_relaunch_max_num, self._critical_worker_index, ) - update_nodes_priority(self._job_nodes) - - self._ps_manager.update_nodes(self._job_nodes.get(NodeType.PS, {})) - chief_nodes = self._job_nodes.get(NodeType.CHIEF, {}) - if not chief_nodes: - chief_nodes = self._job_nodes.get(NodeType.MASTER, {}) - self._chief_manager.update_nodes(chief_nodes) - workers = self._job_nodes.get(NodeType.WORKER, {}) - self._worker_manager.update_nodes(workers) - evaluators = self._job_nodes.get(NodeType.EVALUATOR, {}) - self._evaluator_manager.update_nodes(evaluators) + update_nodes_priority(job_nodes) + update_job_nodes(job_nodes) + + self._ps_manager.update_nodes() + self._chief_manager.update_nodes() + self._worker_manager.update_nodes() + self._evaluator_manager.update_nodes() def _init_job_auto_scaler(self): self._job_autoscaler: JobAutoScaler = new_job_auto_scaler( self._job_args.distribution_strategy, self._job_resource, - self._job_nodes, self._job_optimizer, self._speed_monitor, self._ps_manager, @@ -481,7 +473,7 @@ def _diagnose_job(self): logger.warning(e) detail_trace_back = traceback.format_exc() logger.warning(detail_trace_back) - actions = self._job_context.next_actions(DiagnosisConstant.MASTER_RANK) + actions = self._job_context.next_actions() for action in actions: self._process_diagnosis_action(action) time.sleep(15) @@ -489,8 +481,9 @@ def _diagnose_job(self): def _get_dead_node_event(self, window_interval=900) -> List[NodeEvent]: now = time.time() dead_events: List[NodeEvent] = [] - logger.debug(f"Current job nodes are: {self._job_nodes}.") - for _, nodes in self._job_nodes.items(): + job_nodes = self._job_context.job_nodes() + logger.debug(f"Current job nodes are: {job_nodes}.") + for _, nodes in job_nodes.items(): for _, node in nodes.items(): if ( node.heartbeat_time > 0 @@ -540,7 +533,8 @@ def _get_dead_node_event(self, window_interval=900) -> List[NodeEvent]: def _get_nodes_time_info(self): result = {} - for _, nodes in self._job_nodes.items(): + job_nodes = self._job_context.job_nodes() + for _, nodes in job_nodes.items(): for _, node in nodes.items(): if node.heartbeat_time == 0: heartbeat_time = 0 @@ -587,7 +581,8 @@ def _process_list_nodes(self, nodes: List[Node]): if not nodes: return exist_nodes: Dict[str, List[int]] = {} - for node_type in self._job_nodes.keys(): + job_nodes = self._job_context.job_nodes() + for node_type in job_nodes.keys(): exist_nodes[node_type] = [] for node in nodes: exist_nodes[node.type].append(node.id) @@ -599,9 +594,9 @@ def _process_list_nodes(self, nodes: List[Node]): event = NodeEvent(type, node) self._process_event(event) - for node_type in self._job_nodes.keys(): + for node_type in job_nodes.keys(): # Avoid dictionary keys changed during iteration - type_nodes = list(self._job_nodes[node_type].values()) + type_nodes = list(job_nodes[node_type].values()) for node in type_nodes: if ( node.status != NodeStatus.INITIAL @@ -644,6 +639,7 @@ def _process_event(self, event: NodeEvent): node_type = event.node.type node_status = event.node.status node_id = event.node.id + job_nodes = self._job_context.job_nodes() # Skip deleted event of pod if the cluster has relaunched a new pod # with the same type and rank as the deleted pod. @@ -673,11 +669,11 @@ def _process_event(self, event: NodeEvent): ) return - if node_id not in self._job_nodes[node_type]: + if node_id not in job_nodes[node_type]: logger.info(f"The node {event.node.name} is released.") return else: - cur_node = self._job_nodes[node_type][node_id] + cur_node = job_nodes[node_type][node_id] logger.debug( f"Update node({cur_node.id}), " f"name: {cur_node.name}->{event.node.name}, " @@ -703,6 +699,7 @@ def _process_event(self, event: NodeEvent): restart_training=event.node.restart_training, relaunch_count=event.node.relaunch_count, ) + update_job_node(cur_node) # For the given node id, check whether it meets # the state change condition @@ -733,6 +730,8 @@ def _process_event(self, event: NodeEvent): cur_node.update_status(new_status) new_status = status_change_flow.to_status cur_node.set_exit_reason(event.node.exit_reason) + update_job_node(cur_node) + self._process_node_events(status_change_flow, cur_node) should_relaunch = self._should_relaunch( @@ -891,27 +890,32 @@ def _relaunch_node(self, node: Node): self._set_ps_addrs_in_plan(plan) if self._remove_exited_node: plan.remove_nodes.append(node) - node.relaunchable = False # Avoid repeatedly relaunching the node. + # Avoid repeatedly relaunching the node. + node.relaunchable = False + update_job_node(node) self._scaler.scale(plan) def clear_exited_nodes(self): if not self._remove_exited_node: return + job_nodes = self._job_context.job_nodes() scale_plan = ScalePlan() with self._lock: - for _, nodes in self._job_nodes.items(): + for _, nodes in job_nodes.items(): for _, node in nodes.items(): if not node.is_released and node.exited(): scale_plan.remove_nodes.append(node) node.is_released = True + update_job_node(node) if len(scale_plan.remove_nodes) > 0: logger.info(f"Remove exited nodes {scale_plan.remove_nodes}") self._scaler.scale(scale_plan) def clear_all_nodes(self): scale_plan = ScalePlan() + job_nodes = self._job_context.job_nodes() with self._lock: - for _, nodes in self._job_nodes.items(): + for _, nodes in job_nodes.items(): for _, node in nodes.items(): if not node.is_released: scale_plan.remove_nodes.append(node) @@ -942,7 +946,8 @@ def all_workers_deleted(self): def all_critical_node_completed(self): alive_critical_nodes = [] - for _, nodes in self._job_nodes.items(): + job_nodes = self._job_context.job_nodes() + for _, nodes in job_nodes.items(): for node in nodes.values(): if node.critical and node.status in [ NodeStatus.INITIAL, @@ -957,7 +962,8 @@ def all_critical_node_completed(self): return completed def remove_worker(self, worker_id): - if self._job_nodes[NodeType.WORKER][worker_id].critical: + job_nodes = self._job_context.job_nodes() + if job_nodes[NodeType.WORKER][worker_id].critical: logger.info("Skip the critical worker %s", worker_id) else: logger.info("Delete worker %s", worker_id) @@ -987,27 +993,30 @@ def post_ps_ready(self): def stop(self): self._enable_relaunch_node = False + job_nodes = self._job_context.job_nodes() with self._lock: - for node_type in self._job_nodes.keys(): - for node in self._job_nodes[node_type].values(): + for node_type in job_nodes.keys(): + for node in job_nodes[node_type].values(): node.critical = False node.is_released = True node.relaunchable = False - for node in self._job_nodes[NodeType.WORKER].values(): + update_job_node(node) + for node in job_nodes[NodeType.WORKER].values(): node.eval_time = self._speed_monitor.get_worker_eval_time( node.id ) + update_job_node(node) self._stopped = True def update_node_resource_usage( self, node_type, node_id, cpu, memory, gpu_stats=[] ): - if not self._job_nodes: + node = self._job_context.job_node(node_type, node_id) + if node is None: logger.warning( - "Skip updating for job_nodes hasn't been initialized." + f"Skip update node[{node_type}][{node_id}] resources" ) return - node = self._job_nodes[node_type][node_id] node.update_resource_usage(cpu, memory, gpu_stats) cpu_percent = node.used_resource.cpu / node.config_resource.cpu if cpu_percent < _dlrover_context.hang_cpu_usage_percentage: @@ -1018,17 +1027,21 @@ def update_node_resource_usage( if node.start_hang_time > 0: now = datetime.now() node.start_hang_time = 0 + update_job_node(node) def update_node_service_addr(self, node_type, node_id, service_addr): - node = self._job_nodes[node_type][node_id] + node = self._job_context.job_node(node_type, node_id) + if node is None: + logger.error(f"no Node[{node_type}][{node_id}] found") + return node.update_service_address(service_addr) node.status = NodeStatus.RUNNING node.is_released = False - self._job_nodes[node_type][node_id] = node + update_job_node(node) def get_cur_cluster_ps(self): """Get PS nodes in the current training cluster.""" - logger.info("job nodes are {}".format(self._job_nodes)) + logger.info("job nodes are {}".format(self._job_context.job_nodes())) return self._ps_manager.get_training_ps_cluster() def get_next_cluster_ps(self): @@ -1047,9 +1060,10 @@ def remove_training_nodes(self): """Remove all PS and workers""" self._job_autoscaler.stop_auto_scaling() plan = ScalePlan() - training_nodes = list( - self._job_nodes[NodeType.WORKER].values() - ) + list(self._job_nodes[NodeType.PS].values()) + job_nodes = self._job_context.job_nodes() + training_nodes = list(job_nodes[NodeType.WORKER].values()) + list( + job_nodes[NodeType.PS].values() + ) for node in training_nodes: if ( node.status in [NodeStatus.RUNNING, NodeStatus.PENDING] @@ -1060,6 +1074,7 @@ def remove_training_nodes(self): node.is_released = True node.status = NodeStatus.DELETED logger.info("Remove node %s", node.name) + update_job_node(node) plan.remove_nodes.append(node) self._scaler.scale(plan) @@ -1085,7 +1100,11 @@ def _report_event( ) def _process_error( - self, node: Node, restart_count: int, error_data: str, level: str + self, + node: Optional[Node], + restart_count: int, + error_data: str, + level: str, ) -> bool: if self._error_monitor: return self._error_monitor.process_error( @@ -1121,7 +1140,7 @@ def handle_training_failure( self, node_type, node_id, restart_count=-1, error_data="", level="" ): """Process the training failure reported by the node.""" - node = self._job_nodes[node_type][node_id] + node = self._job_context.job_node(node_type, node_id) if node.is_released: logger.info(f"The node {node.name} has been released.") return @@ -1140,29 +1159,33 @@ def get_opt_strategy(self) -> ParallelConfig: return strategy def update_node_paral_config(self, node_type, node_id, paral_config): - node = self._job_nodes[node_type][node_id] + node = self._job_context.job_node(node_type, node_id) + if node is None: + logger.warning(f"not found Node[{node_type}][{node_id}]") + return node.update_paral_config(paral_config) + update_job_node(node) def verify_restarting_worker_training(self, node_type, node_id): if node_type != NodeType.WORKER: return False return self._worker_manager.verify_restarting_training(node_id) - def collect_node_heart_beat(self, node_type, node_id, timestamp) -> List[DiagnosisAction]: + def collect_node_heart_beat( + self, node_type, node_id, timestamp + ) -> List[DiagnosisAction]: with self._lock: - if ( - node_type not in self._job_nodes - or node_id not in self._job_nodes[node_type] - ): + node = self._job_context.job_node(node_type, node_id) + if node is None: return [] - node = self._job_nodes[node_type][node_id] if node.heartbeat_time == 0: logger.info( f"Start receiving heartbeat from node {node_id}" f"-{node.name}" ) node.heartbeat_time = timestamp - return self._job_context.next_actions(node_id) + update_job_node(node) + return self._job_context.next_actions(instance=node_id) def update_node_required_info_callback(self): self._worker_manager.update_node_required_info(self._nodes_required) @@ -1172,7 +1195,9 @@ def update_succeeded_node(self, node_id, node_type): super().update_succeeded_node(node_id, node_type) -def create_job_manager(args: JobArgs, speed_monitor, job_context=None) -> DistributedJobManager: +def create_job_manager( + args: JobArgs, speed_monitor +) -> DistributedJobManager: critical_worker_index = get_critical_worker_index(args) # Custom distribution strategy does not exit if there are pending nodes wait_pending_relaunch = ( @@ -1197,5 +1222,4 @@ def create_job_manager(args: JobArgs, speed_monitor, job_context=None) -> Distri node_watcher=node_watcher, job_scaler=job_scaler, error_monitor=node_error_monitor, - job_context=job_context, ) diff --git a/dlrover/python/master/node/job.py b/dlrover/python/master/node/job.py deleted file mode 100644 index f2b10bfa7..000000000 --- a/dlrover/python/master/node/job.py +++ /dev/null @@ -1,16 +0,0 @@ -from dlrover.python.diagnosis.common.diagnose_action import DiagnoseActionQueue - - -class JobContext: - def __init__(self): - self._action_queue = DiagnoseActionQueue() - - def enqueue_actions(self, actions): - for action in actions: - self._action_queue.add_action(action) - - def next_actions(self, rank): - return self._action_queue.next_actions(rank) - - def update_context(self): - pass diff --git a/dlrover/python/master/node/job_auto_scaler.py b/dlrover/python/master/node/job_auto_scaler.py index 1a5f04b55..97a6efd5e 100644 --- a/dlrover/python/master/node/job_auto_scaler.py +++ b/dlrover/python/master/node/job_auto_scaler.py @@ -33,6 +33,7 @@ ) from dlrover.python.master.resource.optimizer import ResourcePlan from dlrover.python.master.scaler.base_scaler import ScalePlan, Scaler +from dlrover.python.master.node.job_context import get_job_context _dlrover_context = Context.singleton_instance() @@ -40,7 +41,6 @@ def new_job_auto_scaler( job_strategy, job_resource: JobResource, - job_nodes: Dict[str, Dict[int, Node]], job_optimizer: JobResourceOptimizer, speed_monitor: SpeedMonitor, ps_manager: ParameterServerManager, @@ -50,7 +50,6 @@ def new_job_auto_scaler( if job_strategy == DistributionStrategy.PS: return PSTrainingAutoScaler( job_resource, - job_nodes, job_optimizer, speed_monitor, ps_manager, @@ -60,7 +59,6 @@ def new_job_auto_scaler( elif job_strategy == DistributionStrategy.ALLREDUCE: return AllreduceTrainingAutoScaler( job_resource, - job_nodes, job_optimizer, speed_monitor, worker_manager, @@ -76,18 +74,17 @@ class JobAutoScaler(metaclass=ABCMeta): def __init__( self, job_resource: JobResource, - job_nodes: Dict[str, Dict[int, Node]], job_optimizer: JobResourceOptimizer, speed_monitor: SpeedMonitor, node_scaler: Scaler, scale_interval: int, ): self._job_resource = job_resource - self._job_nodes = job_nodes self._job_optimizer = job_optimizer self._speed_monitor = speed_monitor self._scaler = node_scaler self._scale_interval = scale_interval + self._job_context = get_job_context() self._suggested_stop = False self._autoscaling_started = False @@ -118,7 +115,6 @@ class PSTrainingAutoScaler(JobAutoScaler): def __init__( self, job_resource: JobResource, - job_nodes: Dict[str, Dict[int, Node]], job_optimizer: JobResourceOptimizer, speed_monitor: SpeedMonitor, ps_manager: ParameterServerManager, @@ -127,7 +123,6 @@ def __init__( ) -> None: super().__init__( job_resource, - job_nodes, job_optimizer, speed_monitor, node_scaler, @@ -224,7 +219,8 @@ def execute_job_optimization_plan(self, plan: ResourcePlan): scale_plan.merge(ps_plan) self._speed_monitor.reset_running_speed_monitor() elif node_type == NodeType.WORKER: - chief_num = len(self._job_nodes.get(NodeType.CHIEF, [])) + chief_nodes = self._job_context.job_nodes_by_type(NodeType.CHIEF) + chief_num = len(chief_nodes) worker_num = chief_num + group.count self._speed_monitor.set_target_worker_num(worker_num) worker_plan = self._worker_manager.adjust_worker(group) @@ -260,7 +256,7 @@ def _migrate_nodes(self, node_resources: Dict[str, NodeResource]): return scale_plan def _reduce_timeout_pending_node_resource(self): - """Cut down CPU cores of pending pod at the job starts""" + """Cut down CPU cores of pending pod when job starts""" scale_plan = ScalePlan() plan = self._ps_manager.reduce_pending_node_resource() scale_plan.merge(plan) @@ -278,7 +274,6 @@ class AllreduceTrainingAutoScaler(JobAutoScaler): def __init__( self, job_resource: JobResource, - job_nodes: Dict[str, Dict[int, Node]], job_optimizer: JobResourceOptimizer, speed_monitor: SpeedMonitor, worker_manager: WorkerManager, @@ -286,14 +281,12 @@ def __init__( ) -> None: super().__init__( job_resource, - job_nodes, job_optimizer, speed_monitor, node_scaler, 1800, ) self._worker_manager = worker_manager - self._workers = job_nodes[NodeType.WORKER] def start_auto_scaling(self): """Start auto-scaling nodes of a job""" @@ -334,7 +327,8 @@ def _periodic_adjust_worker(self): def _get_alive_worker_num(self): worker_num = 0 - for _, worker in self._workers.items(): + workers = self._job_context.job_nodes_by_type(NodeType.WORKER) + for _, worker in workers.items(): if worker.status in [ NodeStatus.RUNNING, NodeStatus.PENDING, diff --git a/dlrover/python/master/node/job_context.py b/dlrover/python/master/node/job_context.py new file mode 100644 index 000000000..1e2437d2c --- /dev/null +++ b/dlrover/python/master/node/job_context.py @@ -0,0 +1,138 @@ +# Copyright 2024 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import threading +import time +from typing import Dict, Optional + +from dlrover.python.common.constants import NodeType +from dlrover.python.common.node import Node +from dlrover.python.common.singleton import Singleton +from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisActionQueue +from dlrover.python.diagnosis.common.constants import ( + DiagnosisConstant, + DiagnosisActionConstants, +) +from datetime import datetime + + +class JobContext(Singleton): + """ + JobContext includes critical states of the training job that + will be shared across multiple components. + """ + + def __init__(self): + self._action_queue = DiagnosisActionQueue() + self._job_nodes: Dict[str, Dict[int, Node]] = {} + self._ps_nodes: Dict[int, Node] = {} + self._locker = threading.Lock() + + def enqueue_actions(self, actions): + for action in actions: + self._action_queue.add_action(action) + + def next_actions( + self, instance=DiagnosisConstant.LOCAL_INSTANCE, action_type=DiagnosisActionConstants.ACTION_TYPE_ANY + ): + return self._action_queue.next_actions(instance=instance, action_type=action_type) + + def _update_job_nodes(self, job_nodes: Dict[str, Dict[int, Node]]): + with self._locker: + self._job_nodes = job_nodes + if NodeType.PS in self._job_nodes: + self._ps_nodes = copy.deepcopy(self._job_nodes[NodeType.PS]) + else: + self._ps_nodes = {} + + @property + def ps_nodes(self) -> Dict[int, Node]: + with self._locker: + return self._ps_nodes + + def job_nodes(self) -> Dict[str, Dict[int, Node]]: + """ + return a copy of job nodes + """ + with self._locker: + return copy.deepcopy(self._job_nodes) + + def job_nodes_if_updated(self, last_update_timestamp: float) -> Optional[Dict[str, Dict[int, Node]]]: + with self._locker: + if self._last_update_timestamp > last_update_timestamp: + return self._job_nodes + return None + + def job_node(self, node_type: str, node_id: int) -> Optional[Node]: + with self._locker: + node_type = self._preprocess(node_type) + if ( + node_type not in self._job_nodes + or node_id not in self._job_nodes[node_type] + ): + return None + return copy.deepcopy(self._job_nodes[node_type][node_id]) + + def job_nodes_by_type(self, node_type: str) -> Dict[int, Node]: + with self._locker: + node_type = self._preprocess(node_type) + if node_type not in self._job_nodes: + return {} + return copy.deepcopy(self._job_nodes[node_type]) + + def _preprocess(self, node_type: str) -> str: + if node_type == NodeType.CHIEF and node_type not in self._job_nodes: + return NodeType.MASTER + return node_type + + def _update_job_node(self, node: Node): + with self._locker: + if self._job_nodes is None: + self._job_nodes = {} + if node.type not in self._job_nodes: + self._job_nodes[node.type] = {} + self._job_nodes[node.type][node.id] = copy.deepcopy(node) + + if node.type == NodeType.PS: + if node.id not in self._ps_nodes: + self._ps_nodes[node.id] = copy.deepcopy(node) + else: + self._ps_nodes[node.id].update_from_node(node) + + def _clear_nodes(self): + with self._locker: + self._job_nodes = {} + self._last_update_timestamp = datetime.now().timestamp() + + +def get_job_context() -> JobContext: + job_context = JobContext.singleton_instance() + return job_context + + +def update_job_nodes(job_nodes: Dict[str, Dict[int, Node]]): + job_context = JobContext.singleton_instance() + job_context._update_job_nodes(copy.deepcopy(job_nodes)) + + +def update_job_node(node: Node): + if node is None: + return + job_context = JobContext.singleton_instance() + job_context._update_job_node(copy.deepcopy(node)) + + +def clear_job_nodes(): + job_context = JobContext.singleton_instance() + job_context._clear_nodes() diff --git a/dlrover/python/master/node/job_manager.py b/dlrover/python/master/node/job_manager.py index 3f1b8d2bd..f74451b57 100644 --- a/dlrover/python/master/node/job_manager.py +++ b/dlrover/python/master/node/job_manager.py @@ -16,6 +16,7 @@ from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node +from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction from dlrover.python.master.hyperparams.simple_strategy_generator import ( SimpleStrategyGenerator, ) @@ -28,8 +29,10 @@ from dlrover.python.master.resource.job import JobResource from dlrover.python.scheduler.job import JobArgs from dlrover.python.scheduler.kubernetes import k8sClient -from dlrover.python.master.node.job import JobContext -from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction +from dlrover.python.master.node.job_context import ( + get_job_context, + update_job_node, +) class JobManager(metaclass=ABCMeta): @@ -43,7 +46,6 @@ def __init__( speed_monitor=None, error_monitor=None, external_config=None, - job_context=None, ): self._job_resource = JobResource() self._job_args = job_args @@ -55,16 +57,10 @@ def __init__( self._stopped = False self._speed_monitor: SpeedMonitor = speed_monitor self._error_monitor: ErrorMonitor = error_monitor - - self._job_nodes: Dict[str, Dict[int, Node]] = {} self._nodes_required = (0, 0, 0) self._training_node_config = TrainingNodeConfig(external_config) - - if job_context is None: - self._job_context = JobContext() - else: - self._job_context = job_context + self._job_context = get_job_context() @abstractmethod def start(self): @@ -203,7 +199,9 @@ def handle_training_failure( pass @abstractmethod - def collect_node_heart_beat(self, node_type, node_id, timestamp) -> List[DiagnoseAction]: + def collect_node_heart_beat( + self, node_type, node_id, timestamp + ) -> List[DiagnosisAction]: """Collect the heart beat message of nodes.""" pass @@ -241,9 +239,9 @@ def get_elastic_run_configs(self) -> Dict[str, str]: return self._training_node_config.get_elastic_run_configs() def update_succeeded_node(self, node_id, node_type): - if ( - node_type in self._job_nodes - and node_id in self._job_nodes[node_type] - ): + node = self._job_context.job_node(node_type, node_id) + if node is not None: logger.info(f"Node {node_id}({node_type}) to succeeded.") - self._job_nodes[node_type][node_id].set_as_succeeded() + node.set_as_succeeded() + update_job_node(node) + diff --git a/dlrover/python/master/node/local_job_manager.py b/dlrover/python/master/node/local_job_manager.py index f13574c2d..868d9a1f0 100644 --- a/dlrover/python/master/node/local_job_manager.py +++ b/dlrover/python/master/node/local_job_manager.py @@ -11,14 +11,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + from dlrover.python.common.constants import NodeStatus, NodeType from dlrover.python.common.grpc import ParallelConfig from dlrover.python.common.node import Node +from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction from dlrover.python.master.monitor.error_monitor import SimpleErrorMonitor from dlrover.python.master.node.job_manager import JobManager from dlrover.python.scheduler.job import JobArgs -from typing import List -from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction +from dlrover.python.master.node.job_context import update_job_node class LocalJobManager(JobManager): @@ -38,16 +40,18 @@ def __init__( self._job_resource_optimizer = None def start(self): - self._job_nodes[NodeType.WORKER] = {} + workers = {} worker = self._job_args.node_args[NodeType.WORKER].group_resource self._training_node_config.set_node_num(worker.count) for i in range(worker.count): - self._job_nodes[NodeType.WORKER][i] = Node( + workers[i] = Node( name=NodeType.WORKER + f"-{i}", node_type=NodeType.WORKER, node_id=i, status=NodeStatus.RUNNING, ) + update_job_node(workers[i]) + def should_early_stop(self): return False @@ -58,21 +62,31 @@ def add_node_event_callback(self, node_event_callback): def update_node_resource_usage( self, node_type, node_id, cpu, memory, gpu_stats=[] ): - node = self._job_nodes[node_type][node_id] + node = self._job_context.job_node(node_type, node_id) + if node is None: + return node.update_resource_usage(cpu, memory, gpu_stats) + update_job_node(node) def handle_training_failure( self, node_type, node_id, restart_count=-1, error_data="", level="" ): """Process the training failure reported by the node.""" - node = self._job_nodes[node_type][node_id] + node = self._job_context.job_node(node_type, node_id) + if node is None: + return self._error_monitor.process_error( node, restart_count, error_data, level ) - def collect_node_heart_beat(self, node_type, node_id, timestamp) -> List[DiagnoseAction]: - node = self._job_nodes[node_type][node_id] + def collect_node_heart_beat( + self, node_type, node_id, timestamp + ) -> List[DiagnosisAction]: + node = self._job_context.job_node(node_type, node_id) + if node is None: + return [] node.heartbeat_time = timestamp + update_job_node(node) return [] def close_job(self): @@ -94,12 +108,11 @@ def remove_worker(self, worker_id): pass def get_running_nodes(self): - nodes = list(self._job_nodes[NodeType.WORKER].values()) - return nodes + nodes = self._job_context.job_nodes_by_type(NodeType.WORKER) + return nodes.values() def get_running_workers(self): - workers = list(self._job_nodes[NodeType.WORKER].values()) - return workers + return self._job_context.job_nodes_by_type(NodeType.WORKER) def post_ps_ready(self): pass @@ -149,8 +162,11 @@ def get_opt_strategy(self) -> ParallelConfig: return strategy def update_node_paral_config(self, node_type, node_id, paral_config): - node = self._job_nodes[node_type][node_id] + node = self._job_context.job_node(node_type, node_id) + if node is None: + return node.update_paral_config(paral_config) + update_job_node(node) def create_job_manager(args: JobArgs, speed_monitor) -> LocalJobManager: diff --git a/dlrover/python/master/node/ps.py b/dlrover/python/master/node/ps.py index a796bea32..1efecdd71 100644 --- a/dlrover/python/master/node/ps.py +++ b/dlrover/python/master/node/ps.py @@ -21,9 +21,11 @@ from dlrover.python.common.global_context import Context from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node, NodeGroupResource, NodeResource +from dlrover.python.master.node.job_context import update_job_node from dlrover.python.master.node.training_node import TrainingNodeManager from dlrover.python.master.resource.job import JobResource from dlrover.python.master.scaler.base_scaler import ScalePlan +from datetime import datetime _dlrover_ctx = Context.singleton_instance() @@ -31,7 +33,6 @@ class ParameterServerManager(TrainingNodeManager): def __init__( self, - ps_nodes: Dict[int, Node], job_resource: JobResource, max_relaunch_num, new_service_fn, @@ -39,8 +40,6 @@ def __init__( ): """ Args: - ps_nodes: A dictionary where the key is the index of PS pod - and the value is the PodInfo instance of PS pod. job_resource: the resource configuration of a job. max_relaunch_num: The maximum relaunch number of PS. new_service_fn: A callable function to generate a server name of @@ -49,7 +48,7 @@ def __init__( PS. """ super(ParameterServerManager, self).__init__( - ps_nodes, new_node_name_fn + NodeType.PS, new_node_name_fn ) self._max_relaunch_num = max_relaunch_num self._job_resource = job_resource @@ -58,13 +57,15 @@ def __init__( self._lock = threading.Lock() self._ps_cluster_changed = True self._migrated_ps_nodes: Dict[int, Node] = {} + self._updated_ps_nodes = False self._next_training_ps_cluster: List[Node] = [] self._training_ps_cluster: List[Node] = [] self._node_id_iter = itertools.count(self._job_resource.ps_num) self._init_training_ps_cluster() def _init_training_ps_cluster(self): - for node in self._nodes.values(): + nodes = self._job_context.ps_nodes + for node in nodes.values(): alive = node.status in [ NodeStatus.INITIAL, NodeStatus.PENDING, @@ -84,10 +85,11 @@ def relaunch_node(self, node: Node, remove_exited_node=False): with self._lock: node.is_released = True new_id = next(self._node_id_iter) - self._nodes[new_id] = node.get_relaunch_node_info(new_id) + new_node = node.get_relaunch_node_info(new_id) + update_job_node(new_node) if node in self._training_ps_cluster: i = self._training_ps_cluster.index(node) - self._training_ps_cluster[i] = self._nodes[new_id] + self._training_ps_cluster[i] = new_node logger.info("Relaunch node %s to %s", node.name, new_id) plan.launch_nodes.append( Node( @@ -146,7 +148,7 @@ def _scale_up_ps(self, up_num): critical=True, service_addr=service_addr, ) - self._nodes[ps_id] = ps + update_job_node(ps) new_ps.append(ps) logger.info("Create PS %s", ps) return new_ps @@ -183,6 +185,7 @@ def process_after_ps_cluster_ready(self): node.critical = False node.relaunchable = False node.is_released = True + update_job_node(node) if node.id in self._migrated_ps_nodes: self._migrated_ps_nodes.pop(node.id) plan.remove_nodes.append(node) @@ -191,7 +194,8 @@ def process_after_ps_cluster_ready(self): def _get_alive_ps(self) -> List[Node]: """Get all running PS pods""" alive_ps = [] - for node in self._nodes.values(): + nodes = self._job_context.ps_nodes + for node in nodes.values(): if node.status == NodeStatus.RUNNING and not node.is_released: alive_ps.append(node) return alive_ps @@ -205,7 +209,8 @@ def get_next_training_ps_cluster(self): return self._next_training_ps_cluster all_new_ps_ready = True - for node in self._nodes.values(): + nodes = self._job_context.ps_nodes + for node in nodes.values(): if self._wait_ps_node(node): all_new_ps_ready = False break @@ -226,7 +231,8 @@ def has_ps_failure(self): Check whether there is PS failure and the master does not relaunch the failed PS node. """ - for node in self._nodes.values(): + nodes = self._job_context.ps_nodes + for node in nodes.values(): if node.timeout(_dlrover_ctx.seconds_to_wait_failed_ps): return True return False @@ -257,6 +263,7 @@ def _pre_drop_migrated_ps(self, alive_ps: List[Node]): ): if node not in self._pre_dropped_ps: node.migrated = True + update_job_node(node) self._pre_dropped_ps.append(node) def get_total_request_cpu(self): @@ -282,7 +289,8 @@ def get_ready_for_new_ps_cluster(self): def get_ps_addrs(self): """Get the address list of ps services""" ps_addrs = {} - for ps in list(self._nodes.values()): + nodes = self._job_context.ps_nodes + for ps in list(nodes.values()): if ( ps.id not in self._migrated_ps_nodes and not ps.is_released @@ -297,7 +305,8 @@ def get_ps_addrs(self): def delete_running_ps(self): """Delete all running ps pods""" plan = ScalePlan() - for node in list(self._nodes.values()): + nodes = self._job_context.ps_nodes + for node in list(nodes.values()): if ( node.status in [NodeStatus.RUNNING, NodeStatus.PENDING] and not node.is_released @@ -311,6 +320,8 @@ def delete_running_ps(self): ) node.is_released = True node.status = NodeStatus.DELETED + update_job_node(node) + plan.remove_nodes.append(node) return plan @@ -327,9 +338,13 @@ def migrate_parameter_servers(self, ps_nodes: Dict[str, NodeResource]): def _migrate_parameter_server(self, name: str, cpu=0, memory=0): """Migrate the parameter server node into a new pod""" old_ps_id = int(name.split("-")[-1]) - original_pod = self._nodes[old_ps_id] if old_ps_id in self._migrated_ps_nodes: return + nodes = self._job_context.ps_nodes + if old_ps_id not in nodes: + logger.error(f"not found PS-{old_ps_id} in job") + return + original_pod = nodes[old_ps_id] resource = copy.deepcopy(original_pod.config_resource) with self._lock: @@ -352,7 +367,7 @@ def _migrate_parameter_server(self, name: str, cpu=0, memory=0): service_addr=service_addr, name=self._new_node_name_fn(NodeType.PS, new_ps_id), ) - self._nodes[new_ps_id] = new_node + update_job_node(new_node) self._migrated_ps_nodes[old_ps_id] = new_node logger.info("Migrated PS %s to PS %s", old_ps_id, new_ps_id) return new_node @@ -361,9 +376,10 @@ def exist_migrated_ps_nodes(self): return len(self._migrated_ps_nodes) > 0 def is_all_running(self): + nodes = self._job_context.job_nodes_by_type(self._node_type) running_ps = [ pod_info.id - for pod_info in self._nodes.values() + for pod_info in nodes.values() if pod_info.status == NodeStatus.RUNNING ] return len(running_ps) == self._job_resource.ps_num diff --git a/dlrover/python/master/node/training_node.py b/dlrover/python/master/node/training_node.py index d719763be..d83a3fa7d 100644 --- a/dlrover/python/master/node/training_node.py +++ b/dlrover/python/master/node/training_node.py @@ -33,6 +33,7 @@ from dlrover.python.common.global_context import Context from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node +from dlrover.python.master.node.job_context import get_job_context, update_job_node from dlrover.python.master.scaler.base_scaler import ScalePlan from dlrover.python.scheduler.job import JobArgs @@ -172,19 +173,21 @@ def reduce_timeout_pending_node_resource(node: Node): class TrainingNodeManager(object): def __init__( self, - nodes: Dict[int, Node], + node_type: str, new_node_name_fn, ): """ Args: - nodes: training nodes + node_type: node type new_node_name_fn: new node name function """ - self._nodes = nodes + self._job_context = get_job_context() + nodes = self._job_context.job_nodes_by_type(node_type) + self._node_type = node_type self._new_node_name_fn = new_node_name_fn self._lock = threading.Lock() - self._node_id_iter = itertools.count(len(self._nodes)) - self._node_rank_iter = itertools.count(len(self._nodes)) + self._node_id_iter = itertools.count(len(nodes)) + self._node_rank_iter = itertools.count(len(nodes)) self._pending_nodes: List[Node] = [] @property @@ -202,34 +205,38 @@ def first_pending_node(self): @property def cur_nodes(self): - cur_nodes = [node.name for node in self._nodes.values()] + nodes = self._job_context.job_nodes_by_type(self._node_type) + cur_nodes = [node.name for node in nodes.values()] return cur_nodes - def update_nodes(self, nodes): - self._nodes = nodes - self._node_id_iter = itertools.count(len(self._nodes)) - self._node_rank_iter = itertools.count(len(self._nodes)) + def update_nodes(self): + nodes = self._job_context.job_nodes_by_type(self._node_type) + self._node_id_iter = itertools.count(len(nodes)) + self._node_rank_iter = itertools.count(len(nodes)) def remove_node(self, node_id): plan = ScalePlan() - if node_id not in self._nodes: - logger.info("Delete non-existed worker %s", node_id) - return plan - worker = self._nodes[node_id] with self._lock: + worker = self._job_context.job_node(self._node_type, node_id) + if worker is None: + logger.info("Delete non-existed worker %s", node_id) + return plan if worker.status in [NodeStatus.DELETED, NodeStatus.INITIAL]: logger.error("Unknown deletable worker id: %s" % node_id) return worker.is_released = True + update_job_node(worker) plan.remove_nodes.append(worker) return plan def relaunch_node(self, node: Node, remove_exited_node=False): plan = ScalePlan() + nodes = self._job_context.job_nodes_by_type(self._node_type) with self._lock: new_id = next(self._node_id_iter) relaunch_node = node.get_relaunch_node_info(new_id) - self._nodes[new_id] = relaunch_node + nodes[new_id] = relaunch_node + update_job_node(relaunch_node) logger.info("Relaunch node %s to %s", node.name, new_id) plan.launch_nodes.append( Node( @@ -244,6 +251,7 @@ def relaunch_node(self, node: Node, remove_exited_node=False): ) if remove_exited_node and not node.is_released and node.exited(): node.is_released = True + update_job_node(node) plan.remove_nodes.append(node) return plan @@ -252,19 +260,22 @@ def reduce_pending_node_resource(self): plan = ScalePlan() # Avoid dictionary changed size during iteration. - cur_nodes = list(self._nodes.values()) + nodes = self._job_context.job_nodes_by_type(self._node_type) + cur_nodes = list(nodes.values()) for node in cur_nodes: if node.status == NodeStatus.PENDING: reduced = reduce_timeout_pending_node_resource(node) if reduced: node.relaunchable = False + update_job_node(node) node_plan = self.relaunch_node(node) plan.remove_nodes.append(node) plan.merge(node_plan) return plan def get_pending_timeout_oom_recovered_node(self): - cur_nodes = list(self._nodes.values()) + nodes = self._job_context.job_nodes_by_type(self._node_type) + cur_nodes = list(nodes.values()) now = time.time() nodes = [] for node in cur_nodes: @@ -288,13 +299,15 @@ def get_running_nodes(self): """TensorFlow Chief nodes""" nodes = [] with self._lock: - for node in self._nodes.values(): + training_nodes = self._job_context.job_nodes_by_type(self._node_type) + for node in training_nodes.values(): if node.status == NodeStatus.RUNNING: nodes.append(node) return nodes def all_nodes_exited(self): - if len(self._nodes) == 0: + nodes = self._job_context.job_nodes_by_type(self._node_type) + if len(nodes) == 0: return True counter = self._get_node_counter() @@ -307,7 +320,7 @@ def all_nodes_exited(self): running_workers = [] pending_high_workers = [] pending_low_workers = [] - for worker_id, worker in self._nodes.items(): + for worker_id, worker in nodes.items(): if worker.is_released: continue if worker.config_resource.priority == PriorityClass.LOW: @@ -359,7 +372,8 @@ def all_nodes_failed(self): def running_nodes_hanged(self) -> List[bool]: cur_time = time.time() node_hang = [] - nodes = list(self._nodes.values()) # Avoid dictionary changed size. + nodes_dict = self._job_context.job_nodes_by_type(self._node_type) + nodes = list(nodes_dict.values()) # Avoid dictionary changed size. for node in nodes: if node.status == NodeStatus.RUNNING: timeout = NodeResourceLimit.MAX_HANG_TIMEOUT_SECS @@ -375,12 +389,14 @@ def running_nodes_hanged(self) -> List[bool]: f"{timeout} from {date_time}!!!" ) node.hang = hang + update_job_node(node) node_hang.append(hang) return node_hang def _get_node_counter(self): with self._lock: - return Counter([node.status for node in self._nodes.values()]) + nodes = self._job_context.job_nodes_by_type(self._node_type) + return Counter([node.status for node in nodes.values()]) def update_critical_node(self, critical_node_restarts): """Update critical node by a dict. @@ -389,7 +405,8 @@ def update_critical_node(self, critical_node_restarts): and values are the relaunchable number of nodes """ logger.info("Update critical worker {}".format(critical_node_restarts)) - for id, node in self._nodes.items(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for id, node in nodes.items(): if id in critical_node_restarts: node.critical = True node.max_relaunch_count = critical_node_restarts[id] diff --git a/dlrover/python/master/node/worker.py b/dlrover/python/master/node/worker.py index 2654aa2c6..66408ef80 100644 --- a/dlrover/python/master/node/worker.py +++ b/dlrover/python/master/node/worker.py @@ -24,6 +24,7 @@ from dlrover.python.common.global_context import Context from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node, NodeGroupResource, NodeResource +from dlrover.python.master.node.job_context import JobContext, update_job_node from dlrover.python.master.node.training_node import ( ALIVE_STATUS, TrainingNodeManager, @@ -37,7 +38,6 @@ class ChiefManager(TrainingNodeManager): def __init__( self, - chief_nodes: Dict[int, Node], job_resource: JobResource, max_relaunch_num, new_service_fn, @@ -45,8 +45,6 @@ def __init__( ): """ Args: - chief_nodes: A dictionary where the key is the index of - chief and the value is a Node instance. job_resource: the resource configuration of a job. max_relaunch_num: The maximum relaunch number of a chief. new_service_fn: A callable function to generate a server name of @@ -54,7 +52,9 @@ def __init__( new_node_name_fn: A callable function to generate a node name of chief. """ - super(ChiefManager, self).__init__(chief_nodes, new_node_name_fn) + super(ChiefManager, self).__init__( + NodeType.CHIEF, new_node_name_fn + ) self._job_resource = job_resource self._max_relaunch_num = max_relaunch_num self._new_service_fn = new_service_fn @@ -62,7 +62,8 @@ def __init__( def is_chief_running(self): """The chief worker with id=0 is responsible to initialize variables in TensorFlow 1.x PS strategy""" - for node in self._nodes.values(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for node in nodes.values(): if node.status == NodeStatus.RUNNING: return True return False @@ -71,7 +72,6 @@ def is_chief_running(self): class EvaluatorManager(TrainingNodeManager): def __init__( self, - evaluator_nodes: Dict[int, Node], job_resource: JobResource, max_relaunch_num, new_service_fn, @@ -79,8 +79,6 @@ def __init__( ): """ Args: - evaluator_nodes: A dictionary where the key is the index of - evaluator and the value is a Node instance. job_resource: the resource configuration of a job. max_relaunch_num: The maximum relaunch number of an evaluator. new_service_fn: A callable function to generate a server name of @@ -89,7 +87,7 @@ def __init__( evaluator. """ super(EvaluatorManager, self).__init__( - evaluator_nodes, new_node_name_fn + NodeType.EVALUATOR, new_node_name_fn ) self._job_resource = job_resource self._max_relaunch_num = max_relaunch_num @@ -98,7 +96,8 @@ def __init__( def is_chief_running(self): """The chief worker with id=0 is responsible to initialize variables in TensorFlow 1.x PS strategy""" - for node in self._nodes.values(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for node in nodes.values(): if node.status == NodeStatus.RUNNING: return True return False @@ -107,7 +106,6 @@ def is_chief_running(self): class WorkerManager(TrainingNodeManager): def __init__( self, - worker_nodes: Dict[int, Node], job_resource: JobResource, max_relaunch_num, new_service_fn, @@ -115,8 +113,6 @@ def __init__( ): """ Args: - worker_nodes: A dictionary where the key is the index of worker - and the value is a Node instance. job_resource: the resource configuration of a job. max_relaunch_num: The maximum relaunch number of worker. new_service_fn: A callable function to generate a server name of @@ -124,7 +120,9 @@ def __init__( new_node_name_fn: A callable function to generate a node name of worker. """ - super(WorkerManager, self).__init__(worker_nodes, new_node_name_fn) + super(WorkerManager, self).__init__( + NodeType.WORKER, new_node_name_fn + ) self._job_resource = job_resource self._max_relaunch_num = max_relaunch_num self._new_service_fn = new_service_fn @@ -143,7 +141,8 @@ def adjust_worker(self, worker_resource: NodeGroupResource): ) ) alive_workers = [] - for worker in self._nodes.values(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for worker in nodes.values(): if worker.status in ALIVE_STATUS: alive_workers.append(worker) alive_num = len(alive_workers) @@ -173,8 +172,8 @@ def _scale_up_workers(self, up_num): config_resource=copy.deepcopy(worker_resource), service_addr=service_addr, ) - self._nodes[worker_id] = new_node - logger.info("Create worker %s", self._nodes[worker_id]) + update_job_node(new_node) + logger.info("Create worker %s", new_node) plan.launch_nodes.append(new_node) return plan @@ -194,8 +193,9 @@ def _scale_down_workers(self, down_num, running_workers: List[Node]): def delete_exited_workers(self): """Delete failed, succeed, finished workers.""" plan = ScalePlan() + nodes = self._job_context.job_nodes_by_type(self._node_type) with self._lock: - for worker in self._nodes.values(): + for worker in nodes.values(): if ( worker.status in [ @@ -211,7 +211,8 @@ def delete_exited_workers(self): def delete_running_workers(self): plan = ScalePlan() - for worker in self._nodes.values(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for worker in nodes.values(): if not worker.critical and worker.status in [ NodeStatus.RUNNING, NodeStatus.PENDING, @@ -227,7 +228,11 @@ def delete_running_workers(self): return plan def remove_noncritical_worker(self, worker_id): - if self._nodes[worker_id].critical: + node = self._job_context.job_node(self._node_type, worker_id) + if node is None: + logger.error(f"not found node[{self._node_type}][{worker_id}]") + return + if node.critical: logger.info("Skip the critical worker %s", worker_id) else: return self.remove_node(worker_id) @@ -235,9 +240,10 @@ def remove_noncritical_worker(self, worker_id): def migrate_workers(self, workers: Dict[str, NodeResource]): """Migrate workers with the new resource""" plan = ScalePlan() + nodes = self._job_context.job_nodes_by_type(self._node_type) for name, resource in workers.items(): old_node_id = int(name.split("-")[-1]) - old_node = self._nodes[old_node_id] + old_node = nodes[old_node_id] if old_node.critical: continue old_node.migrated = True @@ -253,7 +259,7 @@ def migrate_workers(self, workers: Dict[str, NodeResource]): rank_index=task_id, name=self._new_node_name_fn(NodeType.WORKER, node_id), ) - self._nodes[node_id] = new_node + update_job_node(new_node) plan.launch_nodes.append(new_node) plan.remove_nodes.append(old_node) return plan @@ -264,17 +270,19 @@ def remove_not_joined_rdzv_workers(self, worker_ranks: List[int]): worker_ranks: The rank of worker which does not join rendezvous. """ plan = ScalePlan() - for node_id, node in self._nodes.items(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for node_id, node in nodes.items(): if node.rank_index in worker_ranks: p = self.remove_node(node.id) - self._nodes[node_id].relaunchable = False + nodes[node_id].relaunchable = False if p: plan.merge(p) return plan def has_exited_worker(self): """Check whether there is exited worker except evicted workers.""" - for worker in self._nodes.values(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for worker in nodes.values(): if ( worker.exit_reason == NodeExitReason.FATAL_ERROR or worker.status == NodeStatus.SUCCEEDED @@ -284,7 +292,8 @@ def has_exited_worker(self): def wait_worker_restart(self): """Check whether there are workers tha have remaining retries.""" - for worker in self._nodes.values(): + nodes = self._job_context.job_nodes_by_type(self._node_type) + for worker in nodes.values(): if ( worker.exit_reason == NodeExitReason.KILLED and worker.relaunch_count < worker.max_relaunch_count @@ -307,11 +316,15 @@ def verify_restarting_training(self, node_id): bool """ restart = False - worker = self._nodes[node_id] + worker = self._job_context.job_node(self._node_type, node_id) + if worker is None: + logger.error(f"not found worker-{node_id}") + return False if not worker.is_released: restart = worker.restart_training # Set False to avoid restart repeatedly. worker.restart_training = False + update_job_node(worker) return restart def _get_pending_timeout(self): @@ -349,7 +362,8 @@ def is_training_hang_by_pending(self, total_node_num) -> bool: if timeout <= 0: return False - cur_nodes = list(self._nodes.values()) + nodes = self._job_context.job_nodes_by_type(self._node_type) + cur_nodes = list(nodes.values()) # collect pending and running nodes pending_nodes: List[Node] = [] @@ -441,7 +455,9 @@ def is_training_hang_by_insufficient_worker(self) -> bool: logger.debug( f"Is training hang by insufficient worker with timeout: {timeout}." ) - cur_nodes = list(self._nodes.values()) + + nodes = self._job_context.job_nodes_by_type(self._node_type) + cur_nodes = list(nodes.values()) # collect available nodes available_nodes: List[Node] = [] diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index d1163d704..b64708512 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -641,8 +641,12 @@ def _sync_training_ports( port=sync_ports.training_port, newport=sync_ports.next_check_port ) - def _report_heartbeat(self, node_type, node_id, message: grpc.HeartBeat) -> grpc.HeartbeatResponse: - actions = self._job_manager.collect_node_heart_beat(node_type, node_id, message.timestamp) + def _report_heartbeat( + self, node_type, node_id, message: grpc.HeartBeat + ) -> grpc.HeartbeatResponse: + actions = self._job_manager.collect_node_heart_beat( + node_type, node_id, message.timestamp + ) grpc_actions: List[grpc.DiagnosisAction] = [] for action in actions: grpc_action = grpc.DiagnosisAction( diff --git a/dlrover/python/tests/test_common_util.py b/dlrover/python/tests/test_common_util.py index c39eeeeba..895eebbc9 100644 --- a/dlrover/python/tests/test_common_util.py +++ b/dlrover/python/tests/test_common_util.py @@ -19,4 +19,4 @@ class CommonUtilTest(unittest.TestCase): def test_get_dlrover_version(self): self.assertIsNotNone(cu.get_dlrover_version()) - self.assertNotEquals(cu.get_dlrover_version(), "Unknown") + self.assertNotEqual(cu.get_dlrover_version(), "Unknown") diff --git a/dlrover/python/tests/test_diagnosis.py b/dlrover/python/tests/test_diagnosis.py index 42826505f..649ad46d8 100644 --- a/dlrover/python/tests/test_diagnosis.py +++ b/dlrover/python/tests/test_diagnosis.py @@ -16,7 +16,9 @@ from dlrover.python.diagnosis.common.constants import DiagnosisDataType from dlrover.python.diagnosis.common.diagnosis_data import TrainingLog -from dlrover.python.master.diagnosis.diagnosis_manager import DiagnosisDataManager +from dlrover.python.master.diagnosis.diagnosis_manager import ( + DiagnosisDataManager, +) class DiagnosisTest(unittest.TestCase): diff --git a/dlrover/python/tests/test_diagnosis_agent.py b/dlrover/python/tests/test_diagnosis_agent.py index 26849ec60..3f50ebb7b 100644 --- a/dlrover/python/tests/test_diagnosis_agent.py +++ b/dlrover/python/tests/test_diagnosis_agent.py @@ -19,9 +19,15 @@ from dlrover.python.common import env_utils from dlrover.python.common.constants import RendezvousName -from dlrover.python.elastic_agent.common.worker import WorkerContext -from dlrover.python.diagnosis.common.constants import DiagnosisAction +from dlrover.python.diagnosis.common.constants import DiagnosisActionConstants from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric +from dlrover.python.elastic_agent.common.worker_context import ( + update_worker_context, + get_worker_context, +) +from dlrover.python.elastic_agent.config.launch_config import ( + ElasticLaunchConfig, +) from dlrover.python.elastic_agent.diagnosis.diagnosis_agent import ( DiagnosisAgent, ) @@ -29,11 +35,8 @@ MasterClient, build_master_client, ) -from dlrover.python.elastic_agent.torch.training import ( - _create_worker_spec, -) +from dlrover.python.elastic_agent.torch.training import _create_worker_spec from dlrover.python.tests.test_utils import start_local_master -from dlrover.python.elastic_agent.config.launch_config import ElasticLaunchConfig class TestDiagnosisAgent(unittest.TestCase): @@ -58,7 +61,11 @@ def test_diagnose_training(self): file_path = os.path.join(path, file) errors = "error code is 11111" - agent = DiagnosisAgent.singleton_instance(file_path, errors) + + worker_context = get_worker_context() + agent = DiagnosisAgent.singleton_instance( + file_path, errors, worker_context + ) spec = _create_worker_spec( node_rank=0, @@ -74,29 +81,30 @@ def test_diagnose_training(self): ), failures={}, ) - wc = WorkerContext( + + update_worker_context( worker_spec=spec, remaining_failovers=2, restart_count=3, run_result=run_result, ) - action = agent.diagnose_training_failure(wc) - self.assertEqual(action, DiagnosisAction.RESTART_WORKER) + action = agent.diagnose_training_failure() + self.assertEqual(action.action, DiagnosisActionConstants.RESTART_WORKER) agent._errors = "error code is 507035" - action = agent.diagnose_training_failure(wc) - self.assertEqual(action, DiagnosisAction.RELAUNCH_WORKER) + action = agent.diagnose_training_failure() + self.assertEqual(action.action, DiagnosisActionConstants.RELAUNCH_WORKER) agent._errors = "error code is 11111" - wc.remaining_failovers = 0 - action = agent.diagnose_training_failure(wc) - self.assertEqual(action, DiagnosisAction.RELAUNCH_WORKER) + worker_context.remaining_failovers = 0 + action = agent.diagnose_training_failure() + self.assertEqual(action.action, DiagnosisActionConstants.RELAUNCH_WORKER) agent._errors = " #" - wc.remaining_failovers = 2 - action = agent.diagnose_training_failure(wc) - self.assertEqual(action, DiagnosisAction.RESTART_WORKER) + worker_context.remaining_failovers = 2 + action = agent.diagnose_training_failure() + self.assertEqual(action.action, DiagnosisActionConstants.RESTART_WORKER) def test_worker_training_metric(self): test = WorkerTrainingMetric( diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index 75b723f01..1f5e2ad99 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -36,6 +36,9 @@ RendezvousName, ) from dlrover.python.common.storage import PosixDiskStorage +from dlrover.python.elastic_agent.config.launch_config import ( + ElasticLaunchConfig, +) from dlrover.python.elastic_agent.master_client import ( MasterClient, build_master_client, @@ -58,7 +61,6 @@ node_health_check, ) from dlrover.python.tests.test_utils import start_local_master -from dlrover.python.elastic_agent.config.launch_config import ElasticLaunchConfig class ElasticTrainingAgentTest(unittest.TestCase): diff --git a/dlrover/python/tests/test_job_auto_scaler.py b/dlrover/python/tests/test_job_auto_scaler.py index e8da71378..7d322b033 100644 --- a/dlrover/python/tests/test_job_auto_scaler.py +++ b/dlrover/python/tests/test_job_auto_scaler.py @@ -32,6 +32,7 @@ MockK8sPSJobArgs, mock_k8s_client, ) +from dlrover.python.master.node.job_context import get_job_context, update_job_node _dlrover_context = Context.singleton_instance() @@ -50,7 +51,6 @@ def test_execute_job_optimization_plan(self): auto_scaler = PSTrainingAutoScaler( manager._job_resource, - manager._job_nodes, manager._job_optimizer, manager._speed_monitor, manager._ps_manager, @@ -64,11 +64,28 @@ def test_execute_job_optimization_plan(self): plan.node_resources["test-edljob-worker-0"] = NodeResource(8, 8192) plan.node_resources["test-edljob-worker-1"] = NodeResource(8, 8192) plan.node_resources["test-edljob-ps-1"] = NodeResource(8, 8192) - auto_scaler._ps_manager._nodes[1].status = NodeStatus.RUNNING - auto_scaler._worker_manager._nodes[0].critical = True + + job_context = get_job_context() + ps_nodes = job_context.job_nodes_by_type(NodeType.PS) + ps_node = ps_nodes[1] + ps_node.type = NodeType.PS + ps_node.status = NodeStatus.RUNNING + update_job_node(ps_node) + # auto_scaler._ps_manager._nodes[1].status = NodeStatus.RUNNING + worker_nodes = job_context.job_nodes_by_type(NodeType.WORKER) + worker_node = worker_nodes[0] + worker_node.type = NodeType.WORKER + worker_node.critical = True + update_job_node(worker_node) + # auto_scaler._worker_manager._nodes[0].critical = True scale_plan = auto_scaler.execute_job_optimization_plan(plan) - self.assertEqual(len(manager._ps_manager._nodes), 4) - self.assertEqual(len(manager._worker_manager._nodes), 7) + + ps_nodes = job_context.job_nodes_by_type(NodeType.PS) + # self.assertEqual(len(manager._ps_manager._nodes), 4) + self.assertEqual(len(ps_nodes), 4) + worker_nodes = job_context.job_nodes_by_type(NodeType.WORKER) + # self.assertEqual(len(manager._worker_manager._nodes), 7) + self.assertEqual(len(worker_nodes), 7) self.assertEqual(len(scale_plan.remove_nodes), 1) self.assertEqual(len(scale_plan.launch_nodes), 5) remove_node = scale_plan.remove_nodes[0] @@ -113,7 +130,6 @@ def test_reduce_timeout_pending_node_resource(self): auto_scaler = PSTrainingAutoScaler( manager._job_resource, - manager._job_nodes, manager._job_optimizer, manager._speed_monitor, manager._ps_manager, @@ -121,10 +137,16 @@ def test_reduce_timeout_pending_node_resource(self): manager._scaler, ) auto_scaler._autoscaling_started = True - ps0 = manager._ps_manager._nodes[0] + + job_context = get_job_context() + ps_nodes = job_context.job_nodes_by_type(NodeType.PS) + ps0 = ps_nodes[0] + ps0.type = NodeType.PS + # ps0 = manager._ps_manager._nodes[0] ps0.config_resource.cpu = 16 ps0.status = NodeStatus.PENDING ps0.create_time = datetime.now() + timedelta(days=-1) + update_job_node(ps0) plan = auto_scaler._reduce_timeout_pending_node_resource() self.assertEqual( plan.ps_addrs, @@ -146,14 +168,17 @@ def test_execute_job_optimization_plan(self): manager = create_job_manager(params, SpeedMonitor()) manager._init_nodes() - for worker in manager._job_nodes[NodeType.WORKER].values(): + job_context = get_job_context() + worker_nodes = job_context.job_nodes_by_type(NodeType.WORKER) + + for worker in worker_nodes.values(): worker.status = NodeStatus.RUNNING + update_job_node(worker) manager._scaler.scale = mock.MagicMock(return_value=True) auto_scaler = AllreduceTrainingAutoScaler( manager._job_resource, - manager._job_nodes, manager._job_optimizer, manager._speed_monitor, manager._worker_manager, diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py index 712779066..88bf27c95 100644 --- a/dlrover/python/tests/test_job_manager.py +++ b/dlrover/python/tests/test_job_manager.py @@ -71,6 +71,7 @@ mock_k8s_client, new_dataset_splitter, ) +from dlrover.python.master.node.job_context import get_job_context, update_job_node, clear_job_nodes _MOCK_JOB_UUID = "11111" @@ -118,6 +119,9 @@ class DistributedJobManagerTest(unittest.TestCase): def setUp(self) -> None: mock_k8s_client() + def tearDown(self): + clear_job_nodes() + def test_job_resource(self): job = JobResource() job.node_group_resources[NodeType.PS] = NodeGroupResource( @@ -212,15 +216,17 @@ def test_get_critical_worker_index(self): self.assertDictEqual(critical_worker, {}) def test_relaunch_node(self): - print("!!!!!!!!!!!Start test relaunch_node!!!!!!!!!!!!\n") + job_context = get_job_context() params = MockK8sPSJobArgs() params.initilize() manager = create_job_manager(params, SpeedMonitor()) self.assertEqual(manager._ps_relaunch_max_num, 1) manager.start() self.assertEqual(manager._job_args.job_uuid, _MOCK_JOB_UUID) - self.assertEqual(len(manager._job_nodes), 4) - self.assertTrue(manager._job_nodes[NodeType.PS][0].critical) + + job_nodes = job_context.job_nodes() + self.assertEqual(len(job_nodes), 4) + self.assertTrue(job_nodes[NodeType.PS][0].critical) node = Node( node_type=NodeType.WORKER, @@ -242,21 +248,23 @@ def test_relaunch_node(self): manager.update_node_resource_usage( NodeType.WORKER, 0, 0.7, 2048, gpu_stats ) # noqa + job_nodes = job_context.job_nodes() self.assertEqual( - manager._job_nodes[NodeType.WORKER][0].used_resource.cpu, 0.7 + job_nodes[NodeType.WORKER][0].used_resource.cpu, 0.7 ) self.assertEqual( - manager._job_nodes[NodeType.WORKER][0].used_resource.memory, 2048 + job_nodes[NodeType.WORKER][0].used_resource.memory, 2048 ) self.assertEqual( - manager._job_nodes[NodeType.WORKER][0].used_resource.gpu_stats, + job_nodes[NodeType.WORKER][0].used_resource.gpu_stats, gpu_stats, # noqa ) node_event: NodeEvent = NodeEvent(NodeEventType.MODIFIED, node) manager._process_event(node_event) + job_nodes = job_context.job_nodes() self.assertEqual( - manager._job_nodes[NodeType.WORKER][1].status, NodeStatus.RUNNING + job_nodes[NodeType.WORKER][1].status, NodeStatus.RUNNING ) should_relaunch = manager._should_relaunch(node, NODE_STATE_FLOWS[5]) self.assertFalse(should_relaunch) @@ -349,15 +357,18 @@ def test_get_dead_node_event(self): manager.start() ts = int(time.time()) manager.collect_node_heart_beat(NodeType.WORKER, 0, ts) - worker0 = manager._job_nodes[NodeType.WORKER][0] + + job_context = get_job_context() + job_nodes = job_context.job_nodes() + worker0 = job_nodes[NodeType.WORKER][0] self.assertEqual(worker0.heartbeat_time, ts) - for node in manager._job_nodes[NodeType.WORKER].values(): + for node in job_nodes[NodeType.WORKER].values(): node.status = NodeStatus.RUNNING events = manager._get_dead_node_event() self.assertEqual(len(events), 0) for index, node in enumerate( - manager._job_nodes[NodeType.WORKER].values() + job_nodes[NodeType.WORKER].values() ): node.status = NodeStatus.RUNNING now = datetime.now() @@ -368,6 +379,7 @@ def test_get_dead_node_event(self): else: node.create_time = now - timedelta(seconds=1400) node.start_time = now - timedelta(seconds=1200) + update_job_node(node) events = manager._get_dead_node_event() self.assertEqual(len(events), 2) @@ -375,8 +387,9 @@ def test_get_dead_node_event(self): self.assertIsNotNone(nodes_time_info) self.assertEqual(len(nodes_time_info), 3) + job_nodes = job_context.job_nodes() for index, node in enumerate( - manager._job_nodes[NodeType.WORKER].values() + job_nodes[NodeType.WORKER].values() ): node.status = NodeStatus.RUNNING now = datetime.now() @@ -389,6 +402,7 @@ def test_get_dead_node_event(self): node.succeeded = True node.create_time = now - timedelta(seconds=1400) node.start_time = now - timedelta(seconds=1200) + update_job_node(node) events = manager._get_dead_node_event() self.assertEqual(len(events), 1) @@ -403,7 +417,7 @@ def test_relaunch_training_master(self): manager._init_nodes() master = Node(NodeType.MASTER, 0, NodeResource(1, 256)) - manager._job_nodes[NodeType.MASTER][0] = master + update_job_node(master) plan = manager._chief_manager.relaunch_node(master) self.assertEqual(plan.launch_nodes[0].id, 1) @@ -412,8 +426,11 @@ def test_process_list_nodes(self): params.initilize() manager = create_job_manager(params, SpeedMonitor()) manager._init_nodes() - for node in manager._job_nodes[NodeType.PS].values(): + job_context = get_job_context() + job_nodes = job_context.job_nodes() + for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.PENDING + update_job_node(node) nodes = [] for i in range(2): node = Node( @@ -425,7 +442,9 @@ def test_process_list_nodes(self): ) nodes.append(node) manager._process_list_nodes(nodes) - ps_ids = list(manager._job_nodes[NodeType.PS].keys()) + + job_nodes = job_context.job_nodes() + ps_ids = list(job_nodes[NodeType.PS].keys()) self.assertListEqual(ps_ids, [0, 1, 2]) def test_create_allreduce_job_manager(self): @@ -440,9 +459,13 @@ def test_create_allreduce_job_manager(self): manager._adjust_worker_for_estimator() manager._init_nodes() manager._init_job_auto_scaler() - self.assertEqual(len(manager._job_nodes[NodeType.WORKER]), 3) + + job_context = get_job_context() + job_nodes = job_context.job_nodes() + self.assertEqual(len(job_nodes[NodeType.WORKER]), 3) manager.start_auto_scaling() - self.assertEqual(len(manager._job_nodes[NodeType.WORKER]), 3) + job_nodes = job_context.job_nodes() + self.assertEqual(len(job_nodes[NodeType.WORKER]), 3) def test_recover_tasks_for_failed_workers(self): ds_name_0 = "test-0" @@ -509,27 +532,41 @@ def test_check_worker_status(self): manager._init_nodes() self.assertFalse(manager.all_workers_exited()) - for worker in manager._job_nodes[NodeType.WORKER].values(): + job_context = get_job_context() + job_nodes = job_context.job_nodes() + + for worker in job_nodes[NodeType.WORKER].values(): worker.status = NodeStatus.FINISHED - for worker in manager._job_nodes[NodeType.CHIEF].values(): + update_job_node(worker) + for worker in job_nodes[NodeType.CHIEF].values(): worker.status = NodeStatus.FINISHED - for worker in manager._job_nodes[NodeType.EVALUATOR].values(): + update_job_node(worker) + for worker in job_nodes[NodeType.EVALUATOR].values(): worker.status = NodeStatus.FINISHED + update_job_node(worker) self.assertTrue(manager.all_workers_exited()) - for worker in manager._job_nodes[NodeType.WORKER].values(): + job_nodes = job_context.job_nodes() + for worker in job_nodes[NodeType.WORKER].values(): worker.status = NodeStatus.FAILED - for worker in manager._job_nodes[NodeType.CHIEF].values(): + update_job_node(worker) + for worker in job_nodes[NodeType.CHIEF].values(): worker.status = NodeStatus.FAILED - for worker in manager._job_nodes[NodeType.EVALUATOR].values(): + update_job_node(worker) + for worker in job_nodes[NodeType.EVALUATOR].values(): worker.status = NodeStatus.FAILED + update_job_node(worker) self.assertTrue(manager.all_workers_failed()) - for worker in manager._job_nodes[NodeType.PS].values(): + job_nodes = job_context.job_nodes() + for worker in job_nodes[NodeType.PS].values(): worker.status = NodeStatus.FINISHED - manager._job_nodes[NodeType.WORKER][0].status = NodeStatus.RUNNING + update_job_node(worker) + job_nodes[NodeType.WORKER][0].status = NodeStatus.RUNNING + update_job_node(job_nodes[NodeType.WORKER][0]) self.assertFalse(manager.all_critical_node_completed()) - manager._job_nodes[NodeType.WORKER][0].status = NodeStatus.FINISHED + job_nodes[NodeType.WORKER][0].status = NodeStatus.FINISHED + update_job_node(job_nodes[NodeType.WORKER][0]) self.assertTrue(manager.all_critical_node_completed()) def test_tf_ps_node_handling(self): @@ -572,10 +609,13 @@ def test_all_running_node_hang(self): hang = manager.all_running_node_hanged() self.assertFalse(hang) - for _, nodes in manager._job_nodes.items(): + job_context = get_job_context() + job_nodes = job_context.job_nodes() + for _, nodes in job_nodes.items(): for _, node in nodes.items(): node.start_hang_time = time.time() - 3600 * 4 node.status = NodeStatus.RUNNING + update_job_node(node) manager.update_node_resource_usage(NodeType.WORKER, 0, 0.01, 256) hang = manager.all_running_node_hanged() self.assertTrue(hang) @@ -588,33 +628,43 @@ def test_early_stop_part1(self): params.initilize() manager = create_job_manager(params, SpeedMonitor()) manager._init_nodes() - for node in manager._job_nodes[NodeType.PS].values(): + + job_context = get_job_context() + job_nodes = job_context.job_nodes() + for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.PENDING node.is_recovered_oom = True node.create_time = datetime.now() + update_job_node(node) result, reason, msg = manager.should_early_stop() self.assertFalse(result) self.assertFalse(reason) self.assertFalse(msg) manager._remove_exited_node = True - manager._job_nodes[NodeType.WORKER][0].status = NodeStatus.FAILED + job_nodes = job_context.job_nodes() + job_nodes[NodeType.WORKER][0].status = NodeStatus.FAILED + update_job_node(job_nodes[NodeType.WORKER][0]) manager.clear_exited_nodes() - self.assertTrue(manager._job_nodes[NodeType.WORKER][0].is_released) + job_nodes = job_context.job_nodes() + self.assertTrue(job_nodes[NodeType.WORKER][0].is_released) - for node in manager._job_nodes[NodeType.PS].values(): + for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.PENDING node.create_time = datetime.now() + timedelta(days=-1) node.is_recovered_oom = True + update_job_node(node) result, reason, msg = manager.should_early_stop() self.assertTrue(result) self.assertTrue(reason) self.assertTrue(msg) - for node in manager._job_nodes[NodeType.PS].values(): + job_nodes = job_context.job_nodes() + for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.RUNNING node.create_time = datetime.now() + timedelta(days=-1) node.is_recovered_oom = True + update_job_node(node) result, reason, msg = manager.should_early_stop() self.assertFalse(result) self.assertFalse(reason) @@ -662,7 +712,9 @@ def test_when_node_not_init(self): params = MockK8sPSJobArgs() params.initilize() manager = create_job_manager(params, SpeedMonitor()) - self.assertTrue(not manager._job_nodes) + job_context = get_job_context() + job_nodes = job_context.job_nodes() + self.assertTrue(len(job_nodes) == 0) manager.update_node_resource_usage(NodeType.WORKER, 0, 10, 10240, None) @@ -684,8 +736,10 @@ def test_concurrency_heart_beat_collecting(self): manager = create_job_manager(params, SpeedMonitor()) manager.start() - self.assertEqual(len(manager._job_nodes[NodeType.WORKER]), worker_size) - for i, node in manager._job_nodes[NodeType.WORKER].items(): + job_context = get_job_context() + job_nodes = job_context.job_nodes() + self.assertEqual(len(job_nodes[NodeType.WORKER]), worker_size) + for i, node in job_nodes[NodeType.WORKER].items(): self.assertEqual(node.id, i) self.assertEqual(node.heartbeat_time, 0) futures = [] @@ -701,14 +755,15 @@ def test_concurrency_heart_beat_collecting(self): future.result() self.assertEqual(len(futures), worker_size) - for i, node in manager._job_nodes[NodeType.WORKER].items(): + job_nodes = job_context.job_nodes() + for i, node in job_nodes[NodeType.WORKER].items(): self.assertEqual(node.id, i) self.assertEqual(node.heartbeat_time, i) manager.stop() # test when job manager not init - manager._job_nodes = {} + clear_job_nodes() try: manager.collect_node_heart_beat("worker", 1, 111) except Exception: @@ -730,7 +785,10 @@ def test_local_job_manager(self): args.node_args[NodeType.WORKER].group_resource.count = 4 job_manager = LocalJobManager(args, error_monitor=SimpleErrorMonitor()) job_manager.start() - self.assertEqual(len(job_manager._job_nodes[NodeType.WORKER]), 4) + + job_context = get_job_context() + job_nodes = job_context.job_nodes() + self.assertEqual(len(job_nodes[NodeType.WORKER]), 4) gpu_stats: list[GPUStats] = [ GPUStats( index=0, @@ -743,7 +801,8 @@ def test_local_job_manager(self): NodeType.WORKER, 0, 10, 10240, gpu_stats ) - worker = job_manager._job_nodes[NodeType.WORKER][0] + job_nodes = job_context.job_nodes() + worker = job_nodes[NodeType.WORKER][0] self.assertEqual(worker.used_resource.cpu, 10) self.assertEqual(worker.used_resource.memory, 10240) self.assertEqual(worker.used_resource.gpu_stats, gpu_stats) @@ -752,17 +811,20 @@ def test_local_job_manager(self): optimizer_config = OptimizerConfig(1, "test_optimizer", 2) paral_config = ParallelConfig(dataloader_config, optimizer_config) job_manager.update_node_paral_config(NodeType.WORKER, 0, paral_config) - worker = job_manager._job_nodes[NodeType.WORKER][0] + + job_nodes = job_context.job_nodes() + worker = job_nodes[NodeType.WORKER][0] self.assertEqual(worker.paral_config, paral_config) job_manager.handle_training_failure(NodeType.WORKER, 3) try: self.assertFalse( - job_manager._job_nodes[NodeType.WORKER][0].is_succeeded() + job_nodes[NodeType.WORKER][0].is_succeeded() ) job_manager.update_succeeded_node(0, NodeType.WORKER) + job_nodes = job_context.job_nodes() self.assertTrue( - job_manager._job_nodes[NodeType.WORKER][0].is_succeeded() + job_nodes[NodeType.WORKER][0].is_succeeded() ) job_manager.update_succeeded_node(5, NodeType.WORKER) job_manager.update_succeeded_node(0, "unknown") diff --git a/dlrover/python/tests/test_log.py b/dlrover/python/tests/test_log.py index 3d90f69e9..16a5c5666 100644 --- a/dlrover/python/tests/test_log.py +++ b/dlrover/python/tests/test_log.py @@ -25,6 +25,7 @@ def test_default_log_level(self): from dlrover.python.common.log import default_logger as logger from dlrover.python.common.log import get_log_level + os.environ[BasicClass.LOG_LEVEL_ENV] = "INFO" self.assertEqual(get_log_level(), "INFO") logger.info("test123") diff --git a/dlrover/python/tests/test_master.py b/dlrover/python/tests/test_master.py index eef86ef4a..8cc59d66d 100644 --- a/dlrover/python/tests/test_master.py +++ b/dlrover/python/tests/test_master.py @@ -34,6 +34,7 @@ mock_k8s_client, start_local_master, ) +from dlrover.python.master.node.job_context import get_job_context, update_job_node, clear_job_nodes, update_job_nodes _dlrover_context = Context.singleton_instance() @@ -44,23 +45,28 @@ def setUp(self) -> None: params = MockK8sPSJobArgs() params.initilize() self.master = DistributedJobMaster(2222, params) + self.job_context = get_job_context() + + def tearDown(self): + clear_job_nodes() def test_exit_by_workers(self): self.master.job_manager._init_nodes() - job_nodes = self.master.job_manager._job_nodes + job_nodes = self.job_context.job_nodes() for node in job_nodes[NodeType.WORKER].values(): node.status = NodeStatus.FINISHED for node in job_nodes[NodeType.EVALUATOR].values(): node.status = NodeStatus.FINISHED for node in job_nodes[NodeType.CHIEF].values(): node.status = NodeStatus.FINISHED + update_job_nodes(job_nodes) self.master.run() self.assertEqual(self.master._exit_code, 0) self.assertEqual(self.master._exit_reason, JobExitReason.SUCCEEDED) def test_exit_by_tasks(self): self.master.job_manager._init_nodes() - job_nodes = self.master.job_manager._job_nodes + job_nodes = self.job_context.job_nodes() for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.FINISHED for node in job_nodes[NodeType.EVALUATOR].values(): @@ -69,6 +75,7 @@ def test_exit_by_tasks(self): node.status = NodeStatus.FINISHED job_nodes[NodeType.WORKER][0].status = NodeStatus.FINISHED + update_job_nodes(job_nodes) splitter = new_dataset_splitter( False, @@ -91,11 +98,14 @@ def test_exit_by_tasks(self): def test_early_stop(self): self.master.job_manager._init_nodes() - job_nodes = self.master.job_manager._job_nodes + job_nodes = self.job_context.job_nodes() + for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.PENDING node.is_recovered_oom = True node.create_time = datetime.now() + timedelta(days=-1) + update_job_nodes(job_nodes) + exit_code = self.master.run() self.master.job_manager.clear_all_nodes() self.assertEqual(exit_code, 1) diff --git a/dlrover/python/tests/test_ps_manager.py b/dlrover/python/tests/test_ps_manager.py index 7a83fbbfc..f74246dd4 100644 --- a/dlrover/python/tests/test_ps_manager.py +++ b/dlrover/python/tests/test_ps_manager.py @@ -20,6 +20,7 @@ from dlrover.python.master.resource.job import JobResource from dlrover.python.scheduler.factory import new_elastic_job from dlrover.python.tests.test_utils import mock_k8s_client +from dlrover.python.master.node.job_context import get_job_context, update_job_nodes, clear_job_nodes, update_job_node class PSManagerTest(unittest.TestCase): @@ -32,20 +33,24 @@ def setUp(self) -> None: self._elastic_job = new_elastic_job( PlatformType.KUBERNETES, "test", "default" ) - self._job_nodes = self._job_resource.init_job_node_meta( + self._job_context = get_job_context() + job_nodes = self._job_resource.init_job_node_meta( 1, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + update_job_nodes(job_nodes) self._ps_manager = ParameterServerManager( - self._job_nodes[NodeType.PS], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + def tearDown(self) -> None: + clear_job_nodes() + def test_get_training_ps_cluster(self): ps_nodes = self._ps_manager.get_training_ps_cluster() self.assertEqual(len(ps_nodes), 2) @@ -59,9 +64,11 @@ def test_get_training_ps_cluster(self): ) def test_cut_pending_ps_cpu(self): - for _, node in self._ps_manager._nodes.items(): + nodes = self._job_context.job_nodes_by_type(NodeType.PS) + for _, node in nodes.items(): node.status = NodeStatus.PENDING node.create_time = datetime.now() + timedelta(days=-1) + update_job_node(node) plan = self._ps_manager.reduce_pending_node_resource() self.assertEqual(len(plan.launch_nodes), 2) @@ -72,8 +79,11 @@ def test_scale_up_ps(self): self._ps_manager._scale_up_ps(2) training_ps = self._ps_manager.get_next_training_ps_cluster() self.assertEqual(len(training_ps), 2) - for node in self._ps_manager._nodes.values(): + + nodes = self._job_context.job_nodes_by_type(NodeType.PS) + for node in nodes.values(): node.status = NodeStatus.RUNNING + update_job_node(node) training_ps = self._ps_manager.get_next_training_ps_cluster() self.assertEqual(len(training_ps), 4) @@ -83,15 +93,17 @@ def test_scale_down_ps(self): self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + update_job_nodes(job_nodes) ps_manager = ParameterServerManager( - job_nodes[NodeType.PS], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - for node in ps_manager._nodes.values(): + nodes = self._job_context.job_nodes_by_type(NodeType.PS) + for node in nodes.values(): node.status = NodeStatus.RUNNING + update_job_node(node) ps_manager._scale_down_ps(1) self.assertEqual(len(ps_manager._pre_dropped_ps), 1) self.assertEqual(ps_manager._pre_dropped_ps[0].id, 1) @@ -105,17 +117,20 @@ def test_delete_running_ps(self): self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + update_job_nodes(job_nodes) ps_manager = ParameterServerManager( - job_nodes[NodeType.PS], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - for node in ps_manager._nodes.values(): + nodes = self._job_context.job_nodes_by_type(NodeType.PS) + for node in nodes.values(): node.status = NodeStatus.RUNNING + update_job_node(node) plan = ps_manager.delete_running_ps() + job_nodes = self._job_context.job_nodes() self.assertEqual(len(plan.remove_nodes), 2) self.assertTrue(job_nodes[NodeType.PS][0].is_released) self.assertTrue(job_nodes[NodeType.PS][1].is_released) @@ -126,27 +141,39 @@ def test_migrate_parameter_servers(self): self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + update_job_nodes(job_nodes) ps_manager = ParameterServerManager( - job_nodes[NodeType.PS], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - for node in ps_manager._nodes.values(): + print(f"!!!!!!!!!!!!!migrated_ps_nodes={ps_manager._migrated_ps_nodes}\n") + + nodes = self._job_context.job_nodes_by_type(NodeType.PS) + for node in nodes.values(): node.status = NodeStatus.RUNNING + update_job_node(node) + + job_nodes = self._job_context.job_nodes() node_name = job_nodes[NodeType.PS][0].name nodes = {node_name: NodeResource(20, 2048)} plan = ps_manager.migrate_parameter_servers(nodes) self.assertEqual(len(plan.launch_nodes), 1) self.assertEqual(ps_manager._migrated_ps_nodes[0].id, 2) self.assertTrue(ps_manager.exist_migrated_ps_nodes()) + print(f"!!!!!!!!!!!!!migrated_ps_nodes={ps_manager._migrated_ps_nodes}\n") - ps_manager._pre_drop_migrated_ps(list(ps_manager._nodes.values())) + nodes = self._job_context.job_nodes_by_type(NodeType.PS) + ps_manager._pre_drop_migrated_ps(list(nodes.values())) self.assertEqual(len(ps_manager._pre_dropped_ps), 0) - for node in ps_manager._nodes.values(): + for node in nodes.values(): node.status = NodeStatus.RUNNING - ps_manager._pre_drop_migrated_ps(list(ps_manager._nodes.values())) + update_job_node(node) + nodes = self._job_context.job_nodes_by_type(NodeType.PS) + print(f"!!!!!!!!!!!!!nodes={nodes}\n") + print(f"!!!!!!!!!!!!!migrated_ps_nodes={ps_manager._migrated_ps_nodes}\n") + ps_manager._pre_drop_migrated_ps(list(nodes.values())) self.assertEqual(len(ps_manager._pre_dropped_ps), 1) training_ps = ps_manager.get_next_training_ps_cluster() @@ -161,22 +188,25 @@ def test_parameter_server_failure(self): self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + update_job_nodes(job_nodes) ps_manager = ParameterServerManager( - job_nodes[NodeType.PS], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - for node in ps_manager._nodes.values(): + nodes = self._job_context.job_nodes_by_type(NodeType.PS) + for node in nodes.values(): node.status = NodeStatus.RUNNING + update_job_node(node) ps_failure = ps_manager.has_ps_failure() self.assertFalse(ps_failure) - latest_ps_index = len(ps_manager._nodes) - 1 - ps = ps_manager._nodes[latest_ps_index] + latest_ps_index = len(nodes) - 1 + ps = nodes[latest_ps_index] ps_manager._ps_cluster_changed = True ps.status = NodeStatus.INITIAL ps.init_time -= 600 + update_job_node(ps) ps_failure = ps_manager.has_ps_failure() self.assertTrue(ps_failure) cluster = ps_manager.get_next_training_ps_cluster() diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index e5eba00fa..82bf542f5 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -36,6 +36,7 @@ from dlrover.python.master.elastic_training.sync_service import SyncService from dlrover.python.master.monitor.speed_monitor import SpeedMonitor from dlrover.python.master.node.dist_job_manager import create_job_manager +from dlrover.python.master.node.job_context import JobContext from dlrover.python.master.servicer import MasterServicer from dlrover.python.master.shard.task_manager import TaskManager from dlrover.python.master.stats.job_collector import JobMetricCollector @@ -59,7 +60,12 @@ def setUp(self) -> None: worker_resource.node_resource.gpu_type = "a100" speed_monitor = SpeedMonitor() self.task_manager = TaskManager(False, speed_monitor) - self.job_manager = create_job_manager(params, speed_monitor) + + job_context = JobContext() + self.job_manager = create_job_manager( + params, speed_monitor, job_context + ) + self.job_manager._init_nodes() self.job_manager._init_job_auto_scaler() for node in self.job_manager._job_nodes[NodeType.WORKER].values(): diff --git a/dlrover/python/util/time_util.py b/dlrover/python/util/time_util.py new file mode 100644 index 000000000..79194ea30 --- /dev/null +++ b/dlrover/python/util/time_util.py @@ -0,0 +1,20 @@ +# Copyright 2024 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta + + +def has_expired(timestamp: float, time_period: int) -> bool: + dt = datetime.fromtimestamp(timestamp) + expired_dt = dt + timedelta(seconds=time_period) + return expired_dt < datetime.now() diff --git a/dlrover/trainer/torch/elastic_run.py b/dlrover/trainer/torch/elastic_run.py index d656e5f88..47926f492 100644 --- a/dlrover/trainer/torch/elastic_run.py +++ b/dlrover/trainer/torch/elastic_run.py @@ -114,12 +114,12 @@ TrainingExceptionLevel, ) from dlrover.python.common.log import default_logger as logger -from dlrover.python.elastic_agent.master_client import MasterClient -from dlrover.python.elastic_agent.torch.training import ( - launch_agent, +from dlrover.python.elastic_agent.config.launch_config import ( + ElasticLaunchConfig, ) +from dlrover.python.elastic_agent.master_client import MasterClient +from dlrover.python.elastic_agent.torch.training import launch_agent from dlrover.trainer.torch.utils import version_less_than_230 -from dlrover.python.elastic_agent.config.launch_config import ElasticLaunchConfig def parse_args(args): From dc0f60afb68396f541aaf770da7d7f43b941e3c0 Mon Sep 17 00:00:00 2001 From: bsang Date: Thu, 31 Oct 2024 15:55:33 +0800 Subject: [PATCH 06/10] update --- dlrover/python/common/node.py | 2 + .../elastic_agent/common/worker_context.py | 2 + dlrover/python/master/node/job_context.py | 30 ++++-- dlrover/python/master/node/ps.py | 2 +- dlrover/python/tests/test_ps_manager.py | 10 +- dlrover/python/tests/test_servicer.py | 38 +++++-- dlrover/python/tests/test_sync_service.py | 17 +++- dlrover/python/tests/test_worker_manager.py | 99 ++++++++++++------- 8 files changed, 131 insertions(+), 69 deletions(-) diff --git a/dlrover/python/common/node.py b/dlrover/python/common/node.py index c27b7fe1b..a3657bf16 100644 --- a/dlrover/python/common/node.py +++ b/dlrover/python/common/node.py @@ -366,4 +366,6 @@ def to_dict(self): return d def update_from_node(self, node): + if self == node: + return self.__dict__.update(node.__dict__) diff --git a/dlrover/python/elastic_agent/common/worker_context.py b/dlrover/python/elastic_agent/common/worker_context.py index 4161668d1..88cbcacee 100644 --- a/dlrover/python/elastic_agent/common/worker_context.py +++ b/dlrover/python/elastic_agent/common/worker_context.py @@ -75,9 +75,11 @@ def next_actions( instance=instance, action_type=action_type ) + def get_worker_context() -> WorkerContext: return WorkerContext.singleton_instance() + def update_worker_context( worker_spec: WorkerSpec = None, remaining_failovers: int = 0, diff --git a/dlrover/python/master/node/job_context.py b/dlrover/python/master/node/job_context.py index 1e2437d2c..c4569a528 100644 --- a/dlrover/python/master/node/job_context.py +++ b/dlrover/python/master/node/job_context.py @@ -37,6 +37,7 @@ def __init__(self): self._action_queue = DiagnosisActionQueue() self._job_nodes: Dict[str, Dict[int, Node]] = {} self._ps_nodes: Dict[int, Node] = {} + self._workers: Dict[int, Node] = {} self._locker = threading.Lock() def enqueue_actions(self, actions): @@ -50,17 +51,27 @@ def next_actions( def _update_job_nodes(self, job_nodes: Dict[str, Dict[int, Node]]): with self._locker: - self._job_nodes = job_nodes + self._job_nodes = copy.deepcopy(job_nodes) if NodeType.PS in self._job_nodes: self._ps_nodes = copy.deepcopy(self._job_nodes[NodeType.PS]) else: self._ps_nodes = {} + if NodeType.WORKER in self._job_nodes: + self._workers = copy.deepcopy(self._job_nodes[NodeType.WORKER]) + else: + self._workers = {} + @property def ps_nodes(self) -> Dict[int, Node]: with self._locker: return self._ps_nodes + @property + def workers(self) -> Dict[int, Node]: + with self._locker: + return self._workers + def job_nodes(self) -> Dict[str, Dict[int, Node]]: """ return a copy of job nodes @@ -68,12 +79,6 @@ def job_nodes(self) -> Dict[str, Dict[int, Node]]: with self._locker: return copy.deepcopy(self._job_nodes) - def job_nodes_if_updated(self, last_update_timestamp: float) -> Optional[Dict[str, Dict[int, Node]]]: - with self._locker: - if self._last_update_timestamp > last_update_timestamp: - return self._job_nodes - return None - def job_node(self, node_type: str, node_id: int) -> Optional[Node]: with self._locker: node_type = self._preprocess(node_type) @@ -102,6 +107,7 @@ def _update_job_node(self, node: Node): self._job_nodes = {} if node.type not in self._job_nodes: self._job_nodes[node.type] = {} + self._job_nodes[node.type][node.id] = copy.deepcopy(node) if node.type == NodeType.PS: @@ -110,10 +116,16 @@ def _update_job_node(self, node: Node): else: self._ps_nodes[node.id].update_from_node(node) + if node.type == NodeType.WORKER: + if node.id not in self._workers: + self._workers[node.id] = copy.deepcopy(node) + else: + self._workers[node.id].update_from_node(node) + def _clear_nodes(self): with self._locker: self._job_nodes = {} - self._last_update_timestamp = datetime.now().timestamp() + self._ps_nodes = {} def get_job_context() -> JobContext: @@ -130,7 +142,7 @@ def update_job_node(node: Node): if node is None: return job_context = JobContext.singleton_instance() - job_context._update_job_node(copy.deepcopy(node)) + job_context._update_job_node(node) def clear_job_nodes(): diff --git a/dlrover/python/master/node/ps.py b/dlrover/python/master/node/ps.py index 1efecdd71..b586876b2 100644 --- a/dlrover/python/master/node/ps.py +++ b/dlrover/python/master/node/ps.py @@ -368,7 +368,7 @@ def _migrate_parameter_server(self, name: str, cpu=0, memory=0): name=self._new_node_name_fn(NodeType.PS, new_ps_id), ) update_job_node(new_node) - self._migrated_ps_nodes[old_ps_id] = new_node + self._migrated_ps_nodes[old_ps_id] = self._job_context.ps_nodes[new_node.id] logger.info("Migrated PS %s to PS %s", old_ps_id, new_ps_id) return new_node diff --git a/dlrover/python/tests/test_ps_manager.py b/dlrover/python/tests/test_ps_manager.py index f74246dd4..5f27ffe25 100644 --- a/dlrover/python/tests/test_ps_manager.py +++ b/dlrover/python/tests/test_ps_manager.py @@ -148,9 +148,8 @@ def test_migrate_parameter_servers(self): self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - print(f"!!!!!!!!!!!!!migrated_ps_nodes={ps_manager._migrated_ps_nodes}\n") - nodes = self._job_context.job_nodes_by_type(NodeType.PS) + nodes = self._job_context.ps_nodes for node in nodes.values(): node.status = NodeStatus.RUNNING update_job_node(node) @@ -162,17 +161,14 @@ def test_migrate_parameter_servers(self): self.assertEqual(len(plan.launch_nodes), 1) self.assertEqual(ps_manager._migrated_ps_nodes[0].id, 2) self.assertTrue(ps_manager.exist_migrated_ps_nodes()) - print(f"!!!!!!!!!!!!!migrated_ps_nodes={ps_manager._migrated_ps_nodes}\n") - nodes = self._job_context.job_nodes_by_type(NodeType.PS) + nodes = self._job_context.ps_nodes ps_manager._pre_drop_migrated_ps(list(nodes.values())) self.assertEqual(len(ps_manager._pre_dropped_ps), 0) for node in nodes.values(): node.status = NodeStatus.RUNNING update_job_node(node) - nodes = self._job_context.job_nodes_by_type(NodeType.PS) - print(f"!!!!!!!!!!!!!nodes={nodes}\n") - print(f"!!!!!!!!!!!!!migrated_ps_nodes={ps_manager._migrated_ps_nodes}\n") + nodes = self._job_context.ps_nodes ps_manager._pre_drop_migrated_ps(list(nodes.values())) self.assertEqual(len(ps_manager._pre_dropped_ps), 1) diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index 82bf542f5..1b95327e8 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -36,7 +36,7 @@ from dlrover.python.master.elastic_training.sync_service import SyncService from dlrover.python.master.monitor.speed_monitor import SpeedMonitor from dlrover.python.master.node.dist_job_manager import create_job_manager -from dlrover.python.master.node.job_context import JobContext +from dlrover.python.master.node.job_context import get_job_context, update_job_node, clear_job_nodes from dlrover.python.master.servicer import MasterServicer from dlrover.python.master.shard.task_manager import TaskManager from dlrover.python.master.stats.job_collector import JobMetricCollector @@ -61,15 +61,17 @@ def setUp(self) -> None: speed_monitor = SpeedMonitor() self.task_manager = TaskManager(False, speed_monitor) - job_context = JobContext() self.job_manager = create_job_manager( - params, speed_monitor, job_context + params, speed_monitor ) + self.job_context = get_job_context() self.job_manager._init_nodes() self.job_manager._init_job_auto_scaler() - for node in self.job_manager._job_nodes[NodeType.WORKER].values(): + job_nodes = self.job_context.job_nodes_by_type(NodeType.WORKER) + for node in job_nodes.values(): node.status = NodeStatus.RUNNING + update_job_node(node) self.job_metric_collector = JobMetricCollector( "1", "default", "local", "dlrover" ) @@ -93,6 +95,7 @@ def setUp(self) -> None: def tearDown(self) -> None: os.environ.clear() + clear_job_nodes() def test_query_running_nodes(self): request = elastic_training_pb2.Message() @@ -173,10 +176,15 @@ def test_metric_service(self): reporter._runtime_stats = [] self.assertEqual(reporter._model_info.op_stats.flops, 10000) - worker0 = self.job_manager._job_nodes[NodeType.WORKER][0] + job_nodes = self.job_context.job_nodes() + worker0 = job_nodes[NodeType.WORKER][0] worker0.status = NodeStatus.RUNNING - ps0 = self.job_manager._job_nodes[NodeType.PS][0] + update_job_node(worker0) + + ps0 = job_nodes[NodeType.PS][0] ps0.status = NodeStatus.RUNNING + update_job_node(ps0) + request = grpc.GlobalStep() self.task_manager._speed_monitor.add_running_worker(NodeType.WORKER, 0) self.task_manager._speed_monitor.set_target_worker_num(1) @@ -211,8 +219,10 @@ def test_metric_service(self): def test_query_ps_nodes(self): self.job_manager._init_nodes() - for node in self.job_manager._job_nodes[NodeType.PS].values(): + nodes = self.job_context.job_nodes_by_type(NodeType.PS) + for node in nodes.values(): node.status = NodeStatus.RUNNING + update_job_node(node) res = self.servicer._query_ps_nodes() self.assertEqual(len(res.nodes), 3) self.assertEqual( @@ -401,7 +411,8 @@ def test_report_heartbeat(self): request.node_type = NodeType.WORKER request.node_id = 0 self.servicer.get(request, None) - worker0 = self.servicer._job_manager._job_nodes[NodeType.WORKER][0] + + worker0 = self.job_context.job_node(NodeType.WORKER, 0) self.assertEqual(worker0.heartbeat_time, ts) def test_sync_checkpoint(self): @@ -456,6 +467,10 @@ def setUp(self) -> None: job_metric_collector=self.job_metric_collector, elastic_ps_service=self.elastic_ps_service, ) + self.job_context = get_job_context() + + def tearDown(self) -> None: + clear_job_nodes() def test_update_node_addr(self): request = grpc.NodeMeta() @@ -467,11 +482,14 @@ def test_update_node_addr(self): request.addr = "localhost:5001" self.job_manager._init_nodes() self.servicer._update_node_address(request) + node = self.job_context.job_node(task_type, task_id) self.assertEqual( - self.job_manager._job_nodes[task_type][task_id].service_addr, addr + node.service_addr, addr ) - for node in self.job_manager._job_nodes[NodeType.PS].values(): + ps_nodes = self.job_context.job_nodes_by_type(NodeType.PS) + for node in ps_nodes.values(): node.status = NodeStatus.RUNNING + update_job_node(node) res = self.servicer._query_ps_nodes() self.assertEqual(addr, res.nodes[task_id].addr) self.assertEqual("", res.nodes[0].addr) diff --git a/dlrover/python/tests/test_sync_service.py b/dlrover/python/tests/test_sync_service.py index bbc8140a1..e9d2844d3 100644 --- a/dlrover/python/tests/test_sync_service.py +++ b/dlrover/python/tests/test_sync_service.py @@ -18,6 +18,7 @@ from dlrover.python.master.monitor.speed_monitor import SpeedMonitor from dlrover.python.master.node.dist_job_manager import create_job_manager from dlrover.python.tests.test_utils import MockK8sPSJobArgs, mock_k8s_client +from dlrover.python.master.node.job_context import get_job_context, update_job_node, clear_job_nodes class SyncServiceTest(unittest.TestCase): @@ -28,22 +29,30 @@ def setUp(self) -> None: params.initilize() self.job_manager = create_job_manager(params, SpeedMonitor()) self.job_manager._init_nodes() + self.job_context = get_job_context() + + def tearDown(self): + clear_job_nodes() def test_sync(self): sync_service = SyncService(self.job_manager) - for node in self.job_manager._job_nodes[NodeType.CHIEF].values(): + job_nodes = self.job_context.job_nodes() + for node in job_nodes[NodeType.CHIEF].values(): node.status = NodeStatus.RUNNING + update_job_node(node) - for node in self.job_manager._job_nodes[NodeType.WORKER].values(): + for node in job_nodes[NodeType.WORKER].values(): node.status = NodeStatus.RUNNING + update_job_node(node) sync_name = "sync-0" - for node in self.job_manager._job_nodes[NodeType.CHIEF].values(): + job_nodes = self.job_context.job_nodes() + for node in job_nodes[NodeType.CHIEF].values(): sync_service.join_sync(sync_name, node.type, node.id) finished = sync_service.sync_finished(sync_name) self.assertFalse(finished) - for node in self.job_manager._job_nodes[NodeType.WORKER].values(): + for node in job_nodes[NodeType.WORKER].values(): sync_service.join_sync(sync_name, node.type, node.id) finished = sync_service.sync_finished(sync_name) self.assertTrue(finished) diff --git a/dlrover/python/tests/test_worker_manager.py b/dlrover/python/tests/test_worker_manager.py index 2588d6db9..d74679822 100644 --- a/dlrover/python/tests/test_worker_manager.py +++ b/dlrover/python/tests/test_worker_manager.py @@ -28,6 +28,7 @@ from dlrover.python.master.resource.job import JobResource from dlrover.python.scheduler.factory import new_elastic_job from dlrover.python.tests.test_utils import mock_k8s_client +from dlrover.python.master.node.job_context import get_job_context, update_job_node, clear_job_nodes, update_job_nodes _dlrover_ctx = Context.singleton_instance() @@ -42,27 +43,33 @@ def setUp(self) -> None: self._elastic_job = new_elastic_job( PlatformType.KUBERNETES, "test", "default" ) - self._job_nodes = self._job_resource.init_job_node_meta( + + self.job_context = get_job_context() + job_nodes = self._job_resource.init_job_node_meta( 1, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + update_job_nodes(job_nodes) self._worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) + def tearDown(self) -> None: + clear_job_nodes() + def test_scale_up_workers(self): self._worker_manager._scale_up_workers(3) - self.assertEqual(len(self._worker_manager._nodes), 8) - self.assertEqual(self._worker_manager._nodes[7].id, 7) + workers = self.job_context.workers + self.assertEqual(len(workers), 8) + self.assertEqual(workers[7].id, 7) def test_scale_down_workers(self): - workers = list(self._worker_manager._nodes.values()) + workers = list(self.job_context.workers.values()) self._worker_manager._scale_down_workers(2, workers) released_workers = [] for worker in workers: @@ -71,8 +78,11 @@ def test_scale_down_workers(self): self.assertEqual(len(released_workers), 2) def test_delete_exited_workers(self): - self._worker_manager._nodes[3].status = NodeStatus.FINISHED - self._worker_manager._nodes[4].status = NodeStatus.FAILED + workers = self.job_context.workers + workers[3].status = NodeStatus.FINISHED + update_job_node(workers[3]) + workers[4].status = NodeStatus.FAILED + update_job_node(workers[4]) plan = self._worker_manager.delete_exited_workers() node_names = [node.name for node in plan.remove_nodes] @@ -82,8 +92,9 @@ def test_delete_exited_workers(self): ) def test_delete_running_workers(self): - for node in self._worker_manager._nodes.values(): + for node in self.job_context.workers.values(): node.status = NodeStatus.RUNNING + update_job_node(node) plan = self._worker_manager.delete_running_workers() node_names = [node.name for node in plan.remove_nodes] self.assertListEqual( @@ -99,19 +110,19 @@ def test_delete_running_workers(self): def test_relaunch_node(self): worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - failed_worker = self._job_nodes[NodeType.WORKER][4] + failed_worker = self.job_context.workers[4] failed_worker.status = NodeStatus.FAILED + update_job_node(failed_worker) plan = worker_manager.relaunch_node( failed_worker, remove_exited_node=True ) self.assertEqual(plan.launch_nodes[0].config_resource.cpu, 16) - self.assertEqual(worker_manager._nodes[5].id, 5) + self.assertEqual(self.job_context.workers[5].id, 5) self.assertEqual(plan.remove_nodes[0].config_resource.cpu, 16) def test_relaunch_chief_node(self): @@ -120,71 +131,78 @@ def test_relaunch_chief_node(self): node_id=0, config_resource=NodeResource(cpu=16, memory=10240), ) + job_nodes = { + NodeType.MASTER: {0: tf_master_node}, + } + update_job_nodes(job_nodes) manager = ChiefManager( - {0: tf_master_node}, self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) plan = manager.relaunch_node(tf_master_node) + nodes = self.job_context.job_nodes_by_type(NodeType.CHIEF) self.assertEqual(plan.launch_nodes[0].config_resource.cpu, 16) - self.assertEqual(manager._nodes[1].id, 1) + self.assertEqual(nodes[1].id, 1) def test_reduce_pending_node_resource(self): worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - for node in worker_manager._nodes.values(): + for node in self.job_context.workers.values(): node.status = NodeStatus.PENDING node.create_time = datetime.now() + timedelta(days=-1) + update_job_node(node) plan = worker_manager.reduce_pending_node_resource() self.assertEqual(len(plan.launch_nodes), 5) - for node in worker_manager._nodes.values(): + for node in self.job_context.workers.values(): node.config_resource.gpu_num = 1 + update_job_node(node) plan = worker_manager.reduce_pending_node_resource() self.assertTrue(plan.empty()) def test_pending_without_workers(self): worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, self._elastic_job.get_node_name, ) - for node in worker_manager._nodes.values(): + for node in self.job_context.workers.values(): node.status = NodeStatus.FAILED node.exit_reason = NodeExitReason.FATAL_ERROR + update_job_node(node) exited = worker_manager.has_exited_worker() self.assertTrue(exited) - for node in worker_manager._nodes.values(): + for node in self.job_context.workers.values(): node.exit_reason = NodeExitReason.KILLED + update_job_node(node) exited = worker_manager.has_exited_worker() self.assertFalse(exited) - worker_manager._nodes[0].status = NodeStatus.SUCCEEDED + self.job_context.workers[0].status = NodeStatus.SUCCEEDED + update_job_node(self.job_context.workers[0]) exited = worker_manager.has_exited_worker() self.assertTrue(exited) wait = worker_manager.wait_worker_restart() self.assertTrue(wait) - for node in worker_manager._nodes.values(): + for node in self.job_context.workers.values(): node.relaunch_count = node.max_relaunch_count + update_job_node(node) wait = worker_manager.wait_worker_restart() self.assertFalse(wait) def test_verify_restarting_training(self): worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, @@ -192,16 +210,17 @@ def test_verify_restarting_training(self): ) reset = worker_manager.verify_restarting_training(0) self.assertFalse(reset) - worker_manager._nodes[0].restart_training = True + self.job_context.workers[0].restart_training = True + update_job_node(self.job_context.workers[0]) reset = worker_manager.verify_restarting_training(0) self.assertTrue(reset) - worker_manager._nodes[0].is_released = True + self.job_context.workers[0].is_released = True + update_job_node(self.job_context.workers[0]) reset = worker_manager.verify_restarting_training(0) self.assertFalse(reset) def test_is_training_hang_by_pending(self): worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, @@ -234,11 +253,12 @@ def test_is_training_hang_by_pending(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + update_job_node(mock_node) self.assertFalse( worker_manager.is_training_hang_by_pending(worker_num) ) mock_nodes.clear() + clear_job_nodes() # mock with 3 running + 1 pending long time for index in range(4): @@ -255,9 +275,10 @@ def test_is_training_hang_by_pending(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + update_job_node(mock_node) self.assertTrue(worker_manager.is_training_hang_by_pending(worker_num)) mock_nodes.clear() + clear_job_nodes() # mock with 4 running + 1 pending long time worker_num = 5 @@ -275,11 +296,12 @@ def test_is_training_hang_by_pending(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + update_job_node(mock_node) self.assertFalse( worker_manager.is_training_hang_by_pending(worker_num) ) mock_nodes.clear() + clear_job_nodes() # mock with 3 running + 1 initial long time worker_num = 4 @@ -297,7 +319,7 @@ def test_is_training_hang_by_pending(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + update_job_node(mock_node) self.assertTrue(worker_manager.is_training_hang_by_pending(worker_num)) # ============================================= @@ -321,7 +343,7 @@ def test_is_training_hang_by_pending(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-10) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + update_job_node(mock_node) self.assertFalse( worker_manager.is_training_hang_by_pending(worker_num) ) @@ -337,7 +359,7 @@ def test_is_training_hang_by_pending(self): ) mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + update_job_node(mock_node) self.assertTrue(worker_manager.is_training_hang_by_pending(worker_num)) # mock with 2 pending long time @@ -352,7 +374,7 @@ def test_is_training_hang_by_pending(self): ) mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + update_job_node(mock_node) self.assertTrue(worker_manager.is_training_hang_by_pending(worker_num)) # mock with 2 pending + 1 running long time @@ -369,7 +391,7 @@ def test_is_training_hang_by_pending(self): mock_node.status = NodeStatus.RUNNING mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + update_job_node(mock_node) self.assertFalse( worker_manager.is_training_hang_by_pending(worker_num) ) @@ -386,14 +408,13 @@ def test_is_training_hang_by_pending(self): ) mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + update_job_node(mock_node) self.assertFalse( worker_manager.is_training_hang_by_pending(worker_num) ) def test_is_training_hang_by_insufficient_worker(self): worker_manager = WorkerManager( - self._job_nodes[NodeType.WORKER], self._job_resource, 3, self._elastic_job.get_node_service_addr, @@ -426,7 +447,7 @@ def test_is_training_hang_by_insufficient_worker(self): if index == 0: mock_node.status = NodeStatus.PENDING mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + update_job_node(mock_node) for _ in range(5): if worker_manager.is_training_hang_by_insufficient_worker(): is_insufficient += 1 @@ -434,6 +455,7 @@ def test_is_training_hang_by_insufficient_worker(self): self.assertEqual(is_insufficient, 0) mock_nodes.clear() is_insufficient = 0 + clear_job_nodes() # mock with 3 running for index in range(3): @@ -445,7 +467,7 @@ def test_is_training_hang_by_insufficient_worker(self): NodeStatus.RUNNING, ) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes + update_job_node(mock_node) for _ in range(5): if worker_manager.is_training_hang_by_insufficient_worker(): is_insufficient += 1 @@ -453,6 +475,7 @@ def test_is_training_hang_by_insufficient_worker(self): self.assertTrue(is_insufficient >= 2) mock_nodes.clear() is_insufficient = 0 + clear_job_nodes() # mock with 3 running + 1 released for index in range(4): @@ -466,8 +489,8 @@ def test_is_training_hang_by_insufficient_worker(self): if index == 0: mock_node.status = NodeStatus.DELETED mock_node.is_released = True + update_job_node(mock_node) mock_nodes[index] = mock_node - worker_manager._nodes = mock_nodes for _ in range(5): if worker_manager.is_training_hang_by_insufficient_worker(): is_insufficient += 1 From 9c036f50a5fc0f14f3bf36356687b9e3c60501f2 Mon Sep 17 00:00:00 2001 From: bsang Date: Thu, 31 Oct 2024 16:18:45 +0800 Subject: [PATCH 07/10] update --- dlrover/python/tests/test_elastic_training_agent.py | 1 + dlrover/python/tests/test_pod_scaler.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index 1f5e2ad99..48289ee35 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -112,6 +112,7 @@ def setUp(self) -> None: def tearDown(self): self._master.stop() + os.environ.clear() def test_node_unit(self): node_unit = int(self.rdzv_handler._rdzv_params.get("node_unit", "1")) diff --git a/dlrover/python/tests/test_pod_scaler.py b/dlrover/python/tests/test_pod_scaler.py index efec29bdf..431be9e11 100644 --- a/dlrover/python/tests/test_pod_scaler.py +++ b/dlrover/python/tests/test_pod_scaler.py @@ -40,6 +40,9 @@ def setUp(self) -> None: os.environ["POD_IP"] = "127.0.0.1" mock_k8s_client() + def tearDown(self) -> None: + os.environ.clear() + def test_init_pod_template(self): error_monitor = SimpleErrorMonitor() scaler = PodScaler("elasticjob-sample", "default", error_monitor) From 43dce774755821a9fefb1b72032abb9a0d86ea70 Mon Sep 17 00:00:00 2001 From: bsang Date: Thu, 31 Oct 2024 17:38:21 +0800 Subject: [PATCH 08/10] update unit test --- dlrover/python/master/node/job_context.py | 1 + dlrover/python/master/node/ps.py | 2 +- dlrover/python/tests/test_job_manager.py | 60 ++++++++++----------- dlrover/python/tests/test_ps_manager.py | 10 ++-- dlrover/python/tests/test_worker_manager.py | 11 +--- 5 files changed, 37 insertions(+), 47 deletions(-) diff --git a/dlrover/python/master/node/job_context.py b/dlrover/python/master/node/job_context.py index c4569a528..d58f05d7d 100644 --- a/dlrover/python/master/node/job_context.py +++ b/dlrover/python/master/node/job_context.py @@ -126,6 +126,7 @@ def _clear_nodes(self): with self._locker: self._job_nodes = {} self._ps_nodes = {} + self._workers = {} def get_job_context() -> JobContext: diff --git a/dlrover/python/master/node/ps.py b/dlrover/python/master/node/ps.py index 906bbde69..a97edf99d 100644 --- a/dlrover/python/master/node/ps.py +++ b/dlrover/python/master/node/ps.py @@ -424,7 +424,7 @@ def is_training_hang_by_pending(self, total_node_num, job_type) -> bool: return False # collect pending and running nodes - cur_nodes = list(self._nodes.values()) + cur_nodes = list(self._job_context.ps_nodes.values()) pending_ps: List[Node] = [] running_ps: List[Node] = [] for node in cur_nodes: diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py index 2880699da..b2d78b8d4 100644 --- a/dlrover/python/tests/test_job_manager.py +++ b/dlrover/python/tests/test_job_manager.py @@ -75,7 +75,7 @@ mock_k8s_client, new_dataset_splitter, ) -from dlrover.python.master.node.job_context import get_job_context, update_job_node, clear_job_nodes +from dlrover.python.master.node.job_context import get_job_context, update_job_node, clear_job_nodes, update_job_nodes _MOCK_JOB_UUID = "11111" @@ -122,6 +122,7 @@ def test_get_node_state_flow(self): class DistributedJobManagerTest(unittest.TestCase): def setUp(self) -> None: mock_k8s_client() + self.job_context = get_job_context() def tearDown(self): clear_job_nodes() @@ -220,7 +221,6 @@ def test_get_critical_worker_index(self): self.assertDictEqual(critical_worker, {}) def test_relaunch_node(self): - job_context = get_job_context() params = MockK8sPSJobArgs() params.initilize() manager = create_job_manager(params, SpeedMonitor()) @@ -228,7 +228,7 @@ def test_relaunch_node(self): manager.start() self.assertEqual(manager._job_args.job_uuid, _MOCK_JOB_UUID) - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() self.assertEqual(len(job_nodes), 4) self.assertTrue(job_nodes[NodeType.PS][0].critical) @@ -252,7 +252,7 @@ def test_relaunch_node(self): manager.update_node_resource_usage( NodeType.WORKER, 0, 0.7, 2048, gpu_stats ) # noqa - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() self.assertEqual( job_nodes[NodeType.WORKER][0].used_resource.cpu, 0.7 ) @@ -266,7 +266,7 @@ def test_relaunch_node(self): node_event: NodeEvent = NodeEvent(NodeEventType.MODIFIED, node) manager._process_event(node_event) - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() self.assertEqual( job_nodes[NodeType.WORKER][1].status, NodeStatus.RUNNING ) @@ -362,15 +362,16 @@ def test_get_dead_node_event(self): ts = int(time.time()) manager.collect_node_heart_beat(NodeType.WORKER, 0, ts) - job_context = get_job_context() - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() worker0 = job_nodes[NodeType.WORKER][0] self.assertEqual(worker0.heartbeat_time, ts) for node in job_nodes[NodeType.WORKER].values(): node.status = NodeStatus.RUNNING + update_job_node(node) events = manager._get_dead_node_event() self.assertEqual(len(events), 0) + job_nodes = self.job_context.job_nodes() for index, node in enumerate( job_nodes[NodeType.WORKER].values() ): @@ -391,7 +392,7 @@ def test_get_dead_node_event(self): self.assertIsNotNone(nodes_time_info) self.assertEqual(len(nodes_time_info), 3) - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() for index, node in enumerate( job_nodes[NodeType.WORKER].values() ): @@ -430,8 +431,7 @@ def test_process_list_nodes(self): params.initilize() manager = create_job_manager(params, SpeedMonitor()) manager._init_nodes() - job_context = get_job_context() - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.PENDING update_job_node(node) @@ -447,7 +447,7 @@ def test_process_list_nodes(self): nodes.append(node) manager._process_list_nodes(nodes) - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() ps_ids = list(job_nodes[NodeType.PS].keys()) self.assertListEqual(ps_ids, [0, 1, 2]) @@ -456,7 +456,7 @@ def test_process_list_nodes_for_empty_case(self, mock_method): params = MockK8sPSJobArgs() params.initilize() manager = create_job_manager(params, SpeedMonitor()) - manager._job_nodes = { + job_nodes = { NodeType.PS: { 0: Node( node_type=NodeType.PS, @@ -476,6 +476,7 @@ def test_process_list_nodes_for_empty_case(self, mock_method): ) }, } + update_job_nodes(job_nodes) manager._process_list_nodes([]) self.assertEqual(mock_method.call_count, 2) @@ -492,11 +493,10 @@ def test_create_allreduce_job_manager(self): manager._init_nodes() manager._init_job_auto_scaler() - job_context = get_job_context() - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() self.assertEqual(len(job_nodes[NodeType.WORKER]), 3) manager.start_auto_scaling() - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() self.assertEqual(len(job_nodes[NodeType.WORKER]), 3) def test_recover_tasks_for_failed_workers(self): @@ -564,21 +564,18 @@ def test_check_worker_status(self): manager._init_nodes() self.assertFalse(manager.all_workers_exited()) - job_context = get_job_context() - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() for worker in job_nodes[NodeType.WORKER].values(): worker.status = NodeStatus.FINISHED - update_job_node(worker) for worker in job_nodes[NodeType.CHIEF].values(): worker.status = NodeStatus.FINISHED - update_job_node(worker) for worker in job_nodes[NodeType.EVALUATOR].values(): worker.status = NodeStatus.FINISHED - update_job_node(worker) + update_job_nodes(job_nodes) self.assertTrue(manager.all_workers_exited()) - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() for worker in job_nodes[NodeType.WORKER].values(): worker.status = NodeStatus.FAILED update_job_node(worker) @@ -590,7 +587,7 @@ def test_check_worker_status(self): update_job_node(worker) self.assertTrue(manager.all_workers_failed()) - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() for worker in job_nodes[NodeType.PS].values(): worker.status = NodeStatus.FINISHED update_job_node(worker) @@ -641,8 +638,7 @@ def test_all_running_node_hang(self): hang = manager.all_running_node_hanged() self.assertFalse(hang) - job_context = get_job_context() - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() for _, nodes in job_nodes.items(): for _, node in nodes.items(): node.start_hang_time = time.time() - 3600 * 4 @@ -661,8 +657,7 @@ def test_early_stop_part1(self): manager = create_job_manager(params, SpeedMonitor()) manager._init_nodes() - job_context = get_job_context() - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.PENDING node.is_recovered_oom = True @@ -674,11 +669,11 @@ def test_early_stop_part1(self): self.assertFalse(msg) manager._remove_exited_node = True - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() job_nodes[NodeType.WORKER][0].status = NodeStatus.FAILED update_job_node(job_nodes[NodeType.WORKER][0]) manager.clear_exited_nodes() - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() self.assertTrue(job_nodes[NodeType.WORKER][0].is_released) for node in job_nodes[NodeType.PS].values(): @@ -691,7 +686,7 @@ def test_early_stop_part1(self): self.assertTrue(reason) self.assertTrue(msg) - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() for node in job_nodes[NodeType.PS].values(): node.status = NodeStatus.RUNNING node.create_time = datetime.now() + timedelta(days=-1) @@ -797,8 +792,7 @@ def test_concurrency_heart_beat_collecting(self): manager = create_job_manager(params, SpeedMonitor()) manager.start() - job_context = get_job_context() - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() self.assertEqual(len(job_nodes[NodeType.WORKER]), worker_size) for i, node in job_nodes[NodeType.WORKER].items(): self.assertEqual(node.id, i) @@ -816,15 +810,15 @@ def test_concurrency_heart_beat_collecting(self): future.result() self.assertEqual(len(futures), worker_size) - job_nodes = job_context.job_nodes() + job_nodes = self.job_context.job_nodes() for i, node in job_nodes[NodeType.WORKER].items(): self.assertEqual(node.id, i) self.assertEqual(node.heartbeat_time, i) manager.stop() + clear_job_nodes() # test when job manager not init - clear_job_nodes() try: manager.collect_node_heart_beat("worker", 1, 111) except Exception: diff --git a/dlrover/python/tests/test_ps_manager.py b/dlrover/python/tests/test_ps_manager.py index 0a053df37..dea0e6c15 100644 --- a/dlrover/python/tests/test_ps_manager.py +++ b/dlrover/python/tests/test_ps_manager.py @@ -223,7 +223,6 @@ def test_parameter_server_failure(self): def test_is_training_hang_by_pending_ps(self): _dlrover_ctx.pending_fail_strategy = 1 ps_manager = ParameterServerManager( - self._job_nodes[NodeType.PS], self._job_resource, 3, self._elastic_job.get_node_service_addr, @@ -260,7 +259,7 @@ def test_is_training_hang_by_pending_ps(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - ps_manager._nodes = mock_nodes + update_job_node(mock_node) self.assertFalse( ps_manager.is_training_hang_by_pending( ps_num, DistributionStrategy.ALLREDUCE @@ -272,6 +271,7 @@ def test_is_training_hang_by_pending_ps(self): ) ) mock_nodes.clear() + clear_job_nodes() # mock with 3 running + 1 pending long time for index in range(4): @@ -288,7 +288,7 @@ def test_is_training_hang_by_pending_ps(self): else: mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node - ps_manager._nodes = mock_nodes + update_job_node(mock_node) self.assertFalse( ps_manager.is_training_hang_by_pending( ps_num, DistributionStrategy.ALLREDUCE @@ -300,6 +300,7 @@ def test_is_training_hang_by_pending_ps(self): ) ) mock_nodes.clear() + clear_job_nodes() # mock with 4 running for index in range(4): @@ -311,7 +312,7 @@ def test_is_training_hang_by_pending_ps(self): NodeStatus.RUNNING, ) mock_nodes[index] = mock_node - ps_manager._nodes = mock_nodes + update_job_node(mock_node) self.assertFalse( ps_manager.is_training_hang_by_pending( ps_num, DistributionStrategy.ALLREDUCE @@ -323,3 +324,4 @@ def test_is_training_hang_by_pending_ps(self): ) ) mock_nodes.clear() + clear_job_nodes() diff --git a/dlrover/python/tests/test_worker_manager.py b/dlrover/python/tests/test_worker_manager.py index 9c716f472..92c0b538d 100644 --- a/dlrover/python/tests/test_worker_manager.py +++ b/dlrover/python/tests/test_worker_manager.py @@ -126,7 +126,6 @@ def test_relaunch_node(self): self.assertEqual(plan.launch_nodes[0].config_resource.cpu, 16) self.assertEqual(self.job_context.workers[5].id, 5) self.assertEqual(plan.launch_nodes[0].max_relaunch_count, 3) - self.assertEqual(worker_manager._nodes[5].id, 5) self.assertEqual(plan.remove_nodes[0].config_resource.cpu, 16) def test_relaunch_chief_node(self): @@ -306,8 +305,7 @@ def test_is_training_hang_by_pending_workers(self): mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node update_job_node(mock_node) - self.assertTrue(worker_manager.is_training_hang_by_pending(worker_num)) - worker_manager._nodes = mock_nodes + self.assertTrue( worker_manager.is_training_hang_by_pending( worker_num, DistributionStrategy.ALLREDUCE @@ -368,8 +366,6 @@ def test_is_training_hang_by_pending_workers(self): mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node update_job_node(mock_node) - self.assertTrue(worker_manager.is_training_hang_by_pending(worker_num)) - worker_manager._nodes = mock_nodes self.assertTrue( worker_manager.is_training_hang_by_pending( worker_num, DistributionStrategy.ALLREDUCE @@ -426,8 +422,7 @@ def test_is_training_hang_by_pending_workers(self): mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node update_job_node(mock_node) - self.assertTrue(worker_manager.is_training_hang_by_pending(worker_num)) - worker_manager._nodes = mock_nodes + self.assertTrue( worker_manager.is_training_hang_by_pending( worker_num, DistributionStrategy.ALLREDUCE @@ -452,8 +447,6 @@ def test_is_training_hang_by_pending_workers(self): mock_node.create_time = datetime.now() + timedelta(minutes=-20) mock_nodes[index] = mock_node update_job_node(mock_node) - self.assertTrue(worker_manager.is_training_hang_by_pending(worker_num)) - worker_manager._nodes = mock_nodes self.assertTrue( worker_manager.is_training_hang_by_pending( worker_num, DistributionStrategy.ALLREDUCE From 82dd34363c475e9ef7a7492a364ebef58c6e2471 Mon Sep 17 00:00:00 2001 From: bsang Date: Thu, 31 Oct 2024 21:43:08 +0800 Subject: [PATCH 09/10] fix pre-commit --- .../diagnosis/common/diagnosis_action.py | 29 +++--- .../python/elastic_agent/common/__init__.py | 1 - .../elastic_agent/common/worker_context.py | 22 ++--- .../elastic_agent/config/launch_config.py | 14 +-- .../diagnosis/diagnosis_agent.py | 8 +- dlrover/python/elastic_agent/master_client.py | 22 +++-- .../python/elastic_agent/torch/training.py | 89 +------------------ .../master/diagnosis/diagnosis_manager.py | 12 ++- dlrover/python/master/dist_master.py | 1 - .../python/master/monitor/error_monitor.py | 22 +++-- .../python/master/node/dist_job_manager.py | 9 +- dlrover/python/master/node/job_auto_scaler.py | 8 +- dlrover/python/master/node/job_context.py | 16 ++-- dlrover/python/master/node/job_manager.py | 10 +-- .../python/master/node/local_job_manager.py | 3 +- dlrover/python/master/node/ps.py | 15 +++- dlrover/python/master/node/training_node.py | 9 +- dlrover/python/master/node/worker.py | 10 +-- dlrover/python/master/servicer.py | 6 +- dlrover/python/tests/test_diagnosis_agent.py | 18 ++-- dlrover/python/tests/test_job_auto_scaler.py | 5 +- dlrover/python/tests/test_job_manager.py | 27 +++--- dlrover/python/tests/test_master.py | 6 +- dlrover/python/tests/test_ps_manager.py | 7 +- dlrover/python/tests/test_servicer.py | 14 +-- dlrover/python/tests/test_sync_service.py | 6 +- dlrover/python/tests/test_worker_manager.py | 7 +- 27 files changed, 183 insertions(+), 213 deletions(-) diff --git a/dlrover/python/diagnosis/common/diagnosis_action.py b/dlrover/python/diagnosis/common/diagnosis_action.py index 611feeef0..c0eb791c6 100644 --- a/dlrover/python/diagnosis/common/diagnosis_action.py +++ b/dlrover/python/diagnosis/common/diagnosis_action.py @@ -10,10 +10,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import threading from abc import ABCMeta from datetime import datetime -from typing import List, Dict +from typing import Dict, List from dlrover.python.common.constants import NodeType from dlrover.python.common.log import default_logger as logger @@ -22,11 +23,16 @@ DiagnosisConstant, ) from dlrover.python.util.time_util import has_expired -import json class DiagnosisAction(metaclass=ABCMeta): - def __init__(self, action_type: str, instance: int, timestamp=0, expired_time_period=0): + def __init__( + self, + action_type: str, + instance: int, + timestamp=0, + expired_time_period=0, + ): self.action_type = action_type self.instance = instance if timestamp == 0: @@ -63,7 +69,10 @@ def __init__( instance=DiagnosisConstant.LOCAL_INSTANCE, ): super().__init__( - DiagnosisActionConstants.TYPE_NODE, instance, timestamp, expired_time_period + DiagnosisActionConstants.TYPE_NODE, + instance, + timestamp, + expired_time_period, ) self.action = action self.node_type = node_type @@ -73,13 +82,6 @@ def update_action(self, action: str): def is_same_action(action1: DiagnosisAction, action2: DiagnosisAction) -> bool: - if ( - action1.action_type == DiagnosisActionConstants.TYPE_NODE - and action2.action_type == DiagnosisActionConstants.TYPE_NODE - and action1.action == action2.action - and action1.rank == action2.rank - ): - return True return False @@ -125,7 +127,10 @@ def next_actions( remain_actions = [] actions = self._actions[instance] for action in actions: - if action_type == DiagnosisActionConstants.TYPE_NODE or action_type == action.action_type: + if ( + action_type == DiagnosisActionConstants.TYPE_NODE + or action_type == action.action_type + ): deque_actions.append(action) else: remain_actions.append(action) diff --git a/dlrover/python/elastic_agent/common/__init__.py b/dlrover/python/elastic_agent/common/__init__.py index 73d1d62c2..0c742f3fa 100644 --- a/dlrover/python/elastic_agent/common/__init__.py +++ b/dlrover/python/elastic_agent/common/__init__.py @@ -10,4 +10,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/dlrover/python/elastic_agent/common/worker_context.py b/dlrover/python/elastic_agent/common/worker_context.py index 88cbcacee..b45636d2c 100644 --- a/dlrover/python/elastic_agent/common/worker_context.py +++ b/dlrover/python/elastic_agent/common/worker_context.py @@ -15,14 +15,14 @@ from torch.distributed.elastic.agent.server.api import RunResult, WorkerSpec -from dlrover.python.diagnosis.common.diagnosis_action import ( - DiagnosisAction, - DiagnosisActionQueue, -) from dlrover.python.common.singleton import Singleton from dlrover.python.diagnosis.common.constants import ( - DiagnosisConstant, DiagnosisActionConstants, + DiagnosisConstant, +) +from dlrover.python.diagnosis.common.diagnosis_action import ( + DiagnosisAction, + DiagnosisActionQueue, ) @@ -58,7 +58,7 @@ def _update_context( restart_count: int = 0, run_result: RunResult = None, ): - self._worker_spec: WorkerSpec = worker_spec + self._worker_spec = worker_spec self.remaining_failovers = remaining_failovers self.restart_count = restart_count self._run_result = run_result @@ -67,9 +67,9 @@ def enqueue_diagnose_action(self, action: DiagnosisAction): self._diagnose_action_queue.add_action(action) def next_actions( - self, - instance=DiagnosisConstant.LOCAL_INSTANCE, - action_type=DiagnosisActionConstants.ACTION_TYPE_ANY, + self, + instance=DiagnosisConstant.LOCAL_INSTANCE, + action_type=DiagnosisActionConstants.ACTION_TYPE_ANY, ) -> List[DiagnosisAction]: return self._diagnose_action_queue.next_actions( instance=instance, action_type=action_type @@ -87,4 +87,6 @@ def update_worker_context( run_result: RunResult = None, ): worker_context = get_worker_context() - worker_context._update_context(worker_spec, remaining_failovers, restart_count, run_result) + worker_context._update_context( + worker_spec, remaining_failovers, restart_count, run_result + ) diff --git a/dlrover/python/elastic_agent/config/launch_config.py b/dlrover/python/elastic_agent/config/launch_config.py index bcd53fe76..2f5d0ba01 100644 --- a/dlrover/python/elastic_agent/config/launch_config.py +++ b/dlrover/python/elastic_agent/config/launch_config.py @@ -13,26 +13,16 @@ import os from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Union import torch -from torch.distributed.elastic.multiprocessing import ( - PContext, - SignalException, - Std, -) +from torch.distributed.elastic.multiprocessing import Std from torch.distributed.launcher.api import LaunchConfig from dlrover.python.common.constants import ( Accelerators, AscendConstants, - ConfigPath, - JobConstant, NodeEnv, - NodeErrorMessage, - NodeStatus, - RendezvousName, - TrainingExceptionLevel, ) from dlrover.python.common.log import default_logger as logger diff --git a/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py b/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py index 600736cad..fbc60f7d6 100644 --- a/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py +++ b/dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py @@ -28,7 +28,10 @@ DiagnosisConstant, InferenceConfigKey, ) -from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction, DiagnosisNodeAction +from dlrover.python.diagnosis.common.diagnosis_action import ( + DiagnosisAction, + DiagnosisNodeAction, +) from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric from dlrover.python.diagnosis.common.inference_chain import ( Inference, @@ -190,7 +193,8 @@ def diagnose_training_failure(self) -> DiagnosisAction: f"[{self._worker_context.worker_spec.role}] Worker group " f"{self._worker_context.run_result.state.name}, " f"is failure node: {failure_node}, " - f"no attempts({self._worker_context.worker_spec.max_restarts}) " + f"no attempts(" + f"{self._worker_context.worker_spec.max_restarts}) " "left; will relaunch." ) return DiagnosisNodeAction( diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index 6b4961884..9f4622be0 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -11,12 +11,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib import os import socket import threading import time from contextlib import closing -from typing import Dict, List +from typing import Dict, List, Optional from dlrover.proto import elastic_training_pb2, elastic_training_pb2_grpc from dlrover.python.common import env_utils, grpc @@ -86,6 +87,10 @@ def __init__(self, master_addr, node_id, node_type, timeout=5): self._worker_local_process_id = int(os.getenv("LOCAL_RANK", 0)) self._ddp_server_port = self.find_free_port() + self._diagnosis_action_module = importlib.import_module( + "dlrover.python.diagnosis.common.diagnosis_action" + ) + def __del__(self): if self._channel: self._channel.close() @@ -237,12 +242,17 @@ def report_heart_beat(self, timestamp) -> List[DiagnosisAction]: response: grpc.HeartbeatResponse = self._get(message) actions: List[DiagnosisAction] = [] for grpc_action in response.diagnosis_actions: - action = DiagnosisAction( - rank=grpc_action.rank, - timestamp=grpc_action.timestamp, - expired_time_period=grpc_action.expired_time_period, - action=grpc_action.action, + action_cls: Optional[DiagnosisData] = getattr( + self._diagnosis_action_module, + grpc_action.action_cls, ) + if action_cls is None: + logger.warning( + "Invalid diagnosis action " + f"action type: {grpc_action.action_cls}" + ) + continue + action = action_cls.from_json(grpc_action.action_content) actions.append(action) return actions diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index b506072ef..ded25e3d4 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -23,7 +23,6 @@ from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from contextlib import closing -from dataclasses import dataclass from datetime import datetime, timedelta from typing import ( Any, @@ -36,7 +35,6 @@ Union, ) -import torch import torch.distributed.elastic.timer as timer from torch.distributed import PrefixStore, Store from torch.distributed.elastic import events, metrics @@ -55,18 +53,14 @@ ) from torch.distributed.elastic.metrics import put_metric from torch.distributed.elastic.metrics.api import prof -from torch.distributed.elastic.multiprocessing import ( - PContext, - SignalException, - Std, -) +from torch.distributed.elastic.multiprocessing import PContext, SignalException from torch.distributed.elastic.multiprocessing.errors import ( ChildFailedError, ProcessFailure, ) from torch.distributed.elastic.rendezvous import RendezvousParameters from torch.distributed.elastic.rendezvous.api import RendezvousHandler -from torch.distributed.launcher.api import LaunchConfig, _get_entrypoint_name +from torch.distributed.launcher.api import _get_entrypoint_name from dlrover.python.common import env_utils from dlrover.python.common.constants import ( @@ -74,7 +68,6 @@ AscendConstants, ConfigPath, JobConstant, - NodeEnv, NodeErrorMessage, NodeStatus, RendezvousName, @@ -89,7 +82,6 @@ from dlrover.python.common.log import default_logger as logger from dlrover.python.diagnosis.common.constants import DiagnosisActionConstants from dlrover.python.diagnosis.common.diagnosis_action import ( - DiagnosisAction, DiagnosisNodeAction, ) from dlrover.python.diagnosis.common.inference_chain import ( @@ -152,81 +144,6 @@ class RendezvousOutSyncError(Exception): pass -@dataclass -class ElasticLaunchConfig(LaunchConfig): - """ - Creates a rendezvous config of elastic training. - - Args: - network_check: whether to check the network available before training. - comm_perf_test: whether to test the communication performance. - node_unit: the number of unit of nodes. The number of nodes must be - a multiple of node_unit. - auto_config: indicate if automatically configure the nnodes and - nproc_per_node. - auto_tunning: whether to auto-tune the parallelism configuration. - exclude_straggler: The node will exit if it is a straggler in network - check and exclude_straggler is True. - save_at_breakpoint: indicate if save the checkpoint from the shared - memory into the disk after a failure occurs. - accelerator: the type of accelerator processor like nvidia.com/gpu, - ascend-npu. - training_log_file: the training log file of this training job - failure_node_errors: the error information that indicate the node - is a failure node - """ - - network_check: bool = False - comm_perf_test: bool = False - node_unit: int = 1 - training_port: int = AscendConstants.HCCL_PORT_START_DEFAULT - auto_config: bool = False - auto_tunning: bool = False - exclude_straggler: bool = False - save_at_breakpoint: bool = False - accelerator: str = "" - log_dir: Optional[str] = None # Keep Compatibility with PyTorch>=2.3.0 - redirects: Union[Std, Dict[int, Std]] = Std.NONE - tee: Union[Std, Dict[int, Std]] = Std.NONE - training_log_file: str = "" - failure_node_errors: str = "" - - def set_node_unit(self, node_unit): - """Set the number unit of nodes.""" - self.node_unit = node_unit - self.rdzv_configs["node_unit"] = node_unit - - def auto_configure_params(self): - self.training_log_file = os.getenv(NodeEnv.TRAINING_LOG_FILE, "") - self.failure_node_errors = os.getenv(NodeEnv.FAILURE_NODE_ERRORS, "") - if len(self.failure_node_errors) > 0: - errors = self.failure_node_errors.strip() - if errors[0] != "#" or errors[-1] != "#": - logger.warning("invalid failure node errors: %s", errors) - self.failure_node_errors = "" - - device = "" - if torch.cuda.is_available(): - device = torch.cuda.get_device_name() - if "Ascend" in device: - self.accelerator = Accelerators.ASCEND_NPU - logger.info( - f"Use {self.accelerator} device for training, " - f"cuda is available: {torch.cuda.is_available()}." - ) - - if not self.auto_config: - return - - if NodeEnv.NODE_NUM in os.environ: - self.min_nodes = int(os.environ[NodeEnv.NODE_NUM]) - self.max_nodes = int(os.environ[NodeEnv.NODE_NUM]) - if torch.cuda.is_available(): - self.nproc_per_node = torch.cuda.device_count() - if self.min_nodes >= 4: - self.network_check = True - - class MasterRendezvousHandler(RendezvousHandler): """The rendezvous handler completes rendezvous by connecting with the ElasticJob master. The master will collect all nodes @@ -940,7 +857,7 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: else: raise Exception(f"[{role}] worker group in {state.name} state") - def _process_diagnosis_action(self, action: DiagnosisAction): + def _process_diagnosis_action(self, action): if action.action_type == DiagnosisActionConstants.TYPE_NODE: if action.action == DiagnosisActionConstants.RESTART_WORKER: self._remaining_failovers -= 1 diff --git a/dlrover/python/master/diagnosis/diagnosis_manager.py b/dlrover/python/master/diagnosis/diagnosis_manager.py index 0e05a08f2..b7220579d 100644 --- a/dlrover/python/master/diagnosis/diagnosis_manager.py +++ b/dlrover/python/master/diagnosis/diagnosis_manager.py @@ -17,8 +17,8 @@ from dlrover.python.common.log import default_logger as logger from dlrover.python.diagnosis.common.constants import ( - DiagnosisConstant, DiagnosisActionConstants, + DiagnosisConstant, ) from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData @@ -35,7 +35,7 @@ InferenceChain, combine_inferences, ) -from dlrover.python.diagnosis.inferencechain.inferenceoperator.operator import ( +from dlrover.python.diagnosis.inferencechain.inferenceoperator.operator import ( # noqa: E501 get_master_observe_operators, ) from dlrover.python.master.diagnosis.diagnosis_data_manager import ( @@ -118,9 +118,13 @@ def _diagnose(self): ) def next_actions( - self, instance=DiagnosisConstant.LOCAL_INSTANCE, action_type=DiagnosisActionConstants.ACTION_TYPE_ANY + self, + instance=DiagnosisConstant.LOCAL_INSTANCE, + action_type=DiagnosisActionConstants.ACTION_TYPE_ANY, ) -> List[DiagnosisAction]: - return self._job_context.next_actions(instance=instance, action_type=action_type) + return self._job_context.next_actions( + instance=instance, action_type=action_type + ) class Diagnostician: diff --git a/dlrover/python/master/dist_master.py b/dlrover/python/master/dist_master.py index 1c85ba43c..022a9c0ae 100644 --- a/dlrover/python/master/dist_master.py +++ b/dlrover/python/master/dist_master.py @@ -43,7 +43,6 @@ TaskRescheduleCallback, TFPSNodeHandlingCallback, ) -from dlrover.python.master.node.job_context import JobContext from dlrover.python.master.servicer import create_master_service from dlrover.python.master.shard.task_manager import TaskManager from dlrover.python.master.stats.job_collector import JobMetricCollector diff --git a/dlrover/python/master/monitor/error_monitor.py b/dlrover/python/master/monitor/error_monitor.py index 38411120e..279071498 100644 --- a/dlrover/python/master/monitor/error_monitor.py +++ b/dlrover/python/master/monitor/error_monitor.py @@ -22,7 +22,11 @@ class ErrorMonitor(metaclass=ABCMeta): @abstractmethod def process_error( - self, node: Node, restart_count: int, error_data: str, level: str + self, + node: Node, + restart_count: int, + error_data: str, + level: str, ) -> bool: """ Handle the error of training processes. @@ -34,7 +38,7 @@ def process_error( level: the error level. Returns: - bool: wether to relaunch the node. + bool: to relaunch the node or not. """ pass @@ -57,7 +61,11 @@ def __init__(self): self._restart_errors: Dict[int, str] = {} def process_error( - self, node: Node, restart_count: int, error_data: str, level: str + self, + node: Node, + restart_count: int, + error_data: str, + level: str, ) -> bool: if level == TrainingExceptionLevel.PROCESS_ERROR: return self._handle_process_error(node, restart_count, error_data) @@ -108,7 +116,11 @@ def __init__(self, namespace="", cordon_node_eanbled=False): self._restart_errors: Dict[int, str] = {} def process_error( - self, node: Node, restart_count: int, error_data: str, level: str + self, + node: Node, + restart_count: int, + error_data: str, + level: str, ) -> bool: if level == TrainingExceptionLevel.PROCESS_ERROR: return self._handle_process_error(node, restart_count, error_data) @@ -151,5 +163,5 @@ def _handle_node_error(self, node: Node, error_data: str): if self.cordon_node_eanbled: succeed = self._k8s_client.cordon_node(node.host_name) if succeed: - logger.info(f"Node {node.name} is marked unschedulable.") + logger.info(f"Node {node.name} is marked unscheduled.") return True diff --git a/dlrover/python/master/node/dist_job_manager.py b/dlrover/python/master/node/dist_job_manager.py index aab195ffd..2b6cc2286 100644 --- a/dlrover/python/master/node/dist_job_manager.py +++ b/dlrover/python/master/node/dist_job_manager.py @@ -36,9 +36,6 @@ from dlrover.python.common.grpc import ParallelConfig from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node, NodeGroupResource -from dlrover.python.diagnosis.common.constants import ( - DiagnosisConstant, -) from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction from dlrover.python.master.monitor.error_monitor import K8sJobErrorMonitor from dlrover.python.master.node.event_callback import ( @@ -1113,7 +1110,7 @@ def _process_error( error_data: str, level: str, ) -> bool: - if self._error_monitor: + if self._error_monitor and node is not None: return self._error_monitor.process_error( node, restart_count, error_data, level ) @@ -1202,9 +1199,7 @@ def update_succeeded_node(self, node_id, node_type): super().update_succeeded_node(node_id, node_type) -def create_job_manager( - args: JobArgs, speed_monitor -) -> DistributedJobManager: +def create_job_manager(args: JobArgs, speed_monitor) -> DistributedJobManager: critical_worker_index = get_critical_worker_index(args) # Custom distribution strategy does not exit if there are pending nodes wait_pending_relaunch = ( diff --git a/dlrover/python/master/node/job_auto_scaler.py b/dlrover/python/master/node/job_auto_scaler.py index 97a6efd5e..1530ea72f 100644 --- a/dlrover/python/master/node/job_auto_scaler.py +++ b/dlrover/python/master/node/job_auto_scaler.py @@ -23,8 +23,9 @@ ) from dlrover.python.common.global_context import Context from dlrover.python.common.log import default_logger as logger -from dlrover.python.common.node import Node, NodeResource +from dlrover.python.common.node import NodeResource from dlrover.python.master.monitor.speed_monitor import SpeedMonitor +from dlrover.python.master.node.job_context import get_job_context from dlrover.python.master.node.ps import ParameterServerManager from dlrover.python.master.node.worker import WorkerManager from dlrover.python.master.resource.job import ( @@ -33,7 +34,6 @@ ) from dlrover.python.master.resource.optimizer import ResourcePlan from dlrover.python.master.scaler.base_scaler import ScalePlan, Scaler -from dlrover.python.master.node.job_context import get_job_context _dlrover_context = Context.singleton_instance() @@ -219,7 +219,9 @@ def execute_job_optimization_plan(self, plan: ResourcePlan): scale_plan.merge(ps_plan) self._speed_monitor.reset_running_speed_monitor() elif node_type == NodeType.WORKER: - chief_nodes = self._job_context.job_nodes_by_type(NodeType.CHIEF) + chief_nodes = self._job_context.job_nodes_by_type( + NodeType.CHIEF + ) chief_num = len(chief_nodes) worker_num = chief_num + group.count self._speed_monitor.set_target_worker_num(worker_num) diff --git a/dlrover/python/master/node/job_context.py b/dlrover/python/master/node/job_context.py index d58f05d7d..2165908ee 100644 --- a/dlrover/python/master/node/job_context.py +++ b/dlrover/python/master/node/job_context.py @@ -13,18 +13,18 @@ import copy import threading -import time from typing import Dict, Optional from dlrover.python.common.constants import NodeType from dlrover.python.common.node import Node from dlrover.python.common.singleton import Singleton -from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisActionQueue from dlrover.python.diagnosis.common.constants import ( - DiagnosisConstant, DiagnosisActionConstants, + DiagnosisConstant, +) +from dlrover.python.diagnosis.common.diagnosis_action import ( + DiagnosisActionQueue, ) -from datetime import datetime class JobContext(Singleton): @@ -45,9 +45,13 @@ def enqueue_actions(self, actions): self._action_queue.add_action(action) def next_actions( - self, instance=DiagnosisConstant.LOCAL_INSTANCE, action_type=DiagnosisActionConstants.ACTION_TYPE_ANY + self, + instance=DiagnosisConstant.LOCAL_INSTANCE, + action_type=DiagnosisActionConstants.ACTION_TYPE_ANY, ): - return self._action_queue.next_actions(instance=instance, action_type=action_type) + return self._action_queue.next_actions( + instance=instance, action_type=action_type + ) def _update_job_nodes(self, job_nodes: Dict[str, Dict[int, Node]]): with self._locker: diff --git a/dlrover/python/master/node/job_manager.py b/dlrover/python/master/node/job_manager.py index f74451b57..43674c884 100644 --- a/dlrover/python/master/node/job_manager.py +++ b/dlrover/python/master/node/job_manager.py @@ -15,13 +15,16 @@ from typing import Dict, List from dlrover.python.common.log import default_logger as logger -from dlrover.python.common.node import Node from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction from dlrover.python.master.hyperparams.simple_strategy_generator import ( SimpleStrategyGenerator, ) from dlrover.python.master.monitor.error_monitor import ErrorMonitor from dlrover.python.master.monitor.speed_monitor import SpeedMonitor +from dlrover.python.master.node.job_context import ( + get_job_context, + update_job_node, +) from dlrover.python.master.node.training_node import ( SyncNodeTrainingPorts, TrainingNodeConfig, @@ -29,10 +32,6 @@ from dlrover.python.master.resource.job import JobResource from dlrover.python.scheduler.job import JobArgs from dlrover.python.scheduler.kubernetes import k8sClient -from dlrover.python.master.node.job_context import ( - get_job_context, - update_job_node, -) class JobManager(metaclass=ABCMeta): @@ -244,4 +243,3 @@ def update_succeeded_node(self, node_id, node_type): logger.info(f"Node {node_id}({node_type}) to succeeded.") node.set_as_succeeded() update_job_node(node) - diff --git a/dlrover/python/master/node/local_job_manager.py b/dlrover/python/master/node/local_job_manager.py index 868d9a1f0..6f7c166b3 100644 --- a/dlrover/python/master/node/local_job_manager.py +++ b/dlrover/python/master/node/local_job_manager.py @@ -18,9 +18,9 @@ from dlrover.python.common.node import Node from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction from dlrover.python.master.monitor.error_monitor import SimpleErrorMonitor +from dlrover.python.master.node.job_context import update_job_node from dlrover.python.master.node.job_manager import JobManager from dlrover.python.scheduler.job import JobArgs -from dlrover.python.master.node.job_context import update_job_node class LocalJobManager(JobManager): @@ -52,7 +52,6 @@ def start(self): ) update_job_node(workers[i]) - def should_early_stop(self): return False diff --git a/dlrover/python/master/node/ps.py b/dlrover/python/master/node/ps.py index a97edf99d..448aa9eb6 100644 --- a/dlrover/python/master/node/ps.py +++ b/dlrover/python/master/node/ps.py @@ -26,10 +26,13 @@ from dlrover.python.common.global_context import Context from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node, NodeGroupResource, NodeResource -from dlrover.python.master.node.training_node import TrainingNodeManager, skip_pending_judgement +from dlrover.python.master.node.job_context import update_job_node +from dlrover.python.master.node.training_node import ( + TrainingNodeManager, + skip_pending_judgement, +) from dlrover.python.master.resource.job import JobResource from dlrover.python.master.scaler.base_scaler import ScalePlan -from dlrover.python.master.node.job_context import update_job_node _dlrover_ctx = Context.singleton_instance() @@ -91,7 +94,9 @@ def relaunch_node(self, node: Node, remove_exited_node=False): update_job_node(new_node) if node in self._training_ps_cluster: i = self._training_ps_cluster.index(node) - self._training_ps_cluster[i] = self._job_context.ps_nodes[new_node.id] + self._training_ps_cluster[i] = self._job_context.ps_nodes[ + new_node.id + ] logger.info("Relaunch node %s to %s", node.name, new_id) plan.launch_nodes.append( Node( @@ -367,7 +372,9 @@ def _migrate_parameter_server(self, name: str, cpu=0, memory=0): name=self._new_node_name_fn(NodeType.PS, new_ps_id), ) update_job_node(new_node) - self._migrated_ps_nodes[old_ps_id] = self._job_context.ps_nodes[new_node.id] + self._migrated_ps_nodes[old_ps_id] = self._job_context.ps_nodes[ + new_node.id + ] logger.info("Migrated PS %s to PS %s", old_ps_id, new_ps_id) return new_node diff --git a/dlrover/python/master/node/training_node.py b/dlrover/python/master/node/training_node.py index 043a4fec6..f4a8ab811 100644 --- a/dlrover/python/master/node/training_node.py +++ b/dlrover/python/master/node/training_node.py @@ -33,7 +33,10 @@ from dlrover.python.common.global_context import Context from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node -from dlrover.python.master.node.job_context import get_job_context, update_job_node +from dlrover.python.master.node.job_context import ( + get_job_context, + update_job_node, +) from dlrover.python.master.scaler.base_scaler import ScalePlan from dlrover.python.scheduler.job import JobArgs @@ -312,7 +315,9 @@ def get_running_nodes(self): """TensorFlow Chief nodes""" nodes = [] with self._lock: - training_nodes = self._job_context.job_nodes_by_type(self._node_type) + training_nodes = self._job_context.job_nodes_by_type( + self._node_type + ) for node in training_nodes.values(): if node.status == NodeStatus.RUNNING: nodes.append(node) diff --git a/dlrover/python/master/node/worker.py b/dlrover/python/master/node/worker.py index d29a2116c..87168ea1a 100644 --- a/dlrover/python/master/node/worker.py +++ b/dlrover/python/master/node/worker.py @@ -25,7 +25,7 @@ from dlrover.python.common.global_context import Context from dlrover.python.common.log import default_logger as logger from dlrover.python.common.node import Node, NodeGroupResource, NodeResource -from dlrover.python.master.node.job_context import JobContext, update_job_node +from dlrover.python.master.node.job_context import update_job_node from dlrover.python.master.node.training_node import ( ALIVE_STATUS, TrainingNodeManager, @@ -56,9 +56,7 @@ def __init__( new_node_name_fn: A callable function to generate a node name of chief. """ - super(ChiefManager, self).__init__( - NodeType.CHIEF, new_node_name_fn - ) + super(ChiefManager, self).__init__(NodeType.CHIEF, new_node_name_fn) self._job_resource = job_resource self._max_relaunch_num = max_relaunch_num self._new_service_fn = new_service_fn @@ -124,9 +122,7 @@ def __init__( new_node_name_fn: A callable function to generate a node name of worker. """ - super(WorkerManager, self).__init__( - NodeType.WORKER, new_node_name_fn - ) + super(WorkerManager, self).__init__(NodeType.WORKER, new_node_name_fn) self._job_resource = job_resource self._max_relaunch_num = max_relaunch_num self._new_service_fn = new_service_fn diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index b64708512..6a4d9dda8 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -650,10 +650,8 @@ def _report_heartbeat( grpc_actions: List[grpc.DiagnosisAction] = [] for action in actions: grpc_action = grpc.DiagnosisAction( - rank=action.rank, - timestamp=action.timestamp, - action=action.action, - expired_time_period=action.expired_time_period, + action.__class__.__name__, + action.to_json(), ) grpc_actions.append(grpc_action) return grpc.HeartbeatResponse(diagnosis_actions=grpc_actions) diff --git a/dlrover/python/tests/test_diagnosis_agent.py b/dlrover/python/tests/test_diagnosis_agent.py index 3f50ebb7b..6afc137bf 100644 --- a/dlrover/python/tests/test_diagnosis_agent.py +++ b/dlrover/python/tests/test_diagnosis_agent.py @@ -22,8 +22,8 @@ from dlrover.python.diagnosis.common.constants import DiagnosisActionConstants from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric from dlrover.python.elastic_agent.common.worker_context import ( - update_worker_context, get_worker_context, + update_worker_context, ) from dlrover.python.elastic_agent.config.launch_config import ( ElasticLaunchConfig, @@ -90,21 +90,29 @@ def test_diagnose_training(self): ) action = agent.diagnose_training_failure() - self.assertEqual(action.action, DiagnosisActionConstants.RESTART_WORKER) + self.assertEqual( + action.action, DiagnosisActionConstants.RESTART_WORKER + ) agent._errors = "error code is 507035" action = agent.diagnose_training_failure() - self.assertEqual(action.action, DiagnosisActionConstants.RELAUNCH_WORKER) + self.assertEqual( + action.action, DiagnosisActionConstants.RELAUNCH_WORKER + ) agent._errors = "error code is 11111" worker_context.remaining_failovers = 0 action = agent.diagnose_training_failure() - self.assertEqual(action.action, DiagnosisActionConstants.RELAUNCH_WORKER) + self.assertEqual( + action.action, DiagnosisActionConstants.RELAUNCH_WORKER + ) agent._errors = " #" worker_context.remaining_failovers = 2 action = agent.diagnose_training_failure() - self.assertEqual(action.action, DiagnosisActionConstants.RESTART_WORKER) + self.assertEqual( + action.action, DiagnosisActionConstants.RESTART_WORKER + ) def test_worker_training_metric(self): test = WorkerTrainingMetric( diff --git a/dlrover/python/tests/test_job_auto_scaler.py b/dlrover/python/tests/test_job_auto_scaler.py index 7d322b033..d43dfa988 100644 --- a/dlrover/python/tests/test_job_auto_scaler.py +++ b/dlrover/python/tests/test_job_auto_scaler.py @@ -26,13 +26,16 @@ AllreduceTrainingAutoScaler, PSTrainingAutoScaler, ) +from dlrover.python.master.node.job_context import ( + get_job_context, + update_job_node, +) from dlrover.python.master.resource.optimizer import ResourcePlan from dlrover.python.tests.test_utils import ( MockK8sAllreduceJobArgs, MockK8sPSJobArgs, mock_k8s_client, ) -from dlrover.python.master.node.job_context import get_job_context, update_job_node _dlrover_context = Context.singleton_instance() diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py index b2d78b8d4..091c270ac 100644 --- a/dlrover/python/tests/test_job_manager.py +++ b/dlrover/python/tests/test_job_manager.py @@ -51,6 +51,12 @@ TaskRescheduleCallback, TFPSNodeHandlingCallback, ) +from dlrover.python.master.node.job_context import ( + clear_job_nodes, + get_job_context, + update_job_node, + update_job_nodes, +) from dlrover.python.master.node.local_job_manager import LocalJobManager from dlrover.python.master.node.status_flow import ( NODE_STATE_FLOWS, @@ -75,7 +81,6 @@ mock_k8s_client, new_dataset_splitter, ) -from dlrover.python.master.node.job_context import get_job_context, update_job_node, clear_job_nodes, update_job_nodes _MOCK_JOB_UUID = "11111" @@ -253,9 +258,7 @@ def test_relaunch_node(self): NodeType.WORKER, 0, 0.7, 2048, gpu_stats ) # noqa job_nodes = self.job_context.job_nodes() - self.assertEqual( - job_nodes[NodeType.WORKER][0].used_resource.cpu, 0.7 - ) + self.assertEqual(job_nodes[NodeType.WORKER][0].used_resource.cpu, 0.7) self.assertEqual( job_nodes[NodeType.WORKER][0].used_resource.memory, 2048 ) @@ -372,9 +375,7 @@ def test_get_dead_node_event(self): self.assertEqual(len(events), 0) job_nodes = self.job_context.job_nodes() - for index, node in enumerate( - job_nodes[NodeType.WORKER].values() - ): + for index, node in enumerate(job_nodes[NodeType.WORKER].values()): node.status = NodeStatus.RUNNING now = datetime.now() node.heartbeat_time = (now - timedelta(seconds=1000)).timestamp() @@ -393,9 +394,7 @@ def test_get_dead_node_event(self): self.assertEqual(len(nodes_time_info), 3) job_nodes = self.job_context.job_nodes() - for index, node in enumerate( - job_nodes[NodeType.WORKER].values() - ): + for index, node in enumerate(job_nodes[NodeType.WORKER].values()): node.status = NodeStatus.RUNNING now = datetime.now() node.heartbeat_time = (now - timedelta(seconds=1000)).timestamp() @@ -873,14 +872,10 @@ def test_local_job_manager(self): job_manager.handle_training_failure(NodeType.WORKER, 3) try: - self.assertFalse( - job_nodes[NodeType.WORKER][0].is_succeeded() - ) + self.assertFalse(job_nodes[NodeType.WORKER][0].is_succeeded()) job_manager.update_succeeded_node(0, NodeType.WORKER) job_nodes = job_context.job_nodes() - self.assertTrue( - job_nodes[NodeType.WORKER][0].is_succeeded() - ) + self.assertTrue(job_nodes[NodeType.WORKER][0].is_succeeded()) job_manager.update_succeeded_node(5, NodeType.WORKER) job_manager.update_succeeded_node(0, "unknown") except Exception: diff --git a/dlrover/python/tests/test_master.py b/dlrover/python/tests/test_master.py index 8cc59d66d..9e630f1ef 100644 --- a/dlrover/python/tests/test_master.py +++ b/dlrover/python/tests/test_master.py @@ -28,13 +28,17 @@ _create_master_service_on_k8s, ) from dlrover.python.master.main import update_context +from dlrover.python.master.node.job_context import ( + clear_job_nodes, + get_job_context, + update_job_nodes, +) from dlrover.python.master.shard.dataset_splitter import new_dataset_splitter from dlrover.python.tests.test_utils import ( MockK8sPSJobArgs, mock_k8s_client, start_local_master, ) -from dlrover.python.master.node.job_context import get_job_context, update_job_node, clear_job_nodes, update_job_nodes _dlrover_context = Context.singleton_instance() diff --git a/dlrover/python/tests/test_ps_manager.py b/dlrover/python/tests/test_ps_manager.py index dea0e6c15..99cb815d9 100644 --- a/dlrover/python/tests/test_ps_manager.py +++ b/dlrover/python/tests/test_ps_manager.py @@ -22,11 +22,16 @@ ) from dlrover.python.common.global_context import Context from dlrover.python.common.node import Node, NodeGroupResource, NodeResource +from dlrover.python.master.node.job_context import ( + clear_job_nodes, + get_job_context, + update_job_node, + update_job_nodes, +) from dlrover.python.master.node.ps import ParameterServerManager from dlrover.python.master.resource.job import JobResource from dlrover.python.scheduler.factory import new_elastic_job from dlrover.python.tests.test_utils import mock_k8s_client -from dlrover.python.master.node.job_context import get_job_context, update_job_nodes, clear_job_nodes, update_job_node _dlrover_ctx = Context.singleton_instance() diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index 1b95327e8..480f5eae9 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -36,7 +36,11 @@ from dlrover.python.master.elastic_training.sync_service import SyncService from dlrover.python.master.monitor.speed_monitor import SpeedMonitor from dlrover.python.master.node.dist_job_manager import create_job_manager -from dlrover.python.master.node.job_context import get_job_context, update_job_node, clear_job_nodes +from dlrover.python.master.node.job_context import ( + clear_job_nodes, + get_job_context, + update_job_node, +) from dlrover.python.master.servicer import MasterServicer from dlrover.python.master.shard.task_manager import TaskManager from dlrover.python.master.stats.job_collector import JobMetricCollector @@ -61,9 +65,7 @@ def setUp(self) -> None: speed_monitor = SpeedMonitor() self.task_manager = TaskManager(False, speed_monitor) - self.job_manager = create_job_manager( - params, speed_monitor - ) + self.job_manager = create_job_manager(params, speed_monitor) self.job_context = get_job_context() self.job_manager._init_nodes() @@ -483,9 +485,7 @@ def test_update_node_addr(self): self.job_manager._init_nodes() self.servicer._update_node_address(request) node = self.job_context.job_node(task_type, task_id) - self.assertEqual( - node.service_addr, addr - ) + self.assertEqual(node.service_addr, addr) ps_nodes = self.job_context.job_nodes_by_type(NodeType.PS) for node in ps_nodes.values(): node.status = NodeStatus.RUNNING diff --git a/dlrover/python/tests/test_sync_service.py b/dlrover/python/tests/test_sync_service.py index e9d2844d3..daa83b0a9 100644 --- a/dlrover/python/tests/test_sync_service.py +++ b/dlrover/python/tests/test_sync_service.py @@ -17,8 +17,12 @@ from dlrover.python.master.elastic_training.sync_service import SyncService from dlrover.python.master.monitor.speed_monitor import SpeedMonitor from dlrover.python.master.node.dist_job_manager import create_job_manager +from dlrover.python.master.node.job_context import ( + clear_job_nodes, + get_job_context, + update_job_node, +) from dlrover.python.tests.test_utils import MockK8sPSJobArgs, mock_k8s_client -from dlrover.python.master.node.job_context import get_job_context, update_job_node, clear_job_nodes class SyncServiceTest(unittest.TestCase): diff --git a/dlrover/python/tests/test_worker_manager.py b/dlrover/python/tests/test_worker_manager.py index 92c0b538d..5c3ec0d11 100644 --- a/dlrover/python/tests/test_worker_manager.py +++ b/dlrover/python/tests/test_worker_manager.py @@ -25,11 +25,16 @@ ) from dlrover.python.common.global_context import Context from dlrover.python.common.node import Node, NodeGroupResource, NodeResource +from dlrover.python.master.node.job_context import ( + clear_job_nodes, + get_job_context, + update_job_node, + update_job_nodes, +) from dlrover.python.master.node.worker import ChiefManager, WorkerManager from dlrover.python.master.resource.job import JobResource from dlrover.python.scheduler.factory import new_elastic_job from dlrover.python.tests.test_utils import mock_k8s_client -from dlrover.python.master.node.job_context import get_job_context, update_job_node, clear_job_nodes, update_job_nodes _dlrover_ctx = Context.singleton_instance() From 1a2679c0f7be35e1635d62fdcd44e3fa75888137 Mon Sep 17 00:00:00 2001 From: bsang Date: Thu, 31 Oct 2024 21:47:38 +0800 Subject: [PATCH 10/10] fix pre-commit --- dlrover/python/tests/test_job_auto_scaler.py | 29 ++++++++++---------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/dlrover/python/tests/test_job_auto_scaler.py b/dlrover/python/tests/test_job_auto_scaler.py index d43dfa988..02fd8b4d0 100644 --- a/dlrover/python/tests/test_job_auto_scaler.py +++ b/dlrover/python/tests/test_job_auto_scaler.py @@ -27,6 +27,7 @@ PSTrainingAutoScaler, ) from dlrover.python.master.node.job_context import ( + clear_job_nodes, get_job_context, update_job_node, ) @@ -43,6 +44,10 @@ class JobAutoScalerTest(unittest.TestCase): def setUp(self) -> None: mock_k8s_client() + self.job_context = get_job_context() + + def tearDown(self) -> None: + clear_job_nodes() def test_execute_job_optimization_plan(self): params = MockK8sPSJobArgs() @@ -68,26 +73,21 @@ def test_execute_job_optimization_plan(self): plan.node_resources["test-edljob-worker-1"] = NodeResource(8, 8192) plan.node_resources["test-edljob-ps-1"] = NodeResource(8, 8192) - job_context = get_job_context() - ps_nodes = job_context.job_nodes_by_type(NodeType.PS) + ps_nodes = self.job_context.job_nodes_by_type(NodeType.PS) ps_node = ps_nodes[1] ps_node.type = NodeType.PS ps_node.status = NodeStatus.RUNNING update_job_node(ps_node) - # auto_scaler._ps_manager._nodes[1].status = NodeStatus.RUNNING - worker_nodes = job_context.job_nodes_by_type(NodeType.WORKER) + worker_nodes = self.job_context.job_nodes_by_type(NodeType.WORKER) worker_node = worker_nodes[0] worker_node.type = NodeType.WORKER worker_node.critical = True update_job_node(worker_node) - # auto_scaler._worker_manager._nodes[0].critical = True scale_plan = auto_scaler.execute_job_optimization_plan(plan) - ps_nodes = job_context.job_nodes_by_type(NodeType.PS) - # self.assertEqual(len(manager._ps_manager._nodes), 4) + ps_nodes = self.job_context.job_nodes_by_type(NodeType.PS) self.assertEqual(len(ps_nodes), 4) - worker_nodes = job_context.job_nodes_by_type(NodeType.WORKER) - # self.assertEqual(len(manager._worker_manager._nodes), 7) + worker_nodes = self.job_context.job_nodes_by_type(NodeType.WORKER) self.assertEqual(len(worker_nodes), 7) self.assertEqual(len(scale_plan.remove_nodes), 1) self.assertEqual(len(scale_plan.launch_nodes), 5) @@ -141,11 +141,9 @@ def test_reduce_timeout_pending_node_resource(self): ) auto_scaler._autoscaling_started = True - job_context = get_job_context() - ps_nodes = job_context.job_nodes_by_type(NodeType.PS) + ps_nodes = self.job_context.job_nodes_by_type(NodeType.PS) ps0 = ps_nodes[0] ps0.type = NodeType.PS - # ps0 = manager._ps_manager._nodes[0] ps0.config_resource.cpu = 16 ps0.status = NodeStatus.PENDING ps0.create_time = datetime.now() + timedelta(days=-1) @@ -164,6 +162,10 @@ def test_reduce_timeout_pending_node_resource(self): class AllreduceAutoScalerTest(unittest.TestCase): def setUp(self) -> None: mock_k8s_client() + self.job_context = get_job_context() + + def tearDown(self) -> None: + clear_job_nodes() def test_execute_job_optimization_plan(self): params = MockK8sAllreduceJobArgs() @@ -171,8 +173,7 @@ def test_execute_job_optimization_plan(self): manager = create_job_manager(params, SpeedMonitor()) manager._init_nodes() - job_context = get_job_context() - worker_nodes = job_context.job_nodes_by_type(NodeType.WORKER) + worker_nodes = self.job_context.job_nodes_by_type(NodeType.WORKER) for worker in worker_nodes.values(): worker.status = NodeStatus.RUNNING