Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Training hang detection based on XPU Timer metric. #1288

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
import re
import sys
from typing import Dict, List, Tuple

from dlrover.python.common.log import default_logger as logger
from dlrover.python.diagnosis.common.constants import DiagnosisDataType
from dlrover.python.diagnosis.common.diagnosis_data import DiagnosisData
from dlrover.python.diagnosis.common.inference_chain import (
Expand Down Expand Up @@ -46,7 +49,14 @@ def is_compatible(self, inference: Inference) -> bool:
return False

def infer(self, inferences: List[Inference]) -> List[Inference]:
if not self.data_manager:
if (
not self.data_manager
or not self.data_manager.with_runtime_context()
):
logger.info(
"Skip training-hang inference for there is "
"no diagnosis data reference."
)
return [
Inference(
name=InferenceName.TRAINING,
Expand All @@ -60,6 +70,7 @@ def infer(self, inferences: List[Inference]) -> List[Inference]:
)

if diagnosis_data and self.is_hang(diagnosis_data):
logger.warning("Training might hanged.")
return [
Inference(
name=InferenceName.TRAINING,
Expand All @@ -77,17 +88,100 @@ def infer(self, inferences: List[Inference]) -> List[Inference]:
]

def is_hang(self, diagnosis_data: List[DiagnosisData]):
hang_metric = []
logger.info(
"Hang detection start using diagnosis data, "
f"data number: {len(diagnosis_data)}, "
f"data size: {sys.getsizeof(diagnosis_data)}."
)
worker_hang_metric: Dict[int, List[Tuple[int, bool]]] = {}
if not diagnosis_data:
return False

for data in diagnosis_data:
# filter hang metric
each_metric = [
line
for line in data.data_content.splitlines()
if line.startswith(HANG_METRIC_PREFIX)
]
hang_metric.append(each_metric)

# TODO: implement the judgement
# if all local rank is hanged, tag worker hang
rank_hang_size = 0
is_worker_hang = False
for each_rank_metric in each_metric:
match = re.search(r"(\d+)(?!.*\d)", each_rank_metric)
if match and match.group(0) == "1":
rank_hang_size += 1
if rank_hang_size == len(each_metric):
is_worker_hang = True

if data.node_rank not in worker_hang_metric:
worker_hang_metric[data.node_rank] = []
worker_hang_metric[data.node_rank].append(
(data.timestamp, is_worker_hang)
)

# hang detection rules:
# 1. 100% worker got hang metric
# 2. last for 10+ minutes
BalaBalaYi marked this conversation as resolved.
Show resolved Hide resolved
hang_id, hang_last = self._find_hang_intersection(worker_hang_metric)
if hang_id != -1:
logger.info(
f"Got hang worker: {hang_id}, " f"time last: {hang_last}"
)
return True

return False

def _find_hang_intersection(
self, worker_hang_metric: Dict[int, List[Tuple[int, bool]]]
) -> Tuple[int, int]:
"""
Require all workers hang from latest and find the hang intersection.

Args:
worker_hang_metric (Dict[int, List[Tuple[int, bool]]]): Input
BalaBalaYi marked this conversation as resolved.
Show resolved Hide resolved
metric.

Returns:
The hang intersection's id and time last in tuple format.
"""

worker_hang_length_min = 0
worker_hang_id = -1

# find the intersection from latest
for worker_id, tuple_list in worker_hang_metric.items():
# sorted by timestamp
tuple_list.sort(key=lambda x: x[0])
worker_hang_length = 0

for tuple_item in reversed(tuple_list):
if tuple_item[1]:
worker_hang_length += 1
else:
break

if worker_hang_length > 0:
if worker_hang_length_min == 0:
worker_hang_length_min = worker_hang_length
worker_hang_id = worker_id
elif worker_hang_length < worker_hang_length_min:
worker_hang_length_min = worker_hang_length
worker_hang_id = worker_id
else:
# there is normal worker
return -1, -1

# get the intersection's time last
if worker_hang_id != -1 and worker_hang_length_min != 0:
hang_worker_metric = worker_hang_metric[worker_hang_id]
time_last = (
hang_worker_metric[len(hang_worker_metric) - 1][0]
- hang_worker_metric[
len(hang_worker_metric) - worker_hang_length_min
][0]
)
return worker_hang_id, time_last

return -1, -1
54 changes: 39 additions & 15 deletions dlrover/python/master/diagnosis/diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import threading
import time
from collections import deque
from datetime import datetime, timedelta
from typing import Dict, List

Expand Down Expand Up @@ -41,9 +43,11 @@ def has_expired(timestamp: float, time_period: int) -> bool:


class DiagnosisManager:
def __init__(self):
def __init__(self, job_manager=None):
self._is_observing_started = False
self._data_manager: DiagnosisDataManager = DiagnosisDataManager(600)
self._data_manager: DiagnosisDataManager = DiagnosisDataManager(
job_manager, 600
)
self._diagnostician: Diagnostician = Diagnostician(self._data_manager)

def collect_diagnosis_data(self, data: DiagnosisData):
Expand Down Expand Up @@ -94,6 +98,9 @@ def _diagnose_failures(self):
if not self._is_observing_started:
logger.info("Stop to diagnose failures for observing.")
break
logger.info(
f"Diagnosis data size: {self._data_manager.get_data_size()}."
)

observed_problems = self._diagnostician.observe_training()
for problem in observed_problems:
Expand All @@ -107,36 +114,53 @@ def _diagnose_failures(self):


class DiagnosisDataManager:
def __init__(self, expire_time_period):
self.diagnosis_data: Dict[str, List[DiagnosisData]] = {}
def __init__(self, job_manager=None, expire_time_period=600):
self._diagnosis_data: Dict[str, deque[DiagnosisData]] = {}
self.expire_time_period = expire_time_period
self._job_manager = job_manager
self._lock = threading.Lock()

@property
def job_manager(self):
return self._job_manager

@property
def data(self):
return self._diagnosis_data

def with_runtime_context(self) -> bool:
return self.job_manager is not None

def store_data(self, data: DiagnosisData):
data_type = data.data_type
if data_type not in self.diagnosis_data:
logger.debug(f"{data_type} is not found in the store")
self.diagnosis_data[data_type] = []
self.diagnosis_data[data_type].append(data)
self._clean_diagnosis_data(data_type)
with self._lock:
if data_type not in self.data:
self.data[data_type] = deque(maxlen=100000)
self.data[data_type].append(data)
self._clean_diagnosis_data(data_type)

def get_data(self, data_type: str) -> List[DiagnosisData]:
if data_type not in self.diagnosis_data:
return []
return self.diagnosis_data[data_type]
with self._lock:
if data_type not in self.data:
return []
return list(self.data[data_type])

def get_data_size(self):
return sys.getsizeof(self.data)

def _clean_diagnosis_data(self, data_type: str):
if data_type not in self.diagnosis_data:
if data_type not in self.data:
return

data = self.diagnosis_data[data_type]
data = self.data[data_type]
n = 0
for d in data:
if has_expired(d.timestamp, self.expire_time_period):
n = n + 1
else:
break

self.diagnosis_data[data_type] = data[n:]
self.data[data_type] = data[n:]


class Diagnostician:
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/master/dist_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(
error_monitor
),
}
self.diagnosis_manager = DiagnosisManager()
self.diagnosis_manager = DiagnosisManager(self.job_manager)
self.job_metric_collector = self._create_metric_collector_if_needed(
args
)
Expand Down
12 changes: 12 additions & 0 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,18 @@ def collect_node_heart_beat(self, node_type, node_id, timestamp):
def update_node_required_info_callback(self):
self._worker_manager.update_node_required_info(self._nodes_required)

def get_node_required_info(self):
return self._nodes_required

def get_total_node_num_by_type(self, node_type):
if not self._job_nodes:
return 0

return len(self._job_nodes[node_type])

def get_job_strategy(self):
return self._job_args.distribution_strategy


def create_job_manager(args: JobArgs, speed_monitor) -> DistributedJobManager:
critical_worker_index = get_critical_worker_index(args)
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/master/node/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
self._error_monitor: ErrorMonitor = error_monitor

self._job_nodes: Dict[str, Dict[int, Node]] = {}
self._nodes_required = (0, 0, 0)
self._nodes_required = (0, 0, 0) # (min-nodes, max-nodes, timeout)

self._training_node_config = TrainingNodeConfig(external_config)

Expand Down
Loading
Loading