diff --git a/codecov.yml b/codecov.yml index fadfcdcc7..db4bfeaa9 100644 --- a/codecov.yml +++ b/codecov.yml @@ -15,7 +15,7 @@ coverage: flag_coverage_not_uploaded_behavior: include patch: default: - target: 80% + target: 85% threshold: 3% removed_code_behavior: fully_covered_patch project: diff --git a/dlrover/python/common/comm.py b/dlrover/python/common/comm.py index 6e691f285..4bcc90b8c 100644 --- a/dlrover/python/common/comm.py +++ b/dlrover/python/common/comm.py @@ -228,6 +228,12 @@ class HeartBeat(Message): timestamp: int = 0 +@dataclass +class PreCheckRequest(Message): + timestamp: int = 0 + type: str = "INITIAL" + + @dataclass class DatasetShardParams(Message): batch_size: int = 0 diff --git a/dlrover/python/common/constants.py b/dlrover/python/common/constants.py index d5babb9e9..3a510c218 100644 --- a/dlrover/python/common/constants.py +++ b/dlrover/python/common/constants.py @@ -385,8 +385,12 @@ class JobConstant(object): # sleep 5s before next node check round NODE_CHECK_NEXT_ROUND_TIMEOUT = 5 + # default interval seconds for loop in training agent TRAINING_AGENT_LOOP_DEFAULT_INTERVAL = 15 + # interval seconds for pre-check waiting + PRE_CHECK_WAIT_SECS = 5 + class Accelerators(object): NVIDIA_GPU = "nvidia.com/gpu" diff --git a/dlrover/python/common/global_context.py b/dlrover/python/common/global_context.py index 1f29c81c0..8d9fc24d8 100644 --- a/dlrover/python/common/global_context.py +++ b/dlrover/python/common/global_context.py @@ -61,6 +61,7 @@ class DefaultValues(object): GPU_NUM_PER_NODE = 8 NPU_NUM_PER_NODE = 16 MAX_METRIC_REC = 30 + PRE_CHECK_ENABLE = True class Context(Singleton): @@ -107,6 +108,7 @@ def __init__(self): # The strategy of 'hang detection': # 0: log only; 1: notify; 2: with fault tolerance self.hang_detection = DefaultValues.HANG_DETECTION + self.pre_check_enable = DefaultValues.PRE_CHECK_ENABLE self.gpu_per_node = DefaultValues.GPU_NUM_PER_NODE self.npu_per_node = DefaultValues.NPU_NUM_PER_NODE self.max_metric_records = DefaultValues.MAX_METRIC_REC diff --git a/dlrover/python/diagnosis/common/diagnosis_action.py b/dlrover/python/diagnosis/common/diagnosis_action.py index 57b39797c..40c27fef2 100644 --- a/dlrover/python/diagnosis/common/diagnosis_action.py +++ b/dlrover/python/diagnosis/common/diagnosis_action.py @@ -260,6 +260,10 @@ def clear(self): with self._lock: self._actions.clear() + def len(self): + with self._lock: + return sum(len(d) for d in self._actions.values()) + def next_action( self, instance=DiagnosisConstant.LOCAL_INSTANCE, diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index 22346f56f..d0f8190fa 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -450,6 +450,11 @@ def get_elastic_run_config(self) -> Dict[str, str]: response: comm.ElasticRunConfig = self._get(request) return response.configs + def get_pre_check_result(self) -> bool: + request = comm.PreCheckRequest() + # result is boolean + return self._get(request) + def report_event( self, event_type: str = "", diff --git a/dlrover/python/master/args.py b/dlrover/python/master/args.py index a72e3cd35..dc3ce4f17 100644 --- a/dlrover/python/master/args.py +++ b/dlrover/python/master/args.py @@ -17,6 +17,17 @@ from dlrover.python.common.log import default_logger as logger +def str2bool(value): + if isinstance(value, bool): + return value + if value.lower() in {"true", "yes", "t", "y", "1"}: + return True + elif value.lower() in {"false", "no", "n", "0"}: + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + def add_params(parser): parser.add_argument("--job_name", help="ElasticJob name", required=True) parser.add_argument( @@ -54,6 +65,13 @@ def add_params(parser): type=str, help="The service type of master: grpc/http.", ) + parser.add_argument( + "--pre_check", + "--pre_check", + default=DefaultValues.PRE_CHECK_ENABLE, + type=str2bool, + help="Enable pre training check or not.", + ) def print_args(args, exclude_args=[], groups=None): diff --git a/dlrover/python/master/diagnosis/diagnosis.py b/dlrover/python/master/diagnosis/diagnosis.py index c4c38457e..30f4fc50d 100644 --- a/dlrover/python/master/diagnosis/diagnosis.py +++ b/dlrover/python/master/diagnosis/diagnosis.py @@ -44,6 +44,9 @@ def get_observing_operators(self) -> List[InferenceOperator]: def get_resolving_operators(self) -> List[InferenceOperator]: return self._resolvers + def register_precheck_(self, problems: List[Inference]): + self._training_problems = problems + def register_training_problems(self, problems: List[Inference]): self._training_problems = problems diff --git a/dlrover/python/master/diagnosis/diagnosis_manager.py b/dlrover/python/master/diagnosis/diagnosis_manager.py index 6df80f6c0..1d0f7d1f8 100644 --- a/dlrover/python/master/diagnosis/diagnosis_manager.py +++ b/dlrover/python/master/diagnosis/diagnosis_manager.py @@ -14,6 +14,7 @@ import threading import time +from dlrover.python.common.global_context import Context from dlrover.python.common.log import default_logger as logger from dlrover.python.diagnosis.common.constants import DiagnosisConstant from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData @@ -33,8 +34,13 @@ from dlrover.python.master.diagnosis.diagnosis_data_manager import ( DiagnosisDataManager, ) +from dlrover.python.master.diagnosis.precheck_operator import ( + NoPreCheckOperator, +) from dlrover.python.master.node.job_context import get_job_context +_dlrover_ctx = Context.singleton_instance() + class DiagnosisManager: """ @@ -47,15 +53,80 @@ def __init__(self): self._diagnostician: Diagnostician = Diagnostician(self._data_manager) self._job_context = get_job_context() + @classmethod + def get_pre_check_operators(cls): + return [NoPreCheckOperator()] + def collect_diagnosis_data(self, data: DiagnosisData): self._data_manager.store_data(data) def pre_check(self): - logger.info("Start Diagnosis Manager to pre-check training...") - pass + if not _dlrover_ctx.pre_check_enable: + return + + start = time.time() + pre_check_ops = self.get_pre_check_operators() + logger.info( + "Start to training pre-check" f"with operators: {pre_check_ops}." + ) + + for pre_check_op in pre_check_ops: + current_start = time.time() + current_op_result = None + pre_check_op_name = pre_check_op.__class__.__name__ + + try: + # retry loops for each operator + for i in range(pre_check_op.get_retry_limit_times()): + check_start = time.time() + + # do check + current_op_result = pre_check_op.check() + logger.info( + f"{pre_check_op_name} " + f"check({i}) " + f"cost: {time.time()-check_start:.2f}ms, " + f"result: {current_op_result}" + ) + + if not current_op_result.is_success(): + # try recover and wait + pre_check_op.recover() + time.sleep(pre_check_op.get_retry_interval_secs()) + + # check again after recover + current_op_result = pre_check_op.check() + else: + break + except Exception as e: + logger.error(f"Pre-check operator got unexpected error: {e}") + continue + + if not current_op_result.is_success(): + action = pre_check_op.get_failed_action() + self._job_context.enqueue_action(action) + logger.warning( + "Training pre-check failed " + f"by {pre_check_op_name} " + f"with result: {current_op_result}, " + f"cost:{time.time()-current_start:.2f}ms. " + f"Invoke action: {action}." + ) + return + else: + self._job_context.set_pre_check_pass() + logger.info( + f"{pre_check_op_name} finish " + f"with result: {current_op_result}, " + f"cost:{time.time()-current_start:.2f}ms." + ) + + logger.info( + "Training pre-check complete, " f"cost:{time.time()-start:.2f}ms." + ) def start_observing(self): - logger.info("Start Diagnosis Manager to observing training...") + logger.info("Start to observing training...") self._is_observing_started = True self._diagnostician.register_training_problems( diff --git a/dlrover/python/master/diagnosis/precheck_operator.py b/dlrover/python/master/diagnosis/precheck_operator.py new file mode 100644 index 000000000..146a9da23 --- /dev/null +++ b/dlrover/python/master/diagnosis/precheck_operator.py @@ -0,0 +1,82 @@ +# Copyright 2025 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import List + +from dlrover.python.diagnosis.common.diagnosis_action import ( + DiagnosisAction, + NoAction, +) + + +@dataclass +class PreCheckResult(object): + + # The default success result is 0. The other result code(>0) should be + # defined by different pre-check operator it's self. + result: int = 0 + + # The simple description info for the result. + result_msg: str = "" + + # Abnormal nodes' id. + abnormal_nodes: List[int] = field(default_factory=list) + + def is_success(self): + return self.result == 0 + + +class PreCheckOperator(ABC): + @classmethod + def get_retry_interval_secs(cls) -> int: + """The retry interval seconds, can be overridden in subclasses.""" + return 5 + + @classmethod + def get_retry_limit_times(cls) -> int: + """ + The retry limit times, can be overridden in subclasses. For most + pre-check, the retry value should > 1(at least once retry). + + The failed action will be executed if result still not ok after + several retry times. + """ + return 3 + + @abstractmethod + def check(self) -> PreCheckResult: + """The abstraction of the main check procedure.""" + pass + + @abstractmethod + def recover(self): + """The abstraction of the procedure if check failed.""" + pass + + @abstractmethod + def get_failed_action(self) -> DiagnosisAction: + """The abstraction of the action when operator check failed.""" + pass + + +class NoPreCheckOperator(PreCheckOperator): + def check(self): + return PreCheckResult() + + def recover(self): + return + + def get_failed_action(self) -> DiagnosisAction: + return NoAction() diff --git a/dlrover/python/master/dist_master.py b/dlrover/python/master/dist_master.py index c10ec853e..b003270f1 100644 --- a/dlrover/python/master/dist_master.py +++ b/dlrover/python/master/dist_master.py @@ -206,7 +206,6 @@ def prepare(self): def pre_check(self): logger.info("Pre-check before running.") self.diagnosis_manager.pre_check() - # TODO def _add_node_event_callback(self): """Add NodeEventCallbacks for the listeners of Pod events.""" diff --git a/dlrover/python/master/node/job_context.py b/dlrover/python/master/node/job_context.py index 816339b9d..677663fbb 100644 --- a/dlrover/python/master/node/job_context.py +++ b/dlrover/python/master/node/job_context.py @@ -17,6 +17,7 @@ from typing import Dict, Optional, Union from dlrover.python.common.constants import NodeType +from dlrover.python.common.global_context import Context from dlrover.python.common.node import Node from dlrover.python.common.singleton import Singleton from dlrover.python.diagnosis.common.constants import ( @@ -27,6 +28,8 @@ DiagnosisActionQueue, ) +_dlrover_context = Context.singleton_instance() + class JobContext(Singleton): """ @@ -38,6 +41,7 @@ def __init__(self): self._action_queue = DiagnosisActionQueue() self._job_nodes: Dict[str, Dict[int, Node]] = {} self._failed_nodes: Dict[int, int] = {} + self._pre_check_pass = False self._locker = threading.Lock() def enqueue_action(self, action): @@ -193,6 +197,14 @@ def report_failed_node(self, node_id: Union[int, str] = None): def get_failed_node_cnt(self): return len(self._failed_nodes) + def set_pre_check_pass(self): + self._pre_check_pass = True + + def is_pre_check_pass(self): + if _dlrover_context.pre_check_enable: + return self._pre_check_pass + return True + def get_job_context() -> JobContext: job_context = JobContext.singleton_instance() diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index 3e1d16484..2b46ec325 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -49,6 +49,7 @@ RendezvousManager, ) from dlrover.python.master.monitor.speed_monitor import SpeedMonitor +from dlrover.python.master.node.job_context import get_job_context from dlrover.python.master.node.job_manager import JobManager from dlrover.python.master.node.training_node import SyncNodeTrainingPorts from dlrover.python.master.shard.dataset_splitter import new_dataset_splitter @@ -161,6 +162,10 @@ def get(self, request, _): elif isinstance(req_message, comm.ElasticRunConfigRequest): configs = self._job_manager.get_elastic_run_configs() message = comm.ElasticRunConfig(configs=configs) + elif isinstance(req_message, comm.PreCheckRequest): + message = self._get_pre_check_result( + node_type, node_id, req_message + ) elif isinstance(req_message, comm.HeartBeat): message = self._report_heartbeat(node_type, node_id, req_message) @@ -664,6 +669,9 @@ def _sync_training_ports( port=sync_ports.training_port, newport=sync_ports.next_check_port ) + def _get_pre_check_result(self, node_id, message: comm.PreCheckRequest): + return get_job_context().is_pre_check_pass() + def _report_event(self, message: comm.Event): if self._error_monitor: self._error_monitor.report_event( diff --git a/dlrover/python/master/watcher/k8s_watcher.py b/dlrover/python/master/watcher/k8s_watcher.py index 15fce0547..2e6b8ca7d 100644 --- a/dlrover/python/master/watcher/k8s_watcher.py +++ b/dlrover/python/master/watcher/k8s_watcher.py @@ -106,6 +106,11 @@ def _convert_pod_event_to_node_event(event): if metadata.deletion_timestamp: status = NodeStatus.DELETED + logger.debug( + f"Got monitor event for pod: {pod_name}, " + f"node: {host_name}, ip: {host_ip}, status: {status}." + ) + restart = _verify_restarting_training(evt_obj) if restart: logger.info(f"{evt_obj.metadata.name} need to restart.") diff --git a/dlrover/python/tests/test_args.py b/dlrover/python/tests/test_args.py index eb49386db..fa9860770 100644 --- a/dlrover/python/tests/test_args.py +++ b/dlrover/python/tests/test_args.py @@ -13,10 +13,28 @@ import unittest -from dlrover.python.master.args import parse_master_args +from dlrover.python.master.args import parse_master_args, str2bool class ArgsTest(unittest.TestCase): + def test_str2bool(self): + self.assertTrue(str2bool("TRUE")) + self.assertTrue(str2bool("True")) + self.assertTrue(str2bool("true")) + self.assertTrue(str2bool("yes")) + self.assertTrue(str2bool("t")) + self.assertTrue(str2bool("y")) + self.assertTrue(str2bool("1")) + self.assertTrue(str2bool(True)) + + self.assertFalse(str2bool("FALSE")) + self.assertFalse(str2bool("False")) + self.assertFalse(str2bool("false")) + self.assertFalse(str2bool("no")) + self.assertFalse(str2bool("n")) + self.assertFalse(str2bool("0")) + self.assertFalse(str2bool(False)) + def test_parse_master_args(self): original_args = [ "--job_name", @@ -30,6 +48,7 @@ def test_parse_master_args(self): self.assertEqual(parsed_args.pending_timeout, 900) self.assertEqual(parsed_args.pending_fail_strategy, 1) self.assertTrue(parsed_args.service_type, "grpc") + self.assertTrue(parsed_args.pre_check) original_args = [ "--job_name", @@ -42,8 +61,11 @@ def test_parse_master_args(self): "2", "--service_type", "http", + "--pre_check", + "false", ] parsed_args = parse_master_args(original_args) self.assertEqual(parsed_args.pending_timeout, 600) self.assertEqual(parsed_args.pending_fail_strategy, 2) self.assertTrue(parsed_args.service_type, "http") + self.assertFalse(parsed_args.pre_check) diff --git a/dlrover/python/tests/test_diagnosis_manager.py b/dlrover/python/tests/test_diagnosis_manager.py index 4a721b9e1..dbc3993a5 100644 --- a/dlrover/python/tests/test_diagnosis_manager.py +++ b/dlrover/python/tests/test_diagnosis_manager.py @@ -15,11 +15,17 @@ import unittest from typing import List from unittest import mock +from unittest.mock import MagicMock +from dlrover.python.common.constants import NodeStatus from dlrover.python.diagnosis.common.constants import ( DiagnosisActionType, DiagnosisDataType, ) +from dlrover.python.diagnosis.common.diagnosis_action import ( + DiagnosisAction, + NodeAction, +) from dlrover.python.diagnosis.common.diagnosis_data import ( DiagnosisData, TrainingLog, @@ -38,6 +44,11 @@ DiagnosisDataManager, ) from dlrover.python.master.diagnosis.diagnosis_manager import DiagnosisManager +from dlrover.python.master.diagnosis.precheck_operator import ( + PreCheckOperator, + PreCheckResult, +) +from dlrover.python.master.node.job_context import get_job_context class DiagnosisManagerTest(unittest.TestCase): @@ -105,3 +116,37 @@ def test_diagnosis_manager(self): # explore solutions to observed problems action = mgr._diagnostician.resolve_problems(observed_problems) self.assertEqual(action.action_type, DiagnosisActionType.NONE) + + def test_pre_check(self): + job_context = get_job_context() + mgr = DiagnosisManager() + mgr.pre_check() + self.assertEqual(job_context._action_queue.len(), 0) + + mgr.get_pre_check_operators = MagicMock(return_value=[TestOperator()]) + mgr.pre_check() + self.assertTrue(isinstance(job_context.next_action(1), NodeAction)) + + +class TestOperator(PreCheckOperator): + @classmethod + def get_retry_interval_secs(cls) -> int: + return 1 + + @classmethod + def get_retry_limit_times(cls) -> int: + return 1 + + def check(self) -> PreCheckResult: + return PreCheckResult(1, "test", [1]) + + def recover(self): + pass + + def get_failed_action(self) -> DiagnosisAction: + return NodeAction( + node_id=1, + node_status=NodeStatus.FAILED, + reason="hang", + action_type=DiagnosisActionType.MASTER_RELAUNCH_WORKER, + ) diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py index c451fbb58..8751a1f22 100644 --- a/dlrover/python/tests/test_job_manager.py +++ b/dlrover/python/tests/test_job_manager.py @@ -995,3 +995,6 @@ def test_local_job_manager(self): worker = job_nodes[NodeType.WORKER][0] self.assertEqual(worker.paral_config, paral_config) job_manager.handle_training_failure(NodeType.WORKER, 3) + + job_context.set_pre_check_pass() + self.assertTrue(job_context.is_pre_check_pass()) diff --git a/dlrover/python/tests/test_pre_check_operator.py b/dlrover/python/tests/test_pre_check_operator.py new file mode 100644 index 000000000..1615c338e --- /dev/null +++ b/dlrover/python/tests/test_pre_check_operator.py @@ -0,0 +1,39 @@ +# Copyright 2025 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from dlrover.python.diagnosis.common.diagnosis_action import NoAction +from dlrover.python.master.diagnosis.precheck_operator import ( + NoPreCheckOperator, +) + + +class PreCheckOperatorTest(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + pass + + def test_no_pre_check_op(self): + op = NoPreCheckOperator() + self.assertTrue(op.check()) + op.recover() + self.assertEqual(op.get_retry_interval_secs(), 5) + self.assertEqual(op.get_retry_limit_times(), 3) + self.assertTrue(isinstance(op.get_failed_action(), NoAction)) + + +if __name__ == "__main__": + unittest.main() diff --git a/dlrover/trainer/tests/torch/elastic_run_test.py b/dlrover/trainer/tests/torch/elastic_run_test.py index de02757c9..8b11100d2 100644 --- a/dlrover/trainer/tests/torch/elastic_run_test.py +++ b/dlrover/trainer/tests/torch/elastic_run_test.py @@ -13,9 +13,11 @@ import socket import telnetlib +import threading +import time import unittest from unittest import mock -from unittest.mock import patch +from unittest.mock import MagicMock, patch from dlrover.python.elastic_agent.master_client import ( MasterClient, @@ -28,6 +30,7 @@ _elastic_config_from_args, _launch_dlrover_local_master, parse_args, + wait_pre_check, ) MC_PATH = "dlrover.python.elastic_agent.master_client.MasterClient" @@ -92,7 +95,7 @@ def test_elastic_config_from_args(self): self.assertEqual(config.node_unit, 4) self.assertEqual(config.rdzv_configs["node_unit"], 4) self.assertEqual(config.training_port, 1000) - self.assertEqual(cmd, "/usr/local/bin/python") + self.assertTrue("bin/python" in cmd) self.assertListEqual(cmd_args, ["-u", "test.py", "--batch_size", "16"]) @patch(f"{MC_PATH}.get_elastic_run_config") @@ -136,3 +139,26 @@ def test_elastic_config_from_master_2(self): args = parse_args(args) config, cmd, cmd_args = _elastic_config_from_args(args) self.assertFalse(config.network_check) + + def test_wait_pre_check(self): + client = MasterClient.singleton_instance() + + # pre-check success + client.get_pre_check_result = MagicMock(return_value=True) + wait_pre_check() + + # pre-check fail + client.get_pre_check_result = MagicMock(return_value=False) + + def set_pre_check_success(): + time_to_set_success = time.time() + while True: + if time.time() - time_to_set_success > 1: + client.get_pre_check_result = MagicMock(return_value=True) + break + time.sleep(0.1) + + start = time.time() + threading.Thread(target=set_pre_check_success).start() + wait_pre_check() + self.assertTrue(time.time() - start > 0.5) diff --git a/dlrover/trainer/torch/elastic_run.py b/dlrover/trainer/torch/elastic_run.py index a60645041..2495b3626 100644 --- a/dlrover/trainer/torch/elastic_run.py +++ b/dlrover/trainer/torch/elastic_run.py @@ -109,6 +109,7 @@ from dlrover.python.common import comm, env_utils from dlrover.python.common.constants import ( Accelerators, + JobConstant, NodeEnv, NodeErrorMessage, TrainingExceptionLevel, @@ -255,6 +256,7 @@ def __init__( def __call__(self, *args): if self._use_dlrover_launch: + wait_pre_check() return launch_agent(self._config, self._entrypoint, list(args)) else: return torch_launch_agent( @@ -262,6 +264,23 @@ def __call__(self, *args): ) +def wait_pre_check(): + """Wait master's pre-check result.""" + client = MasterClient.singleton_instance() + wait_secs = JobConstant.PRE_CHECK_WAIT_SECS + + while True: + if client.get_pre_check_result(): + logger.info("Pre check passed.") + break + else: + logger.info( + "Pre check not passed yet, " + f"wait for another {wait_secs}s..." + ) + time.sleep(wait_secs) + + def _launch_dlrover_local_master(master_addr, job_name, node_num): """Launch a subprocess to run the DLRover master.""" logger.info(f"Start dlrover master with addr {master_addr}") diff --git a/docs/deployment/argument.md b/docs/deployment/argument.md index e728fdbae..d1bb1a529 100644 --- a/docs/deployment/argument.md +++ b/docs/deployment/argument.md @@ -14,6 +14,7 @@ when training with DLRover. | pending_timeout | The timeout value of pending. | No | integer(unit: second) | 900 | \>=0 | | pending_fail_strategy | The fail strategy for pending case. | No | integer | 1 | -1: disabled
0: skip
1: verify necessary parts
2: verify all parts | | service_type | The type of master service. | No | string | grpc | grpc,http | +| pre_check | Whether to enable pre-check before training. | No | bool | True | True: enable pre-check False: disable pre-check | ## 2. Training Arguments diff --git a/docs/design/training-pre-check.md b/docs/design/training-pre-check.md new file mode 100644 index 000000000..fede12efb --- /dev/null +++ b/docs/design/training-pre-check.md @@ -0,0 +1,112 @@ +# Training Pre Check Procedure Introduction + +## Background + +As the scale of training production expands, the factors that may cause +runtime exceptions related to the environment during training are gradually +increasing. Although the current DLRover adds a so-called node-check before +training (which runs a small-scale training task to detect any abnormal +scenarios), it primarily observes availability from the training dimension +perspective. Therefore, it is not easy to specifically diagnose problematic +nodes and classify issues. + +So we are seeking a more advanced and precise check, +conducting certain related inspections right when all nodes are in place +(or not yet in place) to identify issues as early as possible and implement +fault tolerance or blockages in advance. The checks mentioned earlier rely more +on external services, as they require specialized and more precise validation +of components such as containers and networks. Traditional training frameworks +do not and should not involve these components. + +Considering that the actual +checks are highly related to the specific deployment environment and heavily +rely on external capabilities, there are currently no actual check +implementations added on the open-source side, only the exposure of interfaces +and the implementation of processes. This article only discusses the core +design of this aspect. + + +## Target + +- Support configurable multiple types of serial synchronous check before + training starts. (primary) +- Support limited checks during training execution. (secondary) + +## Design + +- The existing training start process: + +Current Training Start Process + +- The training start process combined with pre-check: + +New Training Start Process + +## Interface + +### PreCheckOperator +The PreCheckOperator is the core component for executing pre-checks. Multiple +implemented operators are driven by the DiagnosisManager to run before each +training session starts, as illustrated in the code below: + +```python +class PreCheckOperator(ABC): + + @classmethod + def get_retry_interval_secs(cls) -> int: + """The retry interval seconds, can be overridden in subclasses.""" + return 5 + + @classmethod + def get_retry_limit_times(cls) -> int: + """ + The retry limit times, can be overridden in subclasses. + The job starting procedure will be abort if result still not ok after + several retry times. + """ + return 3 + + @abstractmethod + def check(self) -> PreCheckResult: + """The abstraction of the main check procedure.""" + pass + + @abstractmethod + def recover(self): + """The abstraction of the procedure if check failed.""" + pass + + @abstractmethod + def get_failed_action(self) -> DiagnosisAction: + """The abstraction of the action when operator check failed.""" + pass +``` + +The execution process of each PreCheckOperator is roughly as follows: + +Pre Check Operator Process + +### PreCheckResult + +The PreCheckResult represents the result of the Operator's check. It includes +a numeric result to indicate the general outcome of the check +(similar to an error code), a string description, and a set to represent the +abnormal nodes. It is defined as follows: + +```python +@dataclass +class PreCheckResult(object): + + # The default success result is 0. The other result code(>0) should be + # defined by different pre-check operator it's self. + result: int = 0 + + # The simple description info for the result. + result_msg: str = "" + + # Abnormal nodes' id. + abnormal_nodes: List[int] = field(default_factory=list) + + def is_success(self): + return self.result == 0 +``` diff --git a/docs/figures/current_start_process.png b/docs/figures/current_start_process.png new file mode 100644 index 000000000..c7e6253eb Binary files /dev/null and b/docs/figures/current_start_process.png differ diff --git a/docs/figures/new_start_process.png b/docs/figures/new_start_process.png new file mode 100644 index 000000000..b41851ac8 Binary files /dev/null and b/docs/figures/new_start_process.png differ diff --git a/docs/figures/pre_check_op.png b/docs/figures/pre_check_op.png new file mode 100644 index 000000000..ccc7079fa Binary files /dev/null and b/docs/figures/pre_check_op.png differ diff --git a/go/elasticjob/pkg/controllers/master/master.go b/go/elasticjob/pkg/controllers/master/master.go index 8f4e121dc..470da56db 100644 --- a/go/elasticjob/pkg/controllers/master/master.go +++ b/go/elasticjob/pkg/controllers/master/master.go @@ -49,6 +49,7 @@ const ( pendingTimeoutArg = "pending_timeout" pendingFailStrategyArg = "pending_fail_strategy" serviceType = "service_type" + preCheckArg = "pre_check" ) // Manager generates a master pod object. @@ -239,7 +240,7 @@ func (m *Manager) StopRunningPods( } func getMasterArguments() []string { - return []string{pendingTimeoutArg, pendingFailStrategyArg, serviceType} + return []string{pendingTimeoutArg, pendingFailStrategyArg, serviceType, preCheckArg} } // NewMasterTemplateToJob sets configurations to the master template of a job.