Skip to content

Commit

Permalink
Merge branch 'master' into update_acc
Browse files Browse the repository at this point in the history
  • Loading branch information
skydoorkai committed Jul 24, 2023
2 parents c2bbcb2 + 2d86add commit f09a647
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 75 deletions.
2 changes: 1 addition & 1 deletion dlrover/proto/elastic_training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ message NodeMeta {
string type = 1;
string addr = 2;
int32 memory = 3;
int32 cpu = 4;
float cpu = 4;
int32 gpu = 5;
string gpu_type = 6;
int32 id = 7;
Expand Down
85 changes: 53 additions & 32 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,17 @@ class ProcessError:


class MasterRendezvousHandler(RendezvousHandler):
def __init__(self, name, rank_id, rdzv_params: RendezvousParameters):
def __init__(
self,
name,
rank_id,
rdzv_params: RendezvousParameters,
local_world_size,
):
self._name = name
self._rank_id = rank_id
self._rdzv_params = rdzv_params
self._local_world_size = local_world_size
self.join_timeout = int(rdzv_params.get("join_timeout", 600))
self._client = GlobalMasterClient.MASTER_CLIENT
self._store = MasterKVStore(self._name, timedelta(seconds=60))
Expand All @@ -98,16 +105,16 @@ def set_closed(self):
"""Marks the rendezvous as closed."""
pass

def join_rendezvous(self, local_world_size):
def _join_rendezvous(self):
"""The node join a rendezvous by sending its
ID and local world size.
"""
round = self._client.join_rendezvous(
self._rank_id, local_world_size, rdzv_name=self._name
self._rank_id, self._local_world_size, rdzv_name=self._name
)
return round

def next_rendezvous(self, round):
def next_rendezvous(self):
"""The handler will peroidically query the world from the master until
the world is not empty. The world is a dictionary like
like {0: 8, 1: 8, 2: 8} where the key is the node ID and the value is
Expand All @@ -121,18 +128,28 @@ def next_rendezvous(self, round):
f"rendezvous '{self._name}' with timeout {self.join_timeout}."
)
logger.info(msg)
round = self._join_rendezvous()
while True:
group, world = self._client.get_comm_world(
self._name, self._rank_id
)
world = dict(sorted(world.items()))
if world:
break
if time.time() - start_join > self.join_timeout:
if self._rank_id in world:
break
else:
logger.info(
"The node is not in the world "
"and waits for more nodes."
)
time.sleep(60)
start_join = time.time()
continue
elif time.time() - start_join > self.join_timeout:
raise TimeoutError(
f"Timeout {self.join_timeout}s to complete next rendezous."
)
time.sleep(3)
world = dict(sorted(world.items()))
rank = list(world.keys()).index(self._rank_id)
world_size = len(world)
logger.info(
Expand Down Expand Up @@ -217,8 +234,7 @@ def _rendezvous(self, worker_group: WorkerGroup) -> None:
"""

spec = worker_group.spec
round = spec.rdzv_handler.join_rendezvous(spec.local_world_size)
store, world = spec.rdzv_handler.next_rendezvous(round)
store, world = spec.rdzv_handler.next_rendezvous()
self._store = store
group_world_size = len(world)
group_rank = list(world.keys()).index(self._rank_id)
Expand Down Expand Up @@ -464,6 +480,7 @@ def launch_agent(
RendezvousName.ELASTIC_TRAINING,
rank_id,
rdzv_parameters,
local_world_size=config.nproc_per_node,
)
spec = WorkerSpec(
role=config.role,
Expand Down Expand Up @@ -525,23 +542,21 @@ def launch_agent(
spec.rdzv_handler.shutdown()


class NcclCheckElasticAgent(ElasticTrainingAgent):
class NetworkCheckElasticAgent(ElasticTrainingAgent):
"""
An implementation of :py:class:`torchelastic.agent.server.ElasticAgent`
that handles host-local workers. This agent will run 3 round allgather
to check network. We show the detail with 4 nodes to check network.
Round 1: all nodes join a communication world {0:8, 1:8, 2:8, 3:8}
where the key is the node id and the value is the local world size
of the node. The check passes if allgather of all nodes is succeed.
Otherwise, the round 2 starts.
Round 2: the manager splits nodes into groups and each group contains
two nodes, like [{0:8, 1:8},{2:8, 3:8}]. The node in each group will
execute allgather independently and report its result to the manager.
For example, the result is {0:False, 1:False, 2:True, 3:True}.
Round 3: the manager will group the abnormal node with a normal node like
[{0:8, 2:8}, {1:8, 2:8}]. Then, the node executes allgather again.
If the result is {0:True, 1:False, 2:False, 3:True}, the network of
node-1 if not available.
that handles host-local workers. This agent will run 2 rounds allgather
to check network available.
Round 0: the job master splits nodes into groups and each group contains
two nodes. The node in each group will execute an allgather task and
report its result to the master. For example, a job has 4 nodes and
groups are [{0, 1}, {2, 3}]. Assuming that the allgather task in the
1st group fails, the result is {0:False, 1:False, 2:True, 3:True}
where the node 0, 1 are abnormal.
Round 1: the master will group the abnormal node with a normal node like
[{0, 2}, {1, 3}]. Then, the node executes an allgather task again.
If the result is {0:True, 1:False, 2:False, 3:True}, the node-1
breakdowns.
"""

def __init__(
Expand All @@ -564,7 +579,7 @@ def __init__(
log_dir,
)
self._log_dir = log_dir or tempfile.mkdtemp(prefix="network_check_")
self._max_check_round = 3
self._check_round = 2

def run(self, role: str = DEFAULT_ROLE) -> bool:
spec = self._worker_group.spec
Expand All @@ -575,7 +590,7 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
f"{spec.get_entrypoint_name()}"
)
success = False
for i in range(self._max_check_round):
for i in range(self._check_round):
result = self._run_network_check(spec.monitor_interval)
logger.info(f"Network check round {i} is {result}")
status = NodeStatus.SUCCEEDED if result else NodeStatus.FAILED
Expand All @@ -585,11 +600,16 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
self._stop_workers(self._worker_group)
if network_ready:
return True
elif i == 0 and self._worker_group.group_world_size <= 2:
logger.error(
"Fail to check network when there are only 2 nodes."
)
raise RuntimeError("The node network is breakdown.")
else:
total_worker_num = len(self._client.get_running_nodes())
# If the number of nodes <= 2, we cannot determine which node
# breakdowns because there is no normal node in the job to
# execute allgather tasks with the two nodes.
if total_worker_num <= 2:
logger.error(
"Fail to check network when there are only 2 nodes."
)
raise RuntimeError("The node network is breakdown.")
time.sleep(1)
if not success:
self._client.report_failures(NodeErrorMessage.NETWORKER_ERROR)
Expand Down Expand Up @@ -663,6 +683,7 @@ def network_check(
RendezvousName.NETWORK_CHECK,
rank_id,
rdzv_parameters,
local_world_size=config.nproc_per_node,
)
spec = WorkerSpec(
role=config.role,
Expand All @@ -675,7 +696,7 @@ def network_check(
master_addr=master_addr,
)

agent = NcclCheckElasticAgent(
agent = NetworkCheckElasticAgent(
rank_id=rank_id,
config=config,
entrypoint=entrypoint,
Expand Down
36 changes: 17 additions & 19 deletions dlrover/python/master/elastic_training/rdzv_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ def _check_rdzv_completed(self):
for i in node_ids:
self._rdzv_nodes[i] = self._waiting_nodes[i]
self._latest_rdzv_nodes = list(self._rdzv_nodes.keys())
self._waiting_nodes = dict()
self._waiting_nodes = dict(
set(self._waiting_nodes.items())
- set(self._rdzv_nodes.items())
)
self._lastcall_time = 0
logger.info(
f"Completed {self._rdzv_round} round "
Expand Down Expand Up @@ -162,7 +165,7 @@ def num_nodes_waiting(self):
"""The elastic agent will restart training processes if it
find the number of waiting nodes is not zero. The manager
will notify all nodes to restart training processes immediately if
ab existing node re-joins the next round rendezvous.
an existing node re-joins the next round rendezvous.
If there are new nodes, the master notifies all nodes to re-join
the next round rendezvous only when the number of waiting nodes
is bigger than the number unit of nodes.
Expand Down Expand Up @@ -231,9 +234,6 @@ def get_comm_world(self, rank_id):
rdzv_completed = self._check_rdzv_completed()
if rdzv_completed:
self._rdzv_round += 1

if rank_id not in self._rdzv_nodes:
return self._rdzv_round, {}
return self._rdzv_round, self._rdzv_nodes

def report_network_check_result(self, node_id, normal):
Expand All @@ -242,17 +242,13 @@ def report_network_check_result(self, node_id, normal):

class NetworkCheckRendezvousManager(RendezvousManager):
"""NcclCheckRendezvousManager runs on the DLRover master. The task
to check network contains 3 round to execute allgather on all nodes.
to check network contains 2 round to execute allgather on all nodes.
We show the detail to check network assuming there are 4 nodes.
Round 1: all nodes join a communication world {0:8, 1:8, 2:8, 3:8}
where the key is the node id and the value is the local world size
of the node. The check passes if allgather of all nodes is succeed.
Otherwise, the round 2 starts.
Round 2: the manager splits nodes into groups and each group contains
Round 0: the manager splits nodes into groups and each group contains
two nodes, like [{0:8, 1:8},{2:8, 3:8}]. The node in each group will
execute allgather independently and report its result to the manager.
For example, the result is {0:False, 1:False, 2:True, 3:True}.
Round 3: the manager will group the abnormal node with a normal node like
Round 1: the manager will group the abnormal node with a normal node like
[{0:8, 2:8}, {1:8, 2:8}]. Then, the node executes allgather again.
If the result is {0:True, 1:False, 2:False, 3:True}, the network of
node-1 if not available.
Expand All @@ -264,6 +260,7 @@ def __init__(self):
self._node_status: Dict[int, bool] = {}
self._reported_nodes = set()
self._node_groups: List[Dict[int, int]] = []
self._check_round = 2

def get_comm_world(self, rank_id):
"""Return the communication world if a round rendezvous is completed.
Expand All @@ -278,15 +275,15 @@ def get_comm_world(self, rank_id):
f"Round {self._rdzv_round} "
f"node group: {self._node_groups}"
)
if self._rdzv_round % 3 == 0:
if self._rdzv_round % 2 == 0:
self._node_status = {}
self._reported_nodes = set()
self._rdzv_round += 1

for i, group in enumerate(self._node_groups):
if rank_id in group:
return i, group
return 0, {}
return 0, self._rdzv_nodes

def _group_nodes(self, round):
"""Group nodes into goups.
Expand All @@ -296,11 +293,9 @@ def _group_nodes(self, round):
Round 1: group the abnormal node with a normal node like
[{0:8, 2:8}, {1:8, 2:8}].
"""
round = round % 3
round = round % self._check_round
node_groups: List[Dict[int, int]] = []
if round == 0:
node_groups.append(self._rdzv_nodes)
elif round == 1:
group = {}
for node_id, local_world_size in self._rdzv_nodes.items():
group[node_id] = local_world_size
Expand All @@ -312,7 +307,7 @@ def _group_nodes(self, round):
node_groups[-1].update(group)
else:
node_groups.append(group)
elif round == 2:
elif round == 1:
abnormal_nodes = []
normal_nodes = []
for node_id, status in self._node_status.items():
Expand Down Expand Up @@ -378,7 +373,10 @@ def network_check_success(self):
list(self._node_status.values())
)
if success:
self._rdzv_round = math.ceil(self._rdzv_round / 3) * 3
self._rdzv_round = (
math.ceil(self._rdzv_round / self._check_round)
* self._check_round
)
else:
reason = NetworkFailureReason.NODE_FAILURE
return success, reason
1 change: 0 additions & 1 deletion dlrover/python/master/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def run(self):
while True:
if self._stop_requested:
break
self._remove_not_participated_workers()
if self.job_manager and self.job_manager.all_workers_exited():
if self.job_manager.pend_without_workers():
time.sleep(30)
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/master/node/job_auto_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def __init__(
self._scaler = node_scaler
self._workers = job_nodes[NodeType.WORKER]
self._autoscaling_started = False
self._scale_interval = 30
self._scale_interval = 1800

def start_auto_scaling(self):
"""Start auto-scaling nodes of a job"""
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/master/node/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def get_running_nodes(self):
nodes = self._chief_manager.get_running_nodes()
nodes.extend(self._worker_manager.get_running_nodes())
nodes.extend(self._evaluator_manager.get_running_nodes())
nodes.extend(self._ps_manager.get_training_ps_cluster())
nodes.extend(self._ps_manager.get_running_nodes())
return nodes

def get_running_workers(self):
Expand Down
4 changes: 2 additions & 2 deletions dlrover/python/master/servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,15 @@ def query_ps_nodes(self, request, _):
ps_meta = res.ps_nodes.add()
ps_meta.type = NodeType.PS
ps_meta.addr = ps.service_addr
ps_meta.cpu = int(ps.config_resource.cpu)
ps_meta.cpu = ps.config_resource.cpu
ps_meta.memory = int(ps.config_resource.memory)
logger.info("PS nodes : %s", res)
res.new_ps_ready = ready
res.ps_failure = ps_failure
return res

def query_running_nodes(self, request, _):
nodes: List[Node] = self._job_manager.get_all_running_nodes()
nodes: List[Node] = self._job_manager.get_running_nodes()
res = elastic_training_pb2.RunningNodes()
for node in nodes:
meta = elastic_training_pb2.NodeMeta()
Expand Down
7 changes: 4 additions & 3 deletions dlrover/python/tests/test_elastic_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def setUp(self) -> None:
RendezvousName.ELASTIC_TRAINING,
node_id,
rdzv_parameters,
local_world_size=self.config.nproc_per_node,
)

self.spec = WorkerSpec(
Expand All @@ -75,9 +76,9 @@ def test_rank0_rendzevous(self):
start_method=self.config.start_method,
log_dir=self.config.log_dir,
)
self.rdzv_handler.join_rendezvous(8)
self.rdzv_handler._join_rendezvous()
self.rdzv_handler._client.join_rendezvous(1, 8)
_, world = self.rdzv_handler.next_rendezvous(0)
_, world = self.rdzv_handler.next_rendezvous()
self.assertDictEqual(world, {0: 8, 1: 8})

worker_group = agent._worker_group
Expand Down Expand Up @@ -108,7 +109,7 @@ def test_rank1_rendzevous(self):
store.set("MASTER_PORT", "12345".encode())
self.rdzv_handler._client.join_rendezvous(1, 8)
self.rdzv_handler._client.join_rendezvous(0, 8)
_, world = self.rdzv_handler.next_rendezvous(0)
_, world = self.rdzv_handler.next_rendezvous()
self.assertDictEqual(world, {0: 8, 1: 8})
worker_group = agent._worker_group
agent._rendezvous(agent._worker_group)
Expand Down
Loading

0 comments on commit f09a647

Please sign in to comment.