Skip to content

Commit

Permalink
Job context implementation (#1319)
Browse files Browse the repository at this point in the history
* split job context impl

* optimization and ut fix

* fix

* fix

* fix ut

* fix ut

* fix ut

* fix ut

* merge PR 1317

* fix ut

* fix ut

* sync

* Refactor node event report and report succeeded. (#1315)

* refactor node event report

* fix ut

* merge check report

* lint

* refactor node succeeded field into a more common status field

* fix ut

* optimize ut

* no updating if already succeeded

* fix ut

(cherry picked from commit d9dc1c1)

* fix

* fix ut

* lint

* lint

* fix

* optimization

* more ut

* optimize action queue

* optimize action response

* add ut

* optimized

* optimized

---------

Co-authored-by: BO SANG <[email protected]>
Co-authored-by: chentianyi.cty <[email protected]>
Co-authored-by: Tianyi Chen <[email protected]>
  • Loading branch information
4 people authored Nov 7, 2024
1 parent d529925 commit a38cf82
Show file tree
Hide file tree
Showing 36 changed files with 1,161 additions and 390 deletions.
11 changes: 11 additions & 0 deletions dlrover/python/common/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,3 +506,14 @@ class ElasticRunConfigRequest(Message):
@dataclass
class ElasticRunConfig(Message):
configs: Dict[str, str] = field(default_factory=dict)


@dataclass
class DiagnosisAction(Message):
action_cls: str = ""
action_content: str = ""


@dataclass
class HeartbeatResponse(Message):
action: DiagnosisAction = field(default_factory=DiagnosisAction)
5 changes: 5 additions & 0 deletions dlrover/python/common/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,8 @@ def to_dict(self):
d.pop("config_resource", None)
d.pop("used_resource", None)
return d

def update_from_node(self, node):
if self == node:
return
self.__dict__.update(node.__dict__)
16 changes: 14 additions & 2 deletions dlrover/python/diagnosis/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ class InferenceConfigKey(object):
class DiagnosisConstant(object):
MASTER_DIAGNOSIS_OBSERVING_INTERVAL_SECS = 180
AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS = 60
MASTER = -1
ANY_INSTANCE = -2
LOCAL_INSTANCE = -3
ACTION_EXPIRED_TIME_PERIOD_DEFAULT = 60 * 5


class DiagnosisDataType(object):
Expand All @@ -32,7 +36,15 @@ class DiagnosisDataType(object):
XPU_TIMER_METRIC = "XPU_TIMER_METRIC"


class DiagnosisAction(object):
NO_ACTION = "no_action"
class DiagnosisActionType(object):
# common
NONE = "no_action"
ANY = "any_action"

# master operation
MASTER_RELAUNCH_WORKER = "master_relaunch_worker"
EVENT = "event"

# node operation
RESTART_WORKER = "restart_worker"
RELAUNCH_WORKER = "relaunch_worker"
215 changes: 215 additions & 0 deletions dlrover/python/diagnosis/common/diagnosis_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright 2024 The DLRover Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import queue
import threading
from abc import ABCMeta
from datetime import datetime
from queue import Queue
from typing import Dict

from dlrover.python.common.log import default_logger as logger
from dlrover.python.diagnosis.common.constants import (
DiagnosisActionType,
DiagnosisConstant,
)
from dlrover.python.util.time_util import has_expired


class DiagnosisAction(metaclass=ABCMeta):
"""
Operation to be done after diagnostician.
Args:
action_type (DiagnosisActionType): Type of action.
instance (Optional): Instance id(nod id). Defaults to -1(Master).
timestamp (Optional[datetime.datetime]): Timestamp of action. Unit: ms.
Defaults to current time.
expired_time_period (Optional): Milliseconds of expired time period.
Unit: ms. Defaults to 60 seconds.
"""

def __init__(
self,
action_type=DiagnosisActionType.NONE,
instance=DiagnosisConstant.MASTER,
timestamp=0,
expired_time_period=60 * 1000,
):
self._action_type = action_type
self._instance = instance
if timestamp == 0:
self._timestamp: float = datetime.now().timestamp()
else:
self._timestamp: float = timestamp

if expired_time_period == 0:
self._expired_time_period = (
DiagnosisConstant.ACTION_EXPIRED_TIME_PERIOD_DEFAULT
)
else:
self._expired_time_period = expired_time_period

@property
def action_type(self):
return self._action_type

@property
def instance(self):
return self._instance

@property
def timestamp(self):
return self._timestamp

@property
def expired_timestamp(self):
return self._timestamp + self._expired_time_period

def is_expired(self) -> bool:
return has_expired(self._timestamp, self._expired_time_period)

def to_json(self):
data = {k.lstrip("_"): v for k, v in self.__dict__.items()}
return json.dumps(data)

@classmethod
def from_json(cls, json_data):
return cls(**json.loads(json_data))


class NoAction(DiagnosisAction):
def __init__(self):
super(NoAction, self).__init__(
action_type=DiagnosisActionType.NONE, instance=-999
)


class EventAction(DiagnosisAction):
"""Output the specified event."""

def __init__(
self,
event_type: str = "",
event_instance: str = "",
event_action: str = "",
event_msg: str = "",
event_labels: Dict[str, str] = {},
timestamp=0,
expired_time_period=0,
):
super().__init__(
DiagnosisActionType.EVENT,
timestamp=timestamp,
expired_time_period=expired_time_period,
)
self._event_type = event_type
self._event_instance = event_instance
self._event_action = event_action
self._event_msg = event_msg
self._event_labels = event_labels

@property
def event_type(self):
return self._event_type

@property
def event_instance(self):
return self._event_instance

@property
def event_action(self):
return self._event_action

@property
def event_msg(self):
return self._event_msg

@property
def event_labels(self):
return self._event_labels


class NodeAction(DiagnosisAction):
def __init__(
self,
node_id,
node_status,
reason,
action_type=DiagnosisActionType.MASTER_RELAUNCH_WORKER,
timestamp=0,
expired_time_period=0,
):
super().__init__(
action_type,
node_id,
timestamp,
expired_time_period,
)
self._node_status = node_status
self._reason = reason

@property
def node_id(self):
return self.instance

@property
def node_status(self):
return self._node_status

@property
def reason(self):
return self._reason


def is_same_action(action1: DiagnosisAction, action2: DiagnosisAction) -> bool:
return False


class DiagnosisActionQueue:
def __init__(self):
self._actions: Dict[int, Queue[DiagnosisAction]] = {}
self._lock = threading.Lock()

def add_action(self, new_action: DiagnosisAction):
with self._lock:
instance = new_action.instance
if instance not in self._actions:
self._actions[instance] = Queue(maxsize=10)
ins_actions = self._actions[instance]
try:
ins_actions.put(new_action, timeout=3)
logger.info(f"New diagnosis action {new_action}")
except queue.Full:
logger.warning(
f"Diagnosis actions for {instance} is full, "
f"skip action: {new_action}."
)

def next_action(
self,
instance=DiagnosisConstant.LOCAL_INSTANCE,
) -> DiagnosisAction:
with self._lock:
while True:
if (
instance not in self._actions
or self._actions[instance].empty()
):
return DiagnosisAction()
action = self._actions[instance].get()
if not action.is_expired():
return action
else:
logger.info(f"Skip expired diagnosis action: {action}.")
6 changes: 3 additions & 3 deletions dlrover/python/diagnosis/inferencechain/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from typing import List

from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction
from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction
from dlrover.python.diagnosis.common.inference_chain import Inference


def coordinate_inferences(observations: List[Inference]) -> DiagnoseAction:
return DiagnoseAction()
def coordinate_inferences(observations: List[Inference]) -> DiagnosisAction:
return DiagnosisAction()
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from torch.distributed.elastic.agent.server.api import RunResult, WorkerSpec
from torch.distributed.elastic.agent.server import RunResult, WorkerSpec


class WorkerContext:
class AgentContext(object):
def __init__(
self,
worker_spec: WorkerSpec,
Expand Down
32 changes: 16 additions & 16 deletions dlrover/python/elastic_agent/diagnosis/diagnosis_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
from dlrover.python.common.error import ProcessError
from dlrover.python.common.log import default_logger as logger
from dlrover.python.common.singleton import Singleton
from dlrover.python.common.worker import WorkerContext
from dlrover.python.diagnosis.common.constants import (
DiagnosisAction,
DiagnosisActionType,
DiagnosisConstant,
InferenceConfigKey,
)
from dlrover.python.diagnosis.common.diagnose_action import DiagnoseAction
from dlrover.python.diagnosis.common.diagnosis_action import DiagnosisAction
from dlrover.python.diagnosis.common.diagnosis_data import WorkerTrainingMetric
from dlrover.python.diagnosis.common.inference_chain import (
Inference,
Expand All @@ -50,6 +49,7 @@
get_worker_diagnosis_operators,
get_worker_observe_operators,
)
from dlrover.python.elastic_agent.context import AgentContext
from dlrover.python.elastic_agent.master_client import MasterClient


Expand Down Expand Up @@ -105,7 +105,7 @@ def _observe(self) -> List[Inference]:

def _diagnose_observations(
self, observations: List[Inference]
) -> DiagnoseAction:
) -> DiagnosisAction:
conclusions: List[Inference] = []
for ob in observations:
ic = InferenceChain([ob], self._diagnosis_operators)
Expand Down Expand Up @@ -133,9 +133,9 @@ def _periodically_diagnosis(self):
DiagnosisConstant.AGENT_PERIODICALLY_DIAGNOSIS_INTERVAL_SECS
)

def diagnose_training_failure(self, worker_context: WorkerContext) -> str:
def diagnose_training_failure(self, agent_context: AgentContext) -> str:
self._report_failure_to_master(
worker_context.run_result.failures, worker_context.restart_count
agent_context.run_result.failures, agent_context.restart_count
)
# check if the node is failed
inference = Inference(
Expand All @@ -156,25 +156,25 @@ def diagnose_training_failure(self, worker_context: WorkerContext) -> str:
)
failure_node = is_inference_included(infer_results, failure_inf)

if worker_context.remaining_failovers > 0 and not failure_node:
if agent_context.remaining_failovers > 0 and not failure_node:
logger.info(
f"[{worker_context.worker_spec.role}] Worker group "
f"{worker_context.run_result.state.name}, "
f"[{agent_context.worker_spec.role}] Worker group "
f"{agent_context.run_result.state.name}, "
f"is failure node: {failure_node},"
f"{worker_context.remaining_failovers}/"
f"{worker_context.worker_spec.max_restarts} "
f"{agent_context.remaining_failovers}/"
f"{agent_context.worker_spec.max_restarts} "
f"attempts left; will restart worker group."
)
return DiagnosisAction.RESTART_WORKER
return DiagnosisActionType.RESTART_WORKER
else:
logger.info(
f"[{worker_context.worker_spec.role}] Worker group "
f"{worker_context.run_result.state.name}, "
f"[{agent_context.worker_spec.role}] Worker group "
f"{agent_context.run_result.state.name}, "
f"is failure node: {failure_node}, "
f"no attempts({worker_context.worker_spec.max_restarts}) "
f"no attempts({agent_context.worker_spec.max_restarts}) "
"left; will relaunch."
)
return DiagnosisAction.RELAUNCH_WORKER
return DiagnosisActionType.RELAUNCH_WORKER

def _report_failure_to_master(
self, failures: Dict[int, ProcessFailure], restart_count: int
Expand Down
Loading

0 comments on commit a38cf82

Please sign in to comment.