Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Training hang detection based on XPU Timer metric. #1288

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dlrover/python/common/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class DefaultValues(object):
SEC_TO_CHANGE_PS = 3600 # 1h
SEC_TO_WAIT_FAILED_PS = 600 # 10min
HANG_CPU_USAGE_RATE = 0.05
HANG_DETECTION = 1


class Context(Singleton):
Expand Down Expand Up @@ -95,6 +96,9 @@ def __init__(self):
self.is_tfv1_ps = False
self.master_port = None
self.relaunch_always = False
# The strategy of 'hang detection':
# 0: log only; 1: notify; 2: with fault tolerance
self.hang_detection = DefaultValues.HANG_DETECTION

def set_params_from_brain(self):
self.train_speed_record_num = self.get_param_value_from_brain(
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/diagnosis/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class InferenceConfigKey(object):
class DiagnosisConstant(object):
MASTER_DIAGNOSIS_OBSERVING_INTERVAL_SECS = 180
AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS = 60
MASTER = -1
MASTER_INSTANCE = -1
ANY_INSTANCE = -2
LOCAL_INSTANCE = -3
ACTION_EXPIRED_TIME_PERIOD_DEFAULT = 60 * 5
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/diagnosis/common/diagnosis_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class DiagnosisAction(metaclass=ABCMeta):
def __init__(
self,
action_type=DiagnosisActionType.NONE,
instance=DiagnosisConstant.MASTER,
instance=DiagnosisConstant.MASTER_INSTANCE,
timestamp=0,
expired_time_period=60 * 1000,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
import re
import sys
from typing import Dict, List, Tuple

from dlrover.python.common.global_context import Context
from dlrover.python.common.log import default_logger as logger
from dlrover.python.diagnosis.common.constants import DiagnosisDataType
from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData
from dlrover.python.diagnosis.common.inference_chain import (
Expand All @@ -24,6 +28,7 @@
)

HANG_METRIC_PREFIX = "XPU_TIMER_COMMON_HANG"
_dlrover_ctx = Context.singleton_instance()


class CheckTrainingHangOperator(InferenceOperator):
Expand All @@ -46,7 +51,14 @@ def is_compatible(self, inference: Inference) -> bool:
return False

def infer(self, inferences: List[Inference]) -> List[Inference]:
if not self.data_manager:
if (
not self.data_manager
or not self.data_manager.with_runtime_context()
):
logger.info(
"Skip training-hang inference for there is "
"no diagnosis data reference."
)
return [
Inference(
name=InferenceName.TRAINING,
Expand All @@ -60,6 +72,7 @@ def infer(self, inferences: List[Inference]) -> List[Inference]:
)

if diagnosis_data and self.is_hang(diagnosis_data):
logger.warning("Training might hanged.")
return [
Inference(
name=InferenceName.TRAINING,
Expand All @@ -77,17 +90,112 @@ def infer(self, inferences: List[Inference]) -> List[Inference]:
]

def is_hang(self, diagnosis_data: List[DiagnosisData]):
hang_metric = []
logger.info(
"Hang detection start using diagnosis data, "
f"data number: {len(diagnosis_data)}, "
f"data size: {sys.getsizeof(diagnosis_data)}."
)
worker_hang_metric: Dict[int, List[Tuple[int, bool]]] = {}
if not diagnosis_data:
return False

for data in diagnosis_data:
# filter hang metric
each_metric = [
line
for line in data.data_content.splitlines()
if line.startswith(HANG_METRIC_PREFIX)
]
hang_metric.append(each_metric)

# TODO: implement the judgement
# if all local rank is hanged, tag worker hang
rank_hang_size = 0
is_worker_hang = False
for each_rank_metric in each_metric:
match = re.search(r"(\d+)(?!.*\d)", each_rank_metric)
if match and match.group(0) == "1":
rank_hang_size += 1
if rank_hang_size == len(each_metric):
is_worker_hang = True

if data.node_rank not in worker_hang_metric:
worker_hang_metric[data.node_rank] = []
worker_hang_metric[data.node_rank].append(
(data.timestamp, is_worker_hang)
)

# hang detection rules:
# 1. 100% worker got hang metric
# 2. last for 5+ minutes
hang_id, hang_last = self._find_hang_intersection(worker_hang_metric)
hang_last_threshold = self._get_hang_time_last_threshold()
if hang_id != -1 and hang_last > hang_last_threshold:
logger.info(
f"Got hang worker: {hang_id}, time last: {hang_last}, "
f"threshold: {hang_last_threshold}"
)
if _dlrover_ctx.hang_detection == 1:
# trigger event action
pass
elif _dlrover_ctx.hang_detection == 2:
# trigger relaunch action
pass
return True

return False

def _get_hang_time_last_threshold(self):
# set 5 minutes for now(second)
return 5 * 60

def _find_hang_intersection(
self, worker_hang_metric: Dict[int, List[Tuple[int, bool]]]
) -> Tuple[int, int]:
"""
Require all workers hang from latest and find the hang intersection.

Args:
worker_hang_metric (Dict[int, List[Tuple[int, bool]]]): Input
BalaBalaYi marked this conversation as resolved.
Show resolved Hide resolved
metric in format: node_id: [(timestamp, is_hang), ...]

Returns:
The hang intersection's id and time last in tuple format.
"""

worker_hang_length_min = 0
worker_hang_id = -1

# find the intersection from latest
for worker_id, tuple_list in worker_hang_metric.items():
# sorted by timestamp
tuple_list.sort(key=lambda x: x[0])
worker_hang_length = 0

for tuple_item in reversed(tuple_list):
if tuple_item[1]:
worker_hang_length += 1
else:
break

if worker_hang_length > 0:
if worker_hang_length_min == 0:
worker_hang_length_min = worker_hang_length
worker_hang_id = worker_id
elif worker_hang_length < worker_hang_length_min:
worker_hang_length_min = worker_hang_length
worker_hang_id = worker_id
else:
# there is normal worker
return -1, -1

# get the intersection's time last
if worker_hang_id != -1 and worker_hang_length_min != 0:
hang_worker_metric = worker_hang_metric[worker_hang_id]
time_last = (
hang_worker_metric[len(hang_worker_metric) - 1][0]
- hang_worker_metric[
len(hang_worker_metric) - worker_hang_length_min
][0]
)
return worker_hang_id, time_last

return -1, -1
7 changes: 7 additions & 0 deletions dlrover/python/master/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def _build_master_args_parser():
type=pos_int,
help="The number of nodes",
)
parser.add_argument(
"--hang_detection",
default=1,
type=pos_int,
help="The strategy of 'hang detection', "
"0: log only; 1: notify; 2: with fault tolerance",
)
add_params(parser)
return parser

Expand Down
56 changes: 37 additions & 19 deletions dlrover/python/master/diagnosis/diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import threading
import time
from collections import deque
from datetime import datetime, timedelta
from itertools import islice
from typing import Dict, List

from dlrover.python.common.log import default_logger as logger
Expand All @@ -32,6 +35,7 @@
from dlrover.python.diagnosis.inferencechain.inferenceoperator.check_training_hang_operator import ( # noqa: E501
CheckTrainingHangOperator,
)
from dlrover.python.master.node.job_context import get_job_context


def has_expired(timestamp: float, time_period: int) -> bool:
Expand All @@ -43,7 +47,7 @@ def has_expired(timestamp: float, time_period: int) -> bool:
class DiagnosisManager:
def __init__(self):
self._is_observing_started = False
self._data_manager: DiagnosisDataManager = DiagnosisDataManager(600)
self._data_manager: DiagnosisDataManager = DiagnosisDataManager()
self._diagnostician: Diagnostician = Diagnostician(self._data_manager)

def collect_diagnosis_data(self, data: DiagnosisData):
Expand Down Expand Up @@ -72,8 +76,8 @@ def start_observing(self):

try:
thread = threading.Thread(
target=self._diagnose_failures(),
name="diagnose_failures",
target=self._diagnose_failures,
name="failure_diagnosis",
daemon=True,
)
thread.start()
Expand All @@ -94,10 +98,14 @@ def _diagnose_failures(self):
if not self._is_observing_started:
logger.info("Stop to diagnose failures for observing.")
break
logger.info(
"Current diagnosis "
f"data size: {self._data_manager.get_data_size()}."
)

observed_problems = self._diagnostician.observe_training()
for problem in observed_problems:
logger.info(f"observed problems: {problem}")
logger.info(f"Observe problem in diagnosing: {problem}")
root_causes = self._diagnostician.diagnose_failure(problem)
for root_cause in root_causes:
logger.info(f"identify root cause: {root_cause}")
Expand All @@ -107,36 +115,46 @@ def _diagnose_failures(self):


class DiagnosisDataManager:
def __init__(self, expire_time_period):
self.diagnosis_data: Dict[str, List[DiagnosisData]] = {}
def __init__(self, expire_time_period=600):
self._diagnosis_data: Dict[str, deque[DiagnosisData]] = {}
self.expire_time_period = expire_time_period
self._job_context = get_job_context()
self._lock = threading.Lock()

@property
def data(self):
return self._diagnosis_data

def store_data(self, data: DiagnosisData):
data_type = data.data_type
if data_type not in self.diagnosis_data:
logger.debug(f"{data_type} is not found in the store")
self.diagnosis_data[data_type] = []
self.diagnosis_data[data_type].append(data)
self._clean_diagnosis_data(data_type)
with self._lock:
if data_type not in self.data:
self.data[data_type] = deque(maxlen=100000)
self.data[data_type].append(data)
self._clean_diagnosis_data(data_type)

def get_data(self, data_type: str) -> List[DiagnosisData]:
if data_type not in self.diagnosis_data:
return []
return self.diagnosis_data[data_type]
with self._lock:
if data_type not in self.data:
return []
return list(self.data[data_type])

def get_data_size(self):
return sys.getsizeof(self.data)

def _clean_diagnosis_data(self, data_type: str):
if data_type not in self.diagnosis_data:
if data_type not in self.data:
return

data = self.diagnosis_data[data_type]
each_data = self.data[data_type]
n = 0
for d in data:
for d in each_data:
if has_expired(d.timestamp, self.expire_time_period):
n = n + 1
else:
break

self.diagnosis_data[data_type] = data[n:]
if n > 0:
self.data[data_type] = deque(islice(each_data, n, len(each_data)))


class Diagnostician:
Expand Down
21 changes: 19 additions & 2 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
from dlrover.python.common.grpc import ParallelConfig
from dlrover.python.common.log import default_logger as logger
from dlrover.python.common.node import Node, NodeGroupResource
from dlrover.python.diagnosis.common.constants import (
DiagnosisActionType,
DiagnosisConstant,
)
from dlrover.python.diagnosis.common.diagnosis_action import (
DiagnosisAction,
NoAction,
Expand Down Expand Up @@ -476,7 +480,11 @@ def _diagnose_job(self):
logger.warning(e)
detail_trace_back = traceback.format_exc()
logger.warning(detail_trace_back)
self._process_diagnosis_action(self._job_context.next_action())
self._process_diagnosis_action(
self._job_context.next_action(
instance=DiagnosisConstant.MASTER_INSTANCE
)
)
time.sleep(15)

def _get_dead_node_event(self, window_interval=900) -> List[NodeEvent]:
Expand Down Expand Up @@ -651,7 +659,10 @@ def _get_pod_unique_labels(self, node: Node):
}

def _process_diagnosis_action(self, action: DiagnosisAction):
pass
if not action or action.action_type == DiagnosisActionType.NONE:
return

# TODO

def _process_event(self, event: NodeEvent):
node_type = event.node.type
Expand Down Expand Up @@ -1236,6 +1247,12 @@ def process_reported_node_event(self, node_event: NodeEvent):

self._job_context.update_job_node(target_node)

def get_node_required_info(self):
return self._nodes_required

def get_job_strategy(self):
return self._job_args.distribution_strategy


def create_job_manager(args: JobArgs, speed_monitor) -> DistributedJobManager:
critical_worker_index = get_critical_worker_index(args)
Expand Down
Loading
Loading