Skip to content

Commit

Permalink
make ut running
Browse files Browse the repository at this point in the history
  • Loading branch information
skydoorkai committed Jul 24, 2023
1 parent 8fea297 commit a117bd0
Show file tree
Hide file tree
Showing 10 changed files with 1,180 additions and 9 deletions.
3 changes: 1 addition & 2 deletions atorch/atorch/auto/model_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
rank,
)
from atorch.utils.graph_transform_utils import map_aggregate
from atorch.utils.version import torch_version

try:
from pippy.IR import LossWrapper
Expand Down Expand Up @@ -463,7 +462,7 @@ def adjust_wrappers(self):
fairscale_zero2_wrapper_exist = "zero2" in self.post_wrappers
fsdp_wrapper_exist = "fsdp" in self.pre_wrappers or "zero2" in self.pre_wrappers
tensor_parallel_wrapper_exist = "tp" in self.pre_wrappers
ckpt_wrapper_exist = "checkpoint" in self.post_wrappers
# ckpt_wrapper_exist = "checkpoint" in self.post_wrappers

# remove ddp wrapper when using zero2
if ddp_wrapper_exist and (fairscale_zero2_wrapper_exist or fsdp_wrapper_exist):
Expand Down
Binary file removed atorch/atorch/distributed/.distributed.py.swp
Binary file not shown.
2 changes: 1 addition & 1 deletion atorch/atorch/distributed/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def init_distributed(
return False
elif backend == "accl":
try:
import torch_accl # noqa: F401
import torch_accl # noqa: F401
except ImportError:
logger.error("import torch_accl failed")
return False
Expand Down
1 change: 1 addition & 0 deletions atorch/atorch/fault_tolerance/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from atorch.fault_tolerance.hanging_detector import HangingDetector
133 changes: 133 additions & 0 deletions atorch/atorch/fault_tolerance/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import uuid

import torch.distributed.elastic.rendezvous.registry as rdzv_registry
from torch.distributed.elastic import events, metrics
from torch.distributed.elastic.agent.server.api import WorkerSpec
from torch.distributed.elastic.multiprocessing import SignalException
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.utils.logging import get_logger
from torch.distributed.launcher.api import _get_addr_and_port, _get_entrypoint_name, elastic_launch
from torch.distributed.run import config_from_args

from atorch.fault_tolerance.custom_agent import LocalDetectHangingAgent

logger = get_logger()


def run(args):
if args.standalone:
args.rdzv_backend = "c10d"
args.rdzv_endpoint = "localhost:29400"
args.rdzv_id = str(uuid.uuid4())
logger.info(
f"\n**************************************\n"
f"Rendezvous info:\n"
f"--rdzv_backend={args.rdzv_backend} "
f"--rdzv_endpoint={args.rdzv_endpoint} "
f"--rdzv_id={args.rdzv_id}\n"
f"**************************************\n"
)

config, cmd, cmd_args = config_from_args(args)
fault_tolerant_launch(
config=config,
entrypoint=cmd,
)(*cmd_args)


class fault_tolerant_launch(elastic_launch):
def __init__(self, config, entrypoint):
super().__init__(config, entrypoint)

def __call__(self, *args):
return launch_custom_agent(self._config, self._entrypoint, list(args))


def launch_custom_agent(config, entrypoint, args):
if not config.run_id:
run_id = str(uuid.uuid4().int)
logger.warning(f"config has no run_id, generated a random run_id: {run_id}")
config.run_id = run_id

entrypoint_name = _get_entrypoint_name(entrypoint, args)

logger.info(
f"Starting elastic_operator with launch configs:\n"
f" entrypoint : {entrypoint_name}\n"
f" min_nodes : {config.min_nodes}\n"
f" max_nodes : {config.max_nodes}\n"
f" nproc_per_node : {config.nproc_per_node}\n"
f" run_id : {config.run_id}\n"
f" rdzv_backend : {config.rdzv_backend}\n"
f" rdzv_endpoint : {config.rdzv_endpoint}\n"
f" rdzv_configs : {config.rdzv_configs}\n"
f" max_restarts : {config.max_restarts}\n"
f" monitor_interval : {config.monitor_interval}\n"
f" log_dir : {config.log_dir}\n"
f" metrics_cfg : {config.metrics_cfg}\n"
)

rdzv_parameters = RendezvousParameters(
backend=config.rdzv_backend,
endpoint=config.rdzv_endpoint,
run_id=config.run_id,
min_nodes=config.min_nodes,
max_nodes=config.max_nodes,
**config.rdzv_configs,
)

master_addr, master_port = _get_addr_and_port(rdzv_parameters)

spec = WorkerSpec(
role=config.role,
local_world_size=config.nproc_per_node,
entrypoint=entrypoint,
args=tuple(args),
rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
max_restarts=config.max_restarts,
monitor_interval=config.monitor_interval,
redirects=config.redirects,
tee=config.tee,
master_addr=master_addr,
master_port=master_port,
)

agent = LocalDetectHangingAgent(
spec=spec, start_method=config.start_method, log_dir=config.log_dir, rdzv_params=rdzv_parameters
)

shutdown_rdzv = True
try:
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))

