diff --git a/codecov.yml b/codecov.yml
index fadfcdcc7..db4bfeaa9 100644
--- a/codecov.yml
+++ b/codecov.yml
@@ -15,7 +15,7 @@ coverage:
flag_coverage_not_uploaded_behavior: include
patch:
default:
- target: 80%
+ target: 85%
threshold: 3%
removed_code_behavior: fully_covered_patch
project:
diff --git a/dlrover/python/common/comm.py b/dlrover/python/common/comm.py
index 6e691f285..4bcc90b8c 100644
--- a/dlrover/python/common/comm.py
+++ b/dlrover/python/common/comm.py
@@ -228,6 +228,12 @@ class HeartBeat(Message):
timestamp: int = 0
+@dataclass
+class PreCheckRequest(Message):
+ timestamp: int = 0
+ type: str = "INITIAL"
+
+
@dataclass
class DatasetShardParams(Message):
batch_size: int = 0
diff --git a/dlrover/python/common/constants.py b/dlrover/python/common/constants.py
index d5babb9e9..3a510c218 100644
--- a/dlrover/python/common/constants.py
+++ b/dlrover/python/common/constants.py
@@ -385,8 +385,12 @@ class JobConstant(object):
# sleep 5s before next node check round
NODE_CHECK_NEXT_ROUND_TIMEOUT = 5
+ # default interval seconds for loop in training agent
TRAINING_AGENT_LOOP_DEFAULT_INTERVAL = 15
+ # interval seconds for pre-check waiting
+ PRE_CHECK_WAIT_SECS = 5
+
class Accelerators(object):
NVIDIA_GPU = "nvidia.com/gpu"
diff --git a/dlrover/python/common/global_context.py b/dlrover/python/common/global_context.py
index 1f29c81c0..8d9fc24d8 100644
--- a/dlrover/python/common/global_context.py
+++ b/dlrover/python/common/global_context.py
@@ -61,6 +61,7 @@ class DefaultValues(object):
GPU_NUM_PER_NODE = 8
NPU_NUM_PER_NODE = 16
MAX_METRIC_REC = 30
+ PRE_CHECK_ENABLE = True
class Context(Singleton):
@@ -107,6 +108,7 @@ def __init__(self):
# The strategy of 'hang detection':
# 0: log only; 1: notify; 2: with fault tolerance
self.hang_detection = DefaultValues.HANG_DETECTION
+ self.pre_check_enable = DefaultValues.PRE_CHECK_ENABLE
self.gpu_per_node = DefaultValues.GPU_NUM_PER_NODE
self.npu_per_node = DefaultValues.NPU_NUM_PER_NODE
self.max_metric_records = DefaultValues.MAX_METRIC_REC
diff --git a/dlrover/python/diagnosis/common/diagnosis_action.py b/dlrover/python/diagnosis/common/diagnosis_action.py
index 57b39797c..40c27fef2 100644
--- a/dlrover/python/diagnosis/common/diagnosis_action.py
+++ b/dlrover/python/diagnosis/common/diagnosis_action.py
@@ -260,6 +260,10 @@ def clear(self):
with self._lock:
self._actions.clear()
+ def len(self):
+ with self._lock:
+ return sum(len(d) for d in self._actions.values())
+
def next_action(
self,
instance=DiagnosisConstant.LOCAL_INSTANCE,
diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py
index 22346f56f..d0f8190fa 100644
--- a/dlrover/python/elastic_agent/master_client.py
+++ b/dlrover/python/elastic_agent/master_client.py
@@ -450,6 +450,11 @@ def get_elastic_run_config(self) -> Dict[str, str]:
response: comm.ElasticRunConfig = self._get(request)
return response.configs
+ def get_pre_check_result(self) -> bool:
+ request = comm.PreCheckRequest()
+ # result is boolean
+ return self._get(request)
+
def report_event(
self,
event_type: str = "",
diff --git a/dlrover/python/master/args.py b/dlrover/python/master/args.py
index a72e3cd35..dc3ce4f17 100644
--- a/dlrover/python/master/args.py
+++ b/dlrover/python/master/args.py
@@ -17,6 +17,17 @@
from dlrover.python.common.log import default_logger as logger
+def str2bool(value):
+ if isinstance(value, bool):
+ return value
+ if value.lower() in {"true", "yes", "t", "y", "1"}:
+ return True
+ elif value.lower() in {"false", "no", "n", "0"}:
+ return False
+ else:
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+
def add_params(parser):
parser.add_argument("--job_name", help="ElasticJob name", required=True)
parser.add_argument(
@@ -54,6 +65,13 @@ def add_params(parser):
type=str,
help="The service type of master: grpc/http.",
)
+ parser.add_argument(
+ "--pre_check",
+ "--pre_check",
+ default=DefaultValues.PRE_CHECK_ENABLE,
+ type=str2bool,
+ help="Enable pre training check or not.",
+ )
def print_args(args, exclude_args=[], groups=None):
diff --git a/dlrover/python/master/diagnosis/diagnosis.py b/dlrover/python/master/diagnosis/diagnosis.py
index c4c38457e..30f4fc50d 100644
--- a/dlrover/python/master/diagnosis/diagnosis.py
+++ b/dlrover/python/master/diagnosis/diagnosis.py
@@ -44,6 +44,9 @@ def get_observing_operators(self) -> List[InferenceOperator]:
def get_resolving_operators(self) -> List[InferenceOperator]:
return self._resolvers
+ def register_precheck_(self, problems: List[Inference]):
+ self._training_problems = problems
+
def register_training_problems(self, problems: List[Inference]):
self._training_problems = problems
diff --git a/dlrover/python/master/diagnosis/diagnosis_manager.py b/dlrover/python/master/diagnosis/diagnosis_manager.py
index 6df80f6c0..1d0f7d1f8 100644
--- a/dlrover/python/master/diagnosis/diagnosis_manager.py
+++ b/dlrover/python/master/diagnosis/diagnosis_manager.py
@@ -14,6 +14,7 @@
import threading
import time
+from dlrover.python.common.global_context import Context
from dlrover.python.common.log import default_logger as logger
from dlrover.python.diagnosis.common.constants import DiagnosisConstant
from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData
@@ -33,8 +34,13 @@
from dlrover.python.master.diagnosis.diagnosis_data_manager import (
DiagnosisDataManager,
)
+from dlrover.python.master.diagnosis.precheck_operator import (
+ NoPreCheckOperator,
+)
from dlrover.python.master.node.job_context import get_job_context
+_dlrover_ctx = Context.singleton_instance()
+
class DiagnosisManager:
"""
@@ -47,15 +53,80 @@ def __init__(self):
self._diagnostician: Diagnostician = Diagnostician(self._data_manager)
self._job_context = get_job_context()
+ @classmethod
+ def get_pre_check_operators(cls):
+ return [NoPreCheckOperator()]
+
def collect_diagnosis_data(self, data: DiagnosisData):
self._data_manager.store_data(data)
def pre_check(self):
- logger.info("Start Diagnosis Manager to pre-check training...")
- pass
+ if not _dlrover_ctx.pre_check_enable:
+ return
+
+ start = time.time()
+ pre_check_ops = self.get_pre_check_operators()
+ logger.info(
+ "Start to training pre-check" f"with operators: {pre_check_ops}."
+ )
+
+ for pre_check_op in pre_check_ops:
+ current_start = time.time()
+ current_op_result = None
+ pre_check_op_name = pre_check_op.__class__.__name__
+
+ try:
+ # retry loops for each operator
+ for i in range(pre_check_op.get_retry_limit_times()):
+ check_start = time.time()
+
+ # do check
+ current_op_result = pre_check_op.check()
+ logger.info(
+ f"{pre_check_op_name} "
+ f"check({i}) "
+ f"cost: {time.time()-check_start:.2f}ms, "
+ f"result: {current_op_result}"
+ )
+
+ if not current_op_result.is_success():
+ # try recover and wait
+ pre_check_op.recover()
+ time.sleep(pre_check_op.get_retry_interval_secs())
+
+ # check again after recover
+ current_op_result = pre_check_op.check()
+ else:
+ break
+ except Exception as e:
+ logger.error(f"Pre-check operator got unexpected error: {e}")
+ continue
+
+ if not current_op_result.is_success():
+ action = pre_check_op.get_failed_action()
+ self._job_context.enqueue_action(action)
+ logger.warning(
+ "Training pre-check failed "
+ f"by {pre_check_op_name} "
+ f"with result: {current_op_result}, "
+ f"cost:{time.time()-current_start:.2f}ms. "
+ f"Invoke action: {action}."
+ )
+ return
+ else:
+ self._job_context.set_pre_check_pass()
+ logger.info(
+ f"{pre_check_op_name} finish "
+ f"with result: {current_op_result}, "
+ f"cost:{time.time()-current_start:.2f}ms."
+ )
+
+ logger.info(
+ "Training pre-check complete, " f"cost:{time.time()-start:.2f}ms."
+ )
def start_observing(self):
- logger.info("Start Diagnosis Manager to observing training...")
+ logger.info("Start to observing training...")
self._is_observing_started = True
self._diagnostician.register_training_problems(
diff --git a/dlrover/python/master/diagnosis/precheck_operator.py b/dlrover/python/master/diagnosis/precheck_operator.py
new file mode 100644
index 000000000..146a9da23
--- /dev/null
+++ b/dlrover/python/master/diagnosis/precheck_operator.py
@@ -0,0 +1,82 @@
+# Copyright 2025 The DLRover Authors. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from typing import List
+
+from dlrover.python.diagnosis.common.diagnosis_action import (
+ DiagnosisAction,
+ NoAction,
+)
+
+
+@dataclass
+class PreCheckResult(object):
+
+ # The default success result is 0. The other result code(>0) should be
+ # defined by different pre-check operator it's self.
+ result: int = 0
+
+ # The simple description info for the result.
+ result_msg: str = ""
+
+ # Abnormal nodes' id.
+ abnormal_nodes: List[int] = field(default_factory=list)
+
+ def is_success(self):
+ return self.result == 0
+
+
+class PreCheckOperator(ABC):
+ @classmethod
+ def get_retry_interval_secs(cls) -> int:
+ """The retry interval seconds, can be overridden in subclasses."""
+ return 5
+
+ @classmethod
+ def get_retry_limit_times(cls) -> int:
+ """
+ The retry limit times, can be overridden in subclasses. For most
+ pre-check, the retry value should > 1(at least once retry).
+
+ The failed action will be executed if result still not ok after
+ several retry times.
+ """
+ return 3
+
+ @abstractmethod
+ def check(self) -> PreCheckResult:
+ """The abstraction of the main check procedure."""
+ pass
+
+ @abstractmethod
+ def recover(self):
+ """The abstraction of the procedure if check failed."""
+ pass
+
+ @abstractmethod
+ def get_failed_action(self) -> DiagnosisAction:
+ """The abstraction of the action when operator check failed."""
+ pass
+
+
+class NoPreCheckOperator(PreCheckOperator):
+ def check(self):
+ return PreCheckResult()
+
+ def recover(self):
+ return
+
+ def get_failed_action(self) -> DiagnosisAction:
+ return NoAction()
diff --git a/dlrover/python/master/dist_master.py b/dlrover/python/master/dist_master.py
index c10ec853e..b003270f1 100644
--- a/dlrover/python/master/dist_master.py
+++ b/dlrover/python/master/dist_master.py
@@ -206,7 +206,6 @@ def prepare(self):
def pre_check(self):
logger.info("Pre-check before running.")
self.diagnosis_manager.pre_check()
- # TODO
def _add_node_event_callback(self):
"""Add NodeEventCallbacks for the listeners of Pod events."""
diff --git a/dlrover/python/master/node/job_context.py b/dlrover/python/master/node/job_context.py
index 816339b9d..677663fbb 100644
--- a/dlrover/python/master/node/job_context.py
+++ b/dlrover/python/master/node/job_context.py
@@ -17,6 +17,7 @@
from typing import Dict, Optional, Union
from dlrover.python.common.constants import NodeType
+from dlrover.python.common.global_context import Context
from dlrover.python.common.node import Node
from dlrover.python.common.singleton import Singleton
from dlrover.python.diagnosis.common.constants import (
@@ -27,6 +28,8 @@
DiagnosisActionQueue,
)
+_dlrover_context = Context.singleton_instance()
+
class JobContext(Singleton):
"""
@@ -38,6 +41,7 @@ def __init__(self):
self._action_queue = DiagnosisActionQueue()
self._job_nodes: Dict[str, Dict[int, Node]] = {}
self._failed_nodes: Dict[int, int] = {}
+ self._pre_check_pass = False
self._locker = threading.Lock()
def enqueue_action(self, action):
@@ -193,6 +197,14 @@ def report_failed_node(self, node_id: Union[int, str] = None):
def get_failed_node_cnt(self):
return len(self._failed_nodes)
+ def set_pre_check_pass(self):
+ self._pre_check_pass = True
+
+ def is_pre_check_pass(self):
+ if _dlrover_context.pre_check_enable:
+ return self._pre_check_pass
+ return True
+
def get_job_context() -> JobContext:
job_context = JobContext.singleton_instance()
diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py
index 3e1d16484..2b46ec325 100644
--- a/dlrover/python/master/servicer.py
+++ b/dlrover/python/master/servicer.py
@@ -49,6 +49,7 @@
RendezvousManager,
)
from dlrover.python.master.monitor.speed_monitor import SpeedMonitor
+from dlrover.python.master.node.job_context import get_job_context
from dlrover.python.master.node.job_manager import JobManager
from dlrover.python.master.node.training_node import SyncNodeTrainingPorts
from dlrover.python.master.shard.dataset_splitter import new_dataset_splitter
@@ -161,6 +162,10 @@ def get(self, request, _):
elif isinstance(req_message, comm.ElasticRunConfigRequest):
configs = self._job_manager.get_elastic_run_configs()
message = comm.ElasticRunConfig(configs=configs)
+ elif isinstance(req_message, comm.PreCheckRequest):
+ message = self._get_pre_check_result(
+ node_type, node_id, req_message
+ )
elif isinstance(req_message, comm.HeartBeat):
message = self._report_heartbeat(node_type, node_id, req_message)
@@ -664,6 +669,9 @@ def _sync_training_ports(
port=sync_ports.training_port, newport=sync_ports.next_check_port
)
+ def _get_pre_check_result(self, node_id, message: comm.PreCheckRequest):
+ return get_job_context().is_pre_check_pass()
+
def _report_event(self, message: comm.Event):
if self._error_monitor:
self._error_monitor.report_event(
diff --git a/dlrover/python/master/watcher/k8s_watcher.py b/dlrover/python/master/watcher/k8s_watcher.py
index 15fce0547..2e6b8ca7d 100644
--- a/dlrover/python/master/watcher/k8s_watcher.py
+++ b/dlrover/python/master/watcher/k8s_watcher.py
@@ -106,6 +106,11 @@ def _convert_pod_event_to_node_event(event):
if metadata.deletion_timestamp:
status = NodeStatus.DELETED
+ logger.debug(
+ f"Got monitor event for pod: {pod_name}, "
+ f"node: {host_name}, ip: {host_ip}, status: {status}."
+ )
+
restart = _verify_restarting_training(evt_obj)
if restart:
logger.info(f"{evt_obj.metadata.name} need to restart.")
diff --git a/dlrover/python/tests/test_args.py b/dlrover/python/tests/test_args.py
index eb49386db..fa9860770 100644
--- a/dlrover/python/tests/test_args.py
+++ b/dlrover/python/tests/test_args.py
@@ -13,10 +13,28 @@
import unittest
-from dlrover.python.master.args import parse_master_args
+from dlrover.python.master.args import parse_master_args, str2bool
class ArgsTest(unittest.TestCase):
+ def test_str2bool(self):
+ self.assertTrue(str2bool("TRUE"))
+ self.assertTrue(str2bool("True"))
+ self.assertTrue(str2bool("true"))
+ self.assertTrue(str2bool("yes"))
+ self.assertTrue(str2bool("t"))
+ self.assertTrue(str2bool("y"))
+ self.assertTrue(str2bool("1"))
+ self.assertTrue(str2bool(True))
+
+ self.assertFalse(str2bool("FALSE"))
+ self.assertFalse(str2bool("False"))
+ self.assertFalse(str2bool("false"))
+ self.assertFalse(str2bool("no"))
+ self.assertFalse(str2bool("n"))
+ self.assertFalse(str2bool("0"))
+ self.assertFalse(str2bool(False))
+
def test_parse_master_args(self):
original_args = [
"--job_name",
@@ -30,6 +48,7 @@ def test_parse_master_args(self):
self.assertEqual(parsed_args.pending_timeout, 900)
self.assertEqual(parsed_args.pending_fail_strategy, 1)
self.assertTrue(parsed_args.service_type, "grpc")
+ self.assertTrue(parsed_args.pre_check)
original_args = [
"--job_name",
@@ -42,8 +61,11 @@ def test_parse_master_args(self):
"2",
"--service_type",
"http",
+ "--pre_check",
+ "false",
]
parsed_args = parse_master_args(original_args)
self.assertEqual(parsed_args.pending_timeout, 600)
self.assertEqual(parsed_args.pending_fail_strategy, 2)
self.assertTrue(parsed_args.service_type, "http")
+ self.assertFalse(parsed_args.pre_check)
diff --git a/dlrover/python/tests/test_diagnosis_manager.py b/dlrover/python/tests/test_diagnosis_manager.py
index 4a721b9e1..dbc3993a5 100644
--- a/dlrover/python/tests/test_diagnosis_manager.py
+++ b/dlrover/python/tests/test_diagnosis_manager.py
@@ -15,11 +15,17 @@
import unittest
from typing import List
from unittest import mock
+from unittest.mock import MagicMock
+from dlrover.python.common.constants import NodeStatus
from dlrover.python.diagnosis.common.constants import (
DiagnosisActionType,
DiagnosisDataType,
)
+from dlrover.python.diagnosis.common.diagnosis_action import (
+ DiagnosisAction,
+ NodeAction,
+)
from dlrover.python.diagnosis.common.diagnosis_data import (
DiagnosisData,
TrainingLog,
@@ -38,6 +44,11 @@
DiagnosisDataManager,
)
from dlrover.python.master.diagnosis.diagnosis_manager import DiagnosisManager
+from dlrover.python.master.diagnosis.precheck_operator import (
+ PreCheckOperator,
+ PreCheckResult,
+)
+from dlrover.python.master.node.job_context import get_job_context
class DiagnosisManagerTest(unittest.TestCase):
@@ -105,3 +116,37 @@ def test_diagnosis_manager(self):
# explore solutions to observed problems
action = mgr._diagnostician.resolve_problems(observed_problems)
self.assertEqual(action.action_type, DiagnosisActionType.NONE)
+
+ def test_pre_check(self):
+ job_context = get_job_context()
+ mgr = DiagnosisManager()
+ mgr.pre_check()
+ self.assertEqual(job_context._action_queue.len(), 0)
+
+ mgr.get_pre_check_operators = MagicMock(return_value=[TestOperator()])
+ mgr.pre_check()
+ self.assertTrue(isinstance(job_context.next_action(1), NodeAction))
+
+
+class TestOperator(PreCheckOperator):
+ @classmethod
+ def get_retry_interval_secs(cls) -> int:
+ return 1
+
+ @classmethod
+ def get_retry_limit_times(cls) -> int:
+ return 1
+
+ def check(self) -> PreCheckResult:
+ return PreCheckResult(1, "test", [1])
+
+ def recover(self):
+ pass
+
+ def get_failed_action(self) -> DiagnosisAction:
+ return NodeAction(
+ node_id=1,
+ node_status=NodeStatus.FAILED,
+ reason="hang",
+ action_type=DiagnosisActionType.MASTER_RELAUNCH_WORKER,
+ )
diff --git a/dlrover/python/tests/test_job_manager.py b/dlrover/python/tests/test_job_manager.py
index c451fbb58..8751a1f22 100644
--- a/dlrover/python/tests/test_job_manager.py
+++ b/dlrover/python/tests/test_job_manager.py
@@ -995,3 +995,6 @@ def test_local_job_manager(self):
worker = job_nodes[NodeType.WORKER][0]
self.assertEqual(worker.paral_config, paral_config)
job_manager.handle_training_failure(NodeType.WORKER, 3)
+
+ job_context.set_pre_check_pass()
+ self.assertTrue(job_context.is_pre_check_pass())
diff --git a/dlrover/python/tests/test_pre_check_operator.py b/dlrover/python/tests/test_pre_check_operator.py
new file mode 100644
index 000000000..1615c338e
--- /dev/null
+++ b/dlrover/python/tests/test_pre_check_operator.py
@@ -0,0 +1,39 @@
+# Copyright 2025 The DLRover Authors. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from dlrover.python.diagnosis.common.diagnosis_action import NoAction
+from dlrover.python.master.diagnosis.precheck_operator import (
+ NoPreCheckOperator,
+)
+
+
+class PreCheckOperatorTest(unittest.TestCase):
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ def test_no_pre_check_op(self):
+ op = NoPreCheckOperator()
+ self.assertTrue(op.check())
+ op.recover()
+ self.assertEqual(op.get_retry_interval_secs(), 5)
+ self.assertEqual(op.get_retry_limit_times(), 3)
+ self.assertTrue(isinstance(op.get_failed_action(), NoAction))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/dlrover/trainer/tests/torch/elastic_run_test.py b/dlrover/trainer/tests/torch/elastic_run_test.py
index de02757c9..8b11100d2 100644
--- a/dlrover/trainer/tests/torch/elastic_run_test.py
+++ b/dlrover/trainer/tests/torch/elastic_run_test.py
@@ -13,9 +13,11 @@
import socket
import telnetlib
+import threading
+import time
import unittest
from unittest import mock
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
from dlrover.python.elastic_agent.master_client import (
MasterClient,
@@ -28,6 +30,7 @@
_elastic_config_from_args,
_launch_dlrover_local_master,
parse_args,
+ wait_pre_check,
)
MC_PATH = "dlrover.python.elastic_agent.master_client.MasterClient"
@@ -92,7 +95,7 @@ def test_elastic_config_from_args(self):
self.assertEqual(config.node_unit, 4)
self.assertEqual(config.rdzv_configs["node_unit"], 4)
self.assertEqual(config.training_port, 1000)
- self.assertEqual(cmd, "/usr/local/bin/python")
+ self.assertTrue("bin/python" in cmd)
self.assertListEqual(cmd_args, ["-u", "test.py", "--batch_size", "16"])
@patch(f"{MC_PATH}.get_elastic_run_config")
@@ -136,3 +139,26 @@ def test_elastic_config_from_master_2(self):
args = parse_args(args)
config, cmd, cmd_args = _elastic_config_from_args(args)
self.assertFalse(config.network_check)
+
+ def test_wait_pre_check(self):
+ client = MasterClient.singleton_instance()
+
+ # pre-check success
+ client.get_pre_check_result = MagicMock(return_value=True)
+ wait_pre_check()
+
+ # pre-check fail
+ client.get_pre_check_result = MagicMock(return_value=False)
+
+ def set_pre_check_success():
+ time_to_set_success = time.time()
+ while True:
+ if time.time() - time_to_set_success > 1:
+ client.get_pre_check_result = MagicMock(return_value=True)
+ break
+ time.sleep(0.1)
+
+ start = time.time()
+ threading.Thread(target=set_pre_check_success).start()
+ wait_pre_check()
+ self.assertTrue(time.time() - start > 0.5)
diff --git a/dlrover/trainer/torch/elastic_run.py b/dlrover/trainer/torch/elastic_run.py
index a60645041..2495b3626 100644
--- a/dlrover/trainer/torch/elastic_run.py
+++ b/dlrover/trainer/torch/elastic_run.py
@@ -109,6 +109,7 @@
from dlrover.python.common import comm, env_utils
from dlrover.python.common.constants import (
Accelerators,
+ JobConstant,
NodeEnv,
NodeErrorMessage,
TrainingExceptionLevel,
@@ -255,6 +256,7 @@ def __init__(
def __call__(self, *args):
if self._use_dlrover_launch:
+ wait_pre_check()
return launch_agent(self._config, self._entrypoint, list(args))
else:
return torch_launch_agent(
@@ -262,6 +264,23 @@ def __call__(self, *args):
)
+def wait_pre_check():
+ """Wait master's pre-check result."""
+ client = MasterClient.singleton_instance()
+ wait_secs = JobConstant.PRE_CHECK_WAIT_SECS
+
+ while True:
+ if client.get_pre_check_result():
+ logger.info("Pre check passed.")
+ break
+ else:
+ logger.info(
+ "Pre check not passed yet, "
+ f"wait for another {wait_secs}s..."
+ )
+ time.sleep(wait_secs)
+
+
def _launch_dlrover_local_master(master_addr, job_name, node_num):
"""Launch a subprocess to run the DLRover master."""
logger.info(f"Start dlrover master with addr {master_addr}")
diff --git a/docs/deployment/argument.md b/docs/deployment/argument.md
index e728fdbae..d1bb1a529 100644
--- a/docs/deployment/argument.md
+++ b/docs/deployment/argument.md
@@ -14,6 +14,7 @@ when training with DLRover.
| pending_timeout | The timeout value of pending. | No | integer(unit: second) | 900 | \>=0 |
| pending_fail_strategy | The fail strategy for pending case. | No | integer | 1 | -1: disabled
0: skip
1: verify necessary parts
2: verify all parts |
| service_type | The type of master service. | No | string | grpc | grpc,http |
+| pre_check | Whether to enable pre-check before training. | No | bool | True | True: enable pre-check False: disable pre-check |
## 2. Training Arguments
diff --git a/docs/design/training-pre-check.md b/docs/design/training-pre-check.md
new file mode 100644
index 000000000..fede12efb
--- /dev/null
+++ b/docs/design/training-pre-check.md
@@ -0,0 +1,112 @@
+# Training Pre Check Procedure Introduction
+
+## Background
+
+As the scale of training production expands, the factors that may cause
+runtime exceptions related to the environment during training are gradually
+increasing. Although the current DLRover adds a so-called node-check before
+training (which runs a small-scale training task to detect any abnormal
+scenarios), it primarily observes availability from the training dimension
+perspective. Therefore, it is not easy to specifically diagnose problematic
+nodes and classify issues.
+
+So we are seeking a more advanced and precise check,
+conducting certain related inspections right when all nodes are in place
+(or not yet in place) to identify issues as early as possible and implement
+fault tolerance or blockages in advance. The checks mentioned earlier rely more
+on external services, as they require specialized and more precise validation
+of components such as containers and networks. Traditional training frameworks
+do not and should not involve these components.
+
+Considering that the actual
+checks are highly related to the specific deployment environment and heavily
+rely on external capabilities, there are currently no actual check
+implementations added on the open-source side, only the exposure of interfaces
+and the implementation of processes. This article only discusses the core
+design of this aspect.
+
+
+## Target
+
+- Support configurable multiple types of serial synchronous check before
+ training starts. (primary)
+- Support limited checks during training execution. (secondary)
+
+## Design
+
+- The existing training start process:
+
+
+
+- The training start process combined with pre-check:
+
+
+
+## Interface
+
+### PreCheckOperator
+The PreCheckOperator is the core component for executing pre-checks. Multiple
+implemented operators are driven by the DiagnosisManager to run before each
+training session starts, as illustrated in the code below:
+
+```python
+class PreCheckOperator(ABC):
+
+ @classmethod
+ def get_retry_interval_secs(cls) -> int:
+ """The retry interval seconds, can be overridden in subclasses."""
+ return 5
+
+ @classmethod
+ def get_retry_limit_times(cls) -> int:
+ """
+ The retry limit times, can be overridden in subclasses.
+ The job starting procedure will be abort if result still not ok after
+ several retry times.
+ """
+ return 3
+
+ @abstractmethod
+ def check(self) -> PreCheckResult:
+ """The abstraction of the main check procedure."""
+ pass
+
+ @abstractmethod
+ def recover(self):
+ """The abstraction of the procedure if check failed."""
+ pass
+
+ @abstractmethod
+ def get_failed_action(self) -> DiagnosisAction:
+ """The abstraction of the action when operator check failed."""
+ pass
+```
+
+The execution process of each PreCheckOperator is roughly as follows:
+
+
+
+### PreCheckResult
+
+The PreCheckResult represents the result of the Operator's check. It includes
+a numeric result to indicate the general outcome of the check
+(similar to an error code), a string description, and a set to represent the
+abnormal nodes. It is defined as follows:
+
+```python
+@dataclass
+class PreCheckResult(object):
+
+ # The default success result is 0. The other result code(>0) should be
+ # defined by different pre-check operator it's self.
+ result: int = 0
+
+ # The simple description info for the result.
+ result_msg: str = ""
+
+ # Abnormal nodes' id.
+ abnormal_nodes: List[int] = field(default_factory=list)
+
+ def is_success(self):
+ return self.result == 0
+```
diff --git a/docs/figures/current_start_process.png b/docs/figures/current_start_process.png
new file mode 100644
index 000000000..c7e6253eb
Binary files /dev/null and b/docs/figures/current_start_process.png differ
diff --git a/docs/figures/new_start_process.png b/docs/figures/new_start_process.png
new file mode 100644
index 000000000..b41851ac8
Binary files /dev/null and b/docs/figures/new_start_process.png differ
diff --git a/docs/figures/pre_check_op.png b/docs/figures/pre_check_op.png
new file mode 100644
index 000000000..ccc7079fa
Binary files /dev/null and b/docs/figures/pre_check_op.png differ
diff --git a/go/elasticjob/pkg/controllers/master/master.go b/go/elasticjob/pkg/controllers/master/master.go
index 8f4e121dc..470da56db 100644
--- a/go/elasticjob/pkg/controllers/master/master.go
+++ b/go/elasticjob/pkg/controllers/master/master.go
@@ -49,6 +49,7 @@ const (
pendingTimeoutArg = "pending_timeout"
pendingFailStrategyArg = "pending_fail_strategy"
serviceType = "service_type"
+ preCheckArg = "pre_check"
)
// Manager generates a master pod object.
@@ -239,7 +240,7 @@ func (m *Manager) StopRunningPods(
}
func getMasterArguments() []string {
- return []string{pendingTimeoutArg, pendingFailStrategyArg, serviceType}
+ return []string{pendingTimeoutArg, pendingFailStrategyArg, serviceType, preCheckArg}
}
// NewMasterTemplateToJob sets configurations to the master template of a job.