Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New precheck procedure to enhance stability. #1453

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ coverage:
flag_coverage_not_uploaded_behavior: include
patch:
default:
target: 80%
target: 85%
threshold: 3%
removed_code_behavior: fully_covered_patch
project:
Expand Down
6 changes: 6 additions & 0 deletions dlrover/python/common/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ class HeartBeat(Message):
timestamp: int = 0


@dataclass
class PreCheckRequest(Message):
timestamp: int = 0
type: str = "INITIAL"


@dataclass
class DatasetShardParams(Message):
batch_size: int = 0
Expand Down
4 changes: 4 additions & 0 deletions dlrover/python/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,12 @@ class JobConstant(object):
# sleep 5s before next node check round
NODE_CHECK_NEXT_ROUND_TIMEOUT = 5

# default interval seconds for loop in training agent
TRAINING_AGENT_LOOP_DEFAULT_INTERVAL = 15

# interval seconds for pre-check waiting
PRE_CHECK_WAIT_SECS = 5


class Accelerators(object):
NVIDIA_GPU = "nvidia.com/gpu"
Expand Down
2 changes: 2 additions & 0 deletions dlrover/python/common/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class DefaultValues(object):
GPU_NUM_PER_NODE = 8
NPU_NUM_PER_NODE = 16
MAX_METRIC_REC = 30
PRE_CHECK_ENABLE = True


class Context(Singleton):
Expand Down Expand Up @@ -107,6 +108,7 @@ def __init__(self):
# The strategy of 'hang detection':
# 0: log only; 1: notify; 2: with fault tolerance
self.hang_detection = DefaultValues.HANG_DETECTION
self.pre_check_enable = DefaultValues.PRE_CHECK_ENABLE
self.gpu_per_node = DefaultValues.GPU_NUM_PER_NODE
self.npu_per_node = DefaultValues.NPU_NUM_PER_NODE
self.max_metric_records = DefaultValues.MAX_METRIC_REC
Expand Down
4 changes: 4 additions & 0 deletions dlrover/python/diagnosis/common/diagnosis_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ def clear(self):
with self._lock:
self._actions.clear()

def len(self):
with self._lock:
return sum(len(d) for d in self._actions.values())