result = agent.run()
# records that agent.run() has succeeded NOT that workers have succeeded
events.record(agent.get_event_succeeded())

if result.is_failed():
# ChildFailedError is treated specially by @record
# if the error files for the failed children exist
# @record will copy the first error (root cause)
# to the error file of the launcher process.
raise ChildFailedError(
name=entrypoint_name,
failures=result.failures,
)

return result.return_values
except ChildFailedError:
raise
except SignalException:
# when the agent dies with a signal do NOT shutdown the rdzv_handler
# since this closes the rendezvous on this rdzv_id permanently and
# prevents any additional scaling events
shutdown_rdzv = False
events.record(agent.get_event_failed())
raise
except Exception:
events.record(agent.get_event_failed())
raise
finally:
if shutdown_rdzv:
spec.rdzv_handler.shutdown()
195 changes: 195 additions & 0 deletions atorch/atorch/fault_tolerance/custom_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import os
import shutil
import socket
import time
from contextlib import closing

from torch.distributed.elastic.agent.server.api import WorkerState, _get_fq_hostname, _get_socket_with_port
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from torch.distributed.elastic.metrics import prof, put_metric
from torch.distributed.elastic.multiprocessing import start_processes
from torch.distributed.elastic.utils import macros
from torch.distributed.elastic.utils.logging import get_logger

from atorch.fault_tolerance.hanging_detector import RelaunchStatus

log = get_logger()


class LocalDetectHangingAgent(LocalElasticAgent):
def __init__(
self,
spec,
start_method="spawn",
exit_barrier_timeout=300,
log_dir=None,
rdzv_params=None,
):
super().__init__(spec, start_method, exit_barrier_timeout, log_dir)
self.rdzv_params = rdzv_params
self.node_world_size = self.rdzv_params.max_nodes
self.node_rank = self.rdzv_params.config.get("node_rank")
if self.node_rank is None:
self.node_rank = os.getenv("RANK")
if self.node_rank is None:
self.node_rank = "0"

@staticmethod
def _set_master_addr_port(store, master_addr, master_port, local_dir=None):
if master_port is None:
sock = _get_socket_with_port()
with closing(sock):
master_port = sock.getsockname()[1]

if master_addr is None:
if local_dir is not None:
master_addr = local_dir
else:
master_addr = os.getenv("POD_IP", socket.gethostbyname(_get_fq_hostname()))

store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8"))

def _invoke_run(self, role):
# NOTE: currently only works for a single role

spec = self._worker_group.spec
role = spec.role

log.info(f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}")

self._initialize_workers(self._worker_group)
monitor_interval = spec.monitor_interval
rdzv_handler = spec.rdzv_handler

# self._store is a PrefixStore. Here we use he underlying TCPStore of the self._store.
underlying_store = rdzv_handler._state_holder._backend._store

