Skip to content

Commit

Permalink
Remove the round to test allgather with all nodes (#512)
Browse files Browse the repository at this point in the history
* Remove the round to test allgather with all nodes

* Fix test cases

* Replace a magic number with a variable

* Fix docstring by comments

* Format codes
  • Loading branch information
workingloong authored Jul 24, 2023
1 parent 357b499 commit 2d86add
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 49 deletions.
2 changes: 1 addition & 1 deletion dlrover/proto/elastic_training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ message NodeMeta {
string type = 1;
string addr = 2;
int32 memory = 3;
int32 cpu = 4;
float cpu = 4;
int32 gpu = 5;
string gpu_type = 6;
int32 id = 7;
Expand Down
49 changes: 26 additions & 23 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,23 +542,21 @@ def launch_agent(
spec.rdzv_handler.shutdown()


class NcclCheckElasticAgent(ElasticTrainingAgent):
class NetworkCheckElasticAgent(ElasticTrainingAgent):
"""
An implementation of :py:class:`torchelastic.agent.server.ElasticAgent`
that handles host-local workers. This agent will run 3 round allgather
to check network. We show the detail with 4 nodes to check network.
Round 1: all nodes join a communication world {0:8, 1:8, 2:8, 3:8}
where the key is the node id and the value is the local world size
of the node. The check passes if allgather of all nodes is succeed.
Otherwise, the round 2 starts.
Round 2: the manager splits nodes into groups and each group contains
two nodes, like [{0:8, 1:8},{2:8, 3:8}]. The node in each group will
execute allgather independently and report its result to the manager.
For example, the result is {0:False, 1:False, 2:True, 3:True}.
Round 3: the manager will group the abnormal node with a normal node like
[{0:8, 2:8}, {1:8, 2:8}]. Then, the node executes allgather again.
If the result is {0:True, 1:False, 2:False, 3:True}, the network of
node-1 if not available.
that handles host-local workers. This agent will run 2 rounds allgather
to check network available.
Round 0: the job master splits nodes into groups and each group contains
two nodes. The node in each group will execute an allgather task and
report its result to the master. For example, a job has 4 nodes and
groups are [{0, 1}, {2, 3}]. Assuming that the allgather task in the
1st group fails, the result is {0:False, 1:False, 2:True, 3:True}
where the node 0, 1 are abnormal.
Round 1: the master will group the abnormal node with a normal node like
[{0, 2}, {1, 3}]. Then, the node executes an allgather task again.
If the result is {0:True, 1:False, 2:False, 3:True}, the node-1
breakdowns.
"""

def __init__(
Expand All @@ -581,7 +579,7 @@ def __init__(
log_dir,
)
self._log_dir = log_dir or tempfile.mkdtemp(prefix="network_check_")
self._max_check_round = 3
self._check_round = 2

def run(self, role: str = DEFAULT_ROLE) -> bool:
spec = self._worker_group.spec
Expand All @@ -592,7 +590,7 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
f"{spec.get_entrypoint_name()}"
)
success = False
for i in range(self._max_check_round):
for i in range(self._check_round):
result = self._run_network_check(spec.monitor_interval)
logger.info(f"Network check round {i} is {result}")
status = NodeStatus.SUCCEEDED if result else NodeStatus.FAILED
Expand All @@ -602,11 +600,16 @@ def run(self, role: str = DEFAULT_ROLE) -> bool:
self._stop_workers(self._worker_group)
if network_ready:
return True
elif i == 0 and self._worker_group.group_world_size <= 2:
logger.error(
"Fail to check network when there are only 2 nodes."
)
raise RuntimeError("The node network is breakdown.")
else:
total_worker_num = len(self._client.get_running_nodes())
# If the number of nodes <= 2, we cannot determine which node
# breakdowns because there is no normal node in the job to
# execute allgather tasks with the two nodes.
if total_worker_num <= 2:
logger.error(
"Fail to check network when there are only 2 nodes."
)
raise RuntimeError("The node network is breakdown.")
time.sleep(1)
if not success:
self._client.report_failures(NodeErrorMessage.NETWORKER_ERROR)
Expand Down Expand Up @@ -693,7 +696,7 @@ def network_check(
master_addr=master_addr,
)

agent = NcclCheckElasticAgent(
agent = NetworkCheckElasticAgent(
rank_id=rank_id,
config=config,
entrypoint=entrypoint,
Expand Down
24 changes: 11 additions & 13 deletions dlrover/python/master/elastic_training/rdzv_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,13 @@ def report_network_check_result(self, node_id, normal):

class NetworkCheckRendezvousManager(RendezvousManager):
"""NcclCheckRendezvousManager runs on the DLRover master. The task
to check network contains 3 round to execute allgather on all nodes.
to check network contains 2 round to execute allgather on all nodes.
We show the detail to check network assuming there are 4 nodes.
Round 1: all nodes join a communication world {0:8, 1:8, 2:8, 3:8}
where the key is the node id and the value is the local world size
of the node. The check passes if allgather of all nodes is succeed.
Otherwise, the round 2 starts.
Round 2: the manager splits nodes into groups and each group contains
Round 0: the manager splits nodes into groups and each group contains
two nodes, like [{0:8, 1:8},{2:8, 3:8}]. The node in each group will
execute allgather independently and report its result to the manager.
For example, the result is {0:False, 1:False, 2:True, 3:True}.
Round 3: the manager will group the abnormal node with a normal node like
Round 1: the manager will group the abnormal node with a normal node like
[{0:8, 2:8}, {1:8, 2:8}]. Then, the node executes allgather again.
If the result is {0:True, 1:False, 2:False, 3:True}, the network of
node-1 if not available.
Expand All @@ -264,6 +260,7 @@ def __init__(self):
self._node_status: Dict[int, bool] = {}
self._reported_nodes = set()
self._node_groups: List[Dict[int, int]] = []
self._check_round = 2

def get_comm_world(self, rank_id):
"""Return the communication world if a round rendezvous is completed.
Expand All @@ -278,7 +275,7 @@ def get_comm_world(self, rank_id):
f"Round {self._rdzv_round} "
f"node group: {self._node_groups}"
)
if self._rdzv_round % 3 == 0:
if self._rdzv_round % 2 == 0:
self._node_status = {}
self._reported_nodes = set()
self._rdzv_round += 1
Expand All @@ -296,11 +293,9 @@ def _group_nodes(self, round):
Round 1: group the abnormal node with a normal node like
[{0:8, 2:8}, {1:8, 2:8}].
"""
round = round % 3
round = round % self._check_round
node_groups: List[Dict[int, int]] = []
if round == 0:
node_groups.append(self._rdzv_nodes)
elif round == 1:
group = {}
for node_id, local_world_size in self._rdzv_nodes.items():
group[node_id] = local_world_size
Expand All @@ -312,7 +307,7 @@ def _group_nodes(self, round):
node_groups[-1].update(group)
else:
node_groups.append(group)
elif round == 2:
elif round == 1:
abnormal_nodes = []
normal_nodes = []
for node_id, status in self._node_status.items():
Expand Down Expand Up @@ -378,7 +373,10 @@ def network_check_success(self):
list(self._node_status.values())
)
if success:
self._rdzv_round = math.ceil(self._rdzv_round / 3) * 3
self._rdzv_round = (
math.ceil(self._rdzv_round / self._check_round)
* self._check_round
)
else:
reason = NetworkFailureReason.NODE_FAILURE
return success, reason
2 changes: 1 addition & 1 deletion dlrover/python/master/node/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def get_running_nodes(self):
nodes = self._chief_manager.get_running_nodes()
nodes.extend(self._worker_manager.get_running_nodes())
nodes.extend(self._evaluator_manager.get_running_nodes())
nodes.extend(self._ps_manager.get_training_ps_cluster())
nodes.extend(self._ps_manager.get_running_nodes())
return nodes

def get_running_workers(self):
Expand Down
4 changes: 2 additions & 2 deletions dlrover/python/master/servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,15 @@ def query_ps_nodes(self, request, _):
ps_meta = res.ps_nodes.add()
ps_meta.type = NodeType.PS
ps_meta.addr = ps.service_addr
ps_meta.cpu = int(ps.config_resource.cpu)
ps_meta.cpu = ps.config_resource.cpu
ps_meta.memory = int(ps.config_resource.memory)
logger.info("PS nodes : %s", res)
res.new_ps_ready = ready
res.ps_failure = ps_failure
return res

def query_running_nodes(self, request, _):
nodes: List[Node] = self._job_manager.get_all_running_nodes()
nodes: List[Node] = self._job_manager.get_running_nodes()
res = elastic_training_pb2.RunningNodes()
for node in nodes:
meta = elastic_training_pb2.NodeMeta()
Expand Down
19 changes: 11 additions & 8 deletions dlrover/python/tests/test_rdzv_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,28 +134,31 @@ def test_network_check_rdzv(self):
self.assertEqual(group, 0)
self.assertEqual(len(rdzv_manager._waiting_nodes), 0)
self.assertEqual(len(rdzv_manager._rdzv_nodes), 4)
self.assertDictEqual(world, {0: 8, 1: 8, 2: 8, 3: 8})
for i in range(4):
round = rdzv_manager.join_rendezvous(i, 8)
self.assertEqual(round, 1)
group, world = rdzv_manager.get_comm_world(0)
self.assertDictEqual(world, {0: 8, 1: 8})
group, world = rdzv_manager.get_comm_world(1)
self.assertEqual(group, 0)
self.assertDictEqual(world, {0: 8, 1: 8})
group, world = rdzv_manager.get_comm_world(2)
self.assertDictEqual(world, {2: 8, 3: 8})
self.assertEqual(group, 1)
for i in range(3):
rdzv_manager.report_network_check_result(i, True)
rdzv_manager.report_network_check_result(3, False)

for i in range(4):
round = rdzv_manager.join_rendezvous(i, 8)
self.assertEqual(round, 1)
group, world = rdzv_manager.get_comm_world(0)
self.assertDictEqual(world, {3: 8, 0: 8})
group, world = rdzv_manager.get_comm_world(1)
self.assertDictEqual(world, {1: 8, 2: 8})
self.assertEqual(group, 1)
success, _ = rdzv_manager.network_check_success()
self.assertFalse(success)

for i in range(4):
round = rdzv_manager.join_rendezvous(i, 8)
self.assertEqual(round, 2)
group, world = rdzv_manager.get_comm_world(3)
self.assertDictEqual(world, {0: 8, 3: 8})
self.assertDictEqual(world, {2: 8, 3: 8})
_, reason = rdzv_manager.network_check_success()
self.assertEqual(reason, NetworkFailureReason.WAITING_NODE)
for i in range(3):
Expand Down
2 changes: 1 addition & 1 deletion dlrover/python/tests/test_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_metric_service(self):
self.job_metric_collector._report_runtime_stats()
self.assertEqual(len(reporter._runtime_stats), 2)
self.assertEqual(reporter._runtime_stats[0].global_step, 1100)
self.assertEqual(len(reporter._runtime_stats[0].running_nodes), 4)
self.assertEqual(len(reporter._runtime_stats[0].running_nodes), 2)

request.timestamp = ts + 20
request.global_step = 2100
Expand Down

0 comments on commit 2d86add

Please sign in to comment.