def next_action(
self,
instance=DiagnosisConstant.LOCAL_INSTANCE,
Expand Down
5 changes: 5 additions & 0 deletions dlrover/python/elastic_agent/master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,11 @@
response: comm.ElasticRunConfig = self._get(request)
return response.configs

def get_pre_check_result(self) -> bool:
request = comm.PreCheckRequest()

Check warning on line 454 in dlrover/python/elastic_agent/master_client.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/elastic_agent/master_client.py#L454

Added line #L454 was not covered by tests
# result is boolean
return self._get(request)

Check warning on line 456 in dlrover/python/elastic_agent/master_client.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/elastic_agent/master_client.py#L456

Added line #L456 was not covered by tests

def report_event(
self,
event_type: str = "",
Expand Down
18 changes: 18 additions & 0 deletions dlrover/python/master/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@
from dlrover.python.common.log import default_logger as logger


def str2bool(value):
if isinstance(value, bool):
return value
if value.lower() in {"true", "yes", "t", "y", "1"}:
return True
elif value.lower() in {"false", "no", "n", "0"}:
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")

Check warning on line 28 in dlrover/python/master/args.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/master/args.py#L28

Added line #L28 was not covered by tests


def add_params(parser):
parser.add_argument("--job_name", help="ElasticJob name", required=True)
parser.add_argument(
Expand Down Expand Up @@ -54,6 +65,13 @@
type=str,
help="The service type of master: grpc/http.",
)
parser.add_argument(
"--pre_check",
"--pre_check",
default=DefaultValues.PRE_CHECK_ENABLE,
type=str2bool,
help="Enable pre training check or not.",
)


def print_args(args, exclude_args=[], groups=None):
Expand Down
3 changes: 3 additions & 0 deletions dlrover/python/master/diagnosis/diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
def get_resolving_operators(self) -> List[InferenceOperator]:
return self._resolvers

def register_precheck_(self, problems: List[Inference]):
self._training_problems = problems

Check warning on line 48 in dlrover/python/master/diagnosis/diagnosis.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/master/diagnosis/diagnosis.py#L48

Added line #L48 was not covered by tests

def register_training_problems(self, problems: List[Inference]):
self._training_problems = problems

Expand Down
77 changes: 74 additions & 3 deletions dlrover/python/master/diagnosis/diagnosis_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import threading
import time

from dlrover.python.common.global_context import Context
from dlrover.python.common.log import default_logger as logger
from dlrover.python.diagnosis.common.constants import DiagnosisConstant
from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData
Expand All @@ -33,8 +34,13 @@
from dlrover.python.master.diagnosis.diagnosis_data_manager import (
DiagnosisDataManager,
)
from dlrover.python.master.diagnosis.precheck_operator import (
NoPreCheckOperator,
)
from dlrover.python.master.node.job_context import get_job_context

_dlrover_ctx = Context.singleton_instance()


class DiagnosisManager:
"""
Expand All @@ -47,15 +53,80 @@
self._diagnostician: Diagnostician = Diagnostician(self._data_manager)
self._job_context = get_job_context()

@classmethod
def get_pre_check_operators(cls):
return [NoPreCheckOperator()]

def collect_diagnosis_data(self, data: DiagnosisData):
self._data_manager.store_data(data)

def pre_check(self):
logger.info("Start Diagnosis Manager to pre-check training...")
pass
if not _dlrover_ctx.pre_check_enable:
return

Check warning on line 65 in dlrover/python/master/diagnosis/diagnosis_manager.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/master/diagnosis/diagnosis_manager.py#L65

Added line #L65 was not covered by tests

start = time.time()
pre_check_ops = self.get_pre_check_operators()
logger.info(
"Start to training pre-check" f"with operators: {pre_check_ops}."
)

for pre_check_op in pre_check_ops:
current_start = time.time()
current_op_result = None
pre_check_op_name = pre_check_op.__class__.__name__

try:
# retry loops for each operator
for i in range(pre_check_op.get_retry_limit_times()):
check_start = time.time()

# do check
current_op_result = pre_check_op.check()
logger.info(
f"{pre_check_op_name} "
f"check({i}) "
f"cost: {time.time()-check_start:.2f}ms, "
f"result: {current_op_result}"
)

if not current_op_result.is_success():
# try recover and wait
pre_check_op.recover()
time.sleep(pre_check_op.get_retry_interval_secs())

# check again after recover
current_op_result = pre_check_op.check()
else:
break
except Exception as e:
logger.error(f"Pre-check operator got unexpected error: {e}")
continue

Check warning on line 103 in dlrover/python/master/diagnosis/diagnosis_manager.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/master/diagnosis/diagnosis_manager.py#L101-L103

Added lines #L101 - L103 were not covered by tests

if not current_op_result.is_success():
action = pre_check_op.get_failed_action()
self._job_context.enqueue_action(action)
logger.warning(
"Training pre-check failed "
f"by {pre_check_op_name} "
f"with result: {current_op_result}, "
f"cost:{time.time()-current_start:.2f}ms. "
f"Invoke action: {action}."
)
return
else:
self._job_context.set_pre_check_pass()
logger.info(
f"{pre_check_op_name} finish "
f"with result: {current_op_result}, "
f"cost:{time.time()-current_start:.2f}ms."
)

logger.info(
"Training pre-check complete, " f"cost:{time.time()-start:.2f}ms."
)

def start_observing(self):
logger.info("Start Diagnosis Manager to observing training...")
logger.info("Start to observing training...")
self._is_observing_started = True

self._diagnostician.register_training_problems(
Expand Down
82 changes: 82 additions & 0 deletions dlrover/python/master/diagnosis/precheck_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2025 The DLRover Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List

from dlrover.python.diagnosis.common.diagnosis_action import (
DiagnosisAction,
NoAction,
)


@dataclass
class PreCheckResult(object):

# The default success result is 0. The other result code(>0) should be
# defined by different pre-check operator it's self.
result: int = 0

# The simple description info for the result.
result_msg: str = ""

# Abnormal nodes' id.
abnormal_nodes: List[int] = field(default_factory=list)

def is_success(self):
return self.result == 0


class PreCheckOperator(ABC):
@classmethod
def get_retry_interval_secs(cls) -> int:
"""The retry interval seconds, can be overridden in subclasses."""
return 5

@classmethod
def get_retry_limit_times(cls) -> int:
"""
The retry limit times, can be overridden in subclasses. For most
pre-check, the retry value should > 1(at least once retry).

The failed action will be executed if result still not ok after
several retry times.
"""
return 3

@abstractmethod
def check(self) -> PreCheckResult:
"""The abstraction of the main check procedure."""
pass

Check warning on line 61 in dlrover/python/master/diagnosis/precheck_operator.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/master/diagnosis/precheck_operator.py#L61

Added line #L61 was not covered by tests

@abstractmethod
def recover(self):
"""The abstraction of the procedure if check failed."""
pass

Check warning on line 66 in dlrover/python/master/diagnosis/precheck_operator.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/master/diagnosis/precheck_operator.py#L66

Added line #L66 was not covered by tests

@abstractmethod
def get_failed_action(self) -> DiagnosisAction:
"""The abstraction of the action when operator check failed."""
pass

Check warning on line 71 in dlrover/python/master/diagnosis/precheck_operator.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/master/diagnosis/precheck_operator.py#L71

Added line #L71 was not covered by tests


class NoPreCheckOperator(PreCheckOperator):
def check(self):
return PreCheckResult()

def recover(self):
return

def get_failed_action(self) -> DiagnosisAction:
return NoAction()
1 change: 0 additions & 1 deletion dlrover/python/master/dist_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def prepare(self):
def pre_check(self):
logger.info("Pre-check before running.")
self.diagnosis_manager.pre_check()
# TODO

def _add_node_event_callback(self):
"""Add NodeEventCallbacks for the listeners of Pod events."""
Expand Down
12 changes: 12 additions & 0 deletions dlrover/python/master/node/job_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Dict, Optional, Union

from dlrover.python.common.constants import NodeType
from dlrover.python.common.global_context import Context
from dlrover.python.common.node import Node
from dlrover.python.common.singleton import Singleton
from dlrover.python.diagnosis.common.constants import (
Expand All @@ -27,6 +28,8 @@
DiagnosisActionQueue,
)

_dlrover_context = Context.singleton_instance()


class JobContext(Singleton):
"""
Expand All @@ -38,6 +41,7 @@
self._action_queue = DiagnosisActionQueue()
self._job_nodes: Dict[str, Dict[int, Node]] = {}
self._failed_nodes: Dict[int, int] = {}
self._pre_check_pass = False
self._locker = threading.Lock()

def enqueue_action(self, action):
Expand Down Expand Up @@ -193,6 +197,14 @@
def get_failed_node_cnt(self):
return len(self._failed_nodes)

def set_pre_check_pass(self):
self._pre_check_pass = True

def is_pre_check_pass(self):
if _dlrover_context.pre_check_enable:
return self._pre_check_pass
return True

Check warning on line 206 in dlrover/python/master/node/job_context.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/master/node/job_context.py#L206

Added line #L206 was not covered by tests


def get_job_context() -> JobContext:
job_context = JobContext.singleton_instance()
Expand Down
8 changes: 8 additions & 0 deletions dlrover/python/master/servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
RendezvousManager,
)
from dlrover.python.master.monitor.speed_monitor import SpeedMonitor
from dlrover.python.master.node.job_context import get_job_context
from dlrover.python.master.node.job_manager import JobManager
from dlrover.python.master.node.training_node import SyncNodeTrainingPorts
from dlrover.python.master.shard.dataset_splitter import new_dataset_splitter
Expand Down Expand Up @@ -161,6 +162,10 @@
elif isinstance(req_message, comm.ElasticRunConfigRequest):
configs = self._job_manager.get_elastic_run_configs()
message = comm.ElasticRunConfig(configs=configs)
elif isinstance(req_message, comm.PreCheckRequest):
message = self._get_pre_check_result(

Check warning on line 166 in dlrover/python/master/servicer.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/master/servicer.py#L166

Added line #L166 was not covered by tests
node_type, node_id, req_message
)
elif isinstance(req_message, comm.HeartBeat):
message = self._report_heartbeat(node_type, node_id, req_message)

Expand Down Expand Up @@ -664,6 +669,9 @@
port=sync_ports.training_port, newport=sync_ports.next_check_port
)

def _get_pre_check_result(self, node_id, message: comm.PreCheckRequest):
return get_job_context().is_pre_check_pass()

Check warning on line 673 in dlrover/python/master/servicer.py

View check run for this annotation

Codecov / codecov/patch

dlrover/python/master/servicer.py#L673

Added line #L673 was not covered by tests

def _report_event(self, message: comm.Event):
if self._error_monitor:
self._error_monitor.report_event(
Expand Down
5 changes: 5 additions & 0 deletions dlrover/python/master/watcher/k8s_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def _convert_pod_event_to_node_event(event):
if metadata.deletion_timestamp:
status = NodeStatus.DELETED

logger.debug(
f"Got monitor event for pod: {pod_name}, "
f"node: {host_name}, ip: {host_ip}, status: {status}."
)

restart = _verify_restarting_training(evt_obj)
if restart:
logger.info(f"{evt_obj.metadata.name} need to restart.")
Expand Down
Loading
Loading