worker_world_size = 0
workers = self._worker_group.workers
if workers:
worker_world_size = workers[0].world_size
relaunch_status = None
if worker_world_size > 0:
node_rank = int(self.node_rank)
node_world_size = int(self.node_world_size)
relaunch_status = RelaunchStatus(
"agent",
worker_world_size=worker_world_size,
agent_rank=node_rank,
agent_world_size=node_world_size,
store=underlying_store,
)
while True:
assert self._worker_group.state != WorkerState.INIT
time.sleep(monitor_interval)
run_result = self._monitor_workers(self._worker_group)
state = run_result.state
self._worker_group.state = state

put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
put_metric(f"workers.{role}.{state.name.lower()}", 1)

if state == WorkerState.SUCCEEDED:
log.info(
f"[{role}] worker group successfully finished."
f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish."
)
self._exit_barrier()
return run_result
elif relaunch_status is not None and relaunch_status.should_relaunch():
self._restart_workers(self._worker_group)
relaunch_status.step()
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
if relaunch_status is not None and relaunch_status.should_relaunch():
self._restart_workers(self._worker_group)
relaunch_status.step()
elif self._remaining_restarts > 0:
log.info(
f"[{role}] Worker group {state.name}. "
f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
f" will restart worker group"
)
self._remaining_restarts -= 1
self._restart_workers(self._worker_group)
else:
self._stop_workers(self._worker_group)
self._worker_group.state = WorkerState.FAILED
self._exit_barrier()
return run_result
elif state == WorkerState.HEALTHY:
# membership changes do not count as retries
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
if num_nodes_waiting > 0:
log.info(
f"[{role}] Detected {num_nodes_waiting} "
f"new nodes from group_rank={group_rank}; "
f"will restart worker group"
)
self._restart_workers(self._worker_group)
else:
raise Exception(f"[{role}] Worker group in {state.name} state")

@prof
def _start_workers(self, worker_group):
spec = worker_group.spec
store = worker_group.store
assert store is not None
master_addr, master_port = super()._get_master_addr_port(store)
restart_count = spec.max_restarts - self._remaining_restarts

use_agent_store = spec.rdzv_handler.get_backend() == "static"
agent_master_addr, agent_master_port = self.rdzv_params.endpoint.split(":")

args = {}
envs = {}
for worker in worker_group.workers:
local_rank = worker.local_rank
worker_env = {
"LOCAL_RANK": str(local_rank),
"RANK": str(worker.global_rank),
"GROUP_RANK": str(worker_group.group_rank),
"ROLE_RANK": str(worker.role_rank),
"ROLE_NAME": spec.role,
"LOCAL_WORLD_SIZE": str(spec.local_world_size),
"WORLD_SIZE": str(worker.world_size),
"GROUP_WORLD_SIZE": str(worker_group.group_world_size),
"ROLE_WORLD_SIZE": str(worker.role_world_size),
"MASTER_ADDR": master_addr,
"MASTER_PORT": str(master_port),
"TORCHELASTIC_RESTART_COUNT": str(restart_count),
"TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
"TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
"TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
"NCCL_ASYNC_ERROR_HANDLING": str(1),
"TORCHELASTIC_AGENT_MASTER_ADDR": agent_master_addr,
"TORCHELASTIC_AGENT_MASTER_PORT": agent_master_port,
"NODE_RANK": self.node_rank,
}
if "OMP_NUM_THREADS" in os.environ:
worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
envs[local_rank] = worker_env
worker_args = list(spec.args)
worker_args = macros.substitute(worker_args, str(local_rank))
args[local_rank] = tuple(worker_args)

# scaling events do not count towards restarts (gets same attempt #)
# remove existing log dir if this restart is due to a scaling event
attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}")
shutil.rmtree(attempt_log_dir, ignore_errors=True)
os.makedirs(attempt_log_dir)

assert spec.entrypoint is not None
self._pcontext = start_processes(
name=spec.role,
entrypoint=spec.entrypoint,
args=args,
envs=envs,
log_dir=attempt_log_dir,
start_method=self._start_method,
redirects=spec.redirects,
tee=spec.tee,
)

return self._pcontext.pids()
Loading

0 comments on commit a117bd0

Please sign in to comment.