Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Refactor diagnosis manager #1302

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dlrover/python/common/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,14 @@ class ElasticRunConfig(Message):
@dataclass
class SucceededRequest(Message):
pass


@dataclass
class DiagnosisAction(Message):
action_cls: str = ""
action_content: str = ""


@dataclass
class HeartbeatResponse(Message):
diagnosis_actions: List[DiagnosisAction] = field(default_factory=list)
5 changes: 5 additions & 0 deletions dlrover/python/common/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,8 @@ def to_dict(self):
d.pop("config_resource", None)
d.pop("used_resource", None)
return d

def update_from_node(self, node):
if self == node:
return
self.__dict__.update(node.__dict__)
45 changes: 0 additions & 45 deletions dlrover/python/common/worker.py

This file was deleted.

11 changes: 10 additions & 1 deletion dlrover/python/diagnosis/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class InferenceConfigKey(object):
class DiagnosisConstant(object):
MASTER_DIAGNOSIS_OBSERVING_INTERVAL_SECS = 180
AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS = 60
MASTER = -1
ANY_INSTANCE = -2
LOCAL_INSTANCE = -3


class DiagnosisDataType(object):
Expand All @@ -32,7 +35,13 @@ class DiagnosisDataType(object):
XPU_TIMER_METRIC = "XPU_TIMER_METRIC"


class DiagnosisAction(object):
class DiagnosisActionConstants(object):
NO_ACTION = "no_action"
RESTART_WORKER = "restart_worker"
RELAUNCH_WORKER = "relaunch_worker"

ACTION_TYPE_ANY = "any"

ACTION_EXPIRED_TIME_PERIOD = 60 * 5

TYPE_NODE = "node"
138 changes: 138 additions & 0 deletions dlrover/python/diagnosis/common/diagnosis_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2024 The DLRover Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import threading
from abc import ABCMeta
from datetime import datetime
from typing import Dict, List

from dlrover.python.common.constants import NodeType
from dlrover.python.common.log import default_logger as logger
from dlrover.python.diagnosis.common.constants import (
DiagnosisActionConstants,
DiagnosisConstant,
)
from dlrover.python.util.time_util import has_expired


class DiagnosisAction(metaclass=ABCMeta):
def __init__(
self,
action_type: str,
instance: int,
timestamp=0,
expired_time_period=0,
):
self.action_type = action_type
self.instance = instance
if timestamp == 0:
self.timestamp = int(round(datetime.now().timestamp()))
else:
self.timestamp = timestamp

if expired_time_period == 0:
self.expired_time_period = (
DiagnosisActionConstants.ACTION_EXPIRED_TIME_PERIOD
)
else:
self.expired_time_period = expired_time_period

def has_expired(self) -> bool:
return has_expired(self.timestamp, self.expired_time_period)

def to_json(self):
data = {k.lstrip("_"): v for k, v in self.__dict__.items()}
return json.dumps(data)

@classmethod
def from_json(cls, json_data):
return cls(**json.loads(json_data))


class DiagnosisNodeAction(DiagnosisAction):
def __init__(
self,
timestamp=0,
expired_time_period=0,
action="",
node_type=NodeType.WORKER,
instance=DiagnosisConstant.LOCAL_INSTANCE,
):
super().__init__(
DiagnosisActionConstants.TYPE_NODE,
instance,
timestamp,
expired_time_period,
)
self.action = action
self.node_type = node_type

def update_action(self, action: str):
self.action = action


def is_same_action(action1: DiagnosisAction, action2: DiagnosisAction) -> bool:
return False


class DiagnosisActionQueue:
def __init__(self):
self._actions: Dict[int, List[DiagnosisAction]] = {}
self._lock = threading.Lock()

def add_action(self, new_action: DiagnosisAction):
with self._lock:
instance = new_action.instance
if instance not in self._actions:
self._actions[instance] = []
ins_actions = self._actions[instance]
for action in ins_actions:
if is_same_action(new_action, action):
return
logger.info(f"enqueue action {new_action}")
ins_actions.append(new_action)

def _remove_expired_actions(self):
with self._lock:
for instance in self._actions.keys():
action_queue = self._actions[instance]
actions = []
for action in action_queue:
if not action.has_expired():
actions.append(action)
else:
logger.info(f"Action {action} has expired")
self._actions[instance] = actions

def next_actions(
self,
instance=DiagnosisConstant.LOCAL_INSTANCE,
action_type=DiagnosisActionConstants.ACTION_TYPE_ANY,
) -> List[DiagnosisAction]:
self._remove_expired_actions()
with self._lock:
if instance not in self._actions:
return []
deque_actions = []
remain_actions = []
actions = self._actions[instance]
for action in actions:
if (
action_type == DiagnosisActionConstants.TYPE_NODE
or action_type == action.action_type
):
deque_actions.append(action)
else:
remain_actions.append(action)
self._actions[instance] = remain_actions
return deque_actions
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from typing import List

from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction
from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction
from dlrover.python.diagnosis.common.inference_chain import Inference


def coordinate_inferences(observations: List[Inference]) -> DiagnoseAction:
return DiagnoseAction()
def coordinate_inferences(problems: List[Inference]) -> List[DiagnosisAction]:
return []
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@
from dlrover.python.diagnosis.inferencechain.inferenceoperator.check_failure_node_operator import ( # noqa: E501
CheckFailureNodeOperator,
)
from dlrover.python.diagnosis.inferencechain.inferenceoperator.check_training_hang_operator import ( # noqa: E501
CheckTrainingHangOperator,
)
from dlrover.python.diagnosis.inferencechain.inferenceoperator.metrics_collection_operator import ( # noqa: E501
MetricsCollectionOperator,
)
from dlrover.python.master.diagnosis.diagnosis_data_manager import (
DiagnosisDataManager,
)


def get_training_failure_operators():
Expand All @@ -29,3 +35,9 @@ def get_worker_observe_operators():

def get_worker_diagnosis_operators():
return []


def get_master_observe_operators(data_mgr: DiagnosisDataManager = None):
return [
CheckTrainingHangOperator(data_mgr),
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List


class DiagnoseAction:
def __init__(self):
self._actions: List[str] = []

def add_action(self, action: str):
self._actions.append(action)
92 changes: 92 additions & 0 deletions dlrover/python/elastic_agent/common/worker_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2024 The DLRover Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

from torch.distributed.elastic.agent.server.api import RunResult, WorkerSpec

from dlrover.python.common.singleton import Singleton
from dlrover.python.diagnosis.common.constants import (
DiagnosisActionConstants,
DiagnosisConstant,
)
from dlrover.python.diagnosis.common.diagnosis_action import (
DiagnosisAction,
DiagnosisActionQueue,
)


class WorkerContext(Singleton):
def __init__(self):
self._worker_spec: Optional[WorkerSpec] = None
self.remaining_failovers = 0
self.restart_count = 0
self._run_result: Optional[RunResult] = None
self._diagnose_action_queue = DiagnosisActionQueue()

@property
def worker_spec(self):
return self._worker_spec

@property
def run_result(self):
return self._run_result

def to_string(self) -> str:
return (
"WorkerContext:\n"
f"worker_spec: {self._worker_spec}\n"
f"remaining_failover: {self.remaining_failovers}\n"
f"restart_count: {self.restart_count}\n"
f"run_result: {self._run_result}"
)

def _update_context(
self,
worker_spec: WorkerSpec = None,
remaining_failovers: int = 0,
restart_count: int = 0,
run_result: RunResult = None,
):
self._worker_spec = worker_spec
self.remaining_failovers = remaining_failovers
self.restart_count = restart_count
self._run_result = run_result

def enqueue_diagnose_action(self, action: DiagnosisAction):
self._diagnose_action_queue.add_action(action)

def next_actions(
self,
instance=DiagnosisConstant.LOCAL_INSTANCE,
action_type=DiagnosisActionConstants.ACTION_TYPE_ANY,
) -> List[DiagnosisAction]:
return self._diagnose_action_queue.next_actions(
instance=instance, action_type=action_type
)


def get_worker_context() -> WorkerContext:
return WorkerContext.singleton_instance()


def update_worker_context(
worker_spec: WorkerSpec = None,
remaining_failovers: int = 0,
restart_count: int = 0,
run_result: RunResult = None,
):
worker_context = get_worker_context()
worker_context._update_context(
worker_spec, remaining_failovers, restart_count, run_result
)
Loading
Loading