diff --git a/dlrover/proto/elastic_training.proto b/dlrover/proto/elastic_training.proto index 7f479a871..9bffee350 100644 --- a/dlrover/proto/elastic_training.proto +++ b/dlrover/proto/elastic_training.proto @@ -165,7 +165,7 @@ message NodeMeta { string type = 1; string addr = 2; int32 memory = 3; - int32 cpu = 4; + float cpu = 4; int32 gpu = 5; string gpu_type = 6; int32 id = 7; diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 9ece00ad3..7326cfd9a 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -542,23 +542,21 @@ def launch_agent( spec.rdzv_handler.shutdown() -class NcclCheckElasticAgent(ElasticTrainingAgent): +class NetworkCheckElasticAgent(ElasticTrainingAgent): """ An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` - that handles host-local workers. This agent will run 3 round allgather - to check network. We show the detail with 4 nodes to check network. - Round 1: all nodes join a communication world {0:8, 1:8, 2:8, 3:8} - where the key is the node id and the value is the local world size - of the node. The check passes if allgather of all nodes is succeed. - Otherwise, the round 2 starts. - Round 2: the manager splits nodes into groups and each group contains - two nodes, like [{0:8, 1:8},{2:8, 3:8}]. The node in each group will - execute allgather independently and report its result to the manager. - For example, the result is {0:False, 1:False, 2:True, 3:True}. - Round 3: the manager will group the abnormal node with a normal node like - [{0:8, 2:8}, {1:8, 2:8}]. Then, the node executes allgather again. - If the result is {0:True, 1:False, 2:False, 3:True}, the network of - node-1 if not available. + that handles host-local workers. This agent will run 2 rounds allgather + to check network available. + Round 0: the job master splits nodes into groups and each group contains + two nodes. The node in each group will execute an allgather task and + report its result to the master. For example, a job has 4 nodes and + groups are [{0, 1}, {2, 3}]. Assuming that the allgather task in the + 1st group fails, the result is {0:False, 1:False, 2:True, 3:True} + where the node 0, 1 are abnormal. + Round 1: the master will group the abnormal node with a normal node like + [{0, 2}, {1, 3}]. Then, the node executes an allgather task again. + If the result is {0:True, 1:False, 2:False, 3:True}, the node-1 + breakdowns. """ def __init__( @@ -581,7 +579,7 @@ def __init__( log_dir, ) self._log_dir = log_dir or tempfile.mkdtemp(prefix="network_check_") - self._max_check_round = 3 + self._check_round = 2 def run(self, role: str = DEFAULT_ROLE) -> bool: spec = self._worker_group.spec @@ -592,7 +590,7 @@ def run(self, role: str = DEFAULT_ROLE) -> bool: f"{spec.get_entrypoint_name()}" ) success = False - for i in range(self._max_check_round): + for i in range(self._check_round): result = self._run_network_check(spec.monitor_interval) logger.info(f"Network check round {i} is {result}") status = NodeStatus.SUCCEEDED if result else NodeStatus.FAILED @@ -602,11 +600,16 @@ def run(self, role: str = DEFAULT_ROLE) -> bool: self._stop_workers(self._worker_group) if network_ready: return True - elif i == 0 and self._worker_group.group_world_size <= 2: - logger.error( - "Fail to check network when there are only 2 nodes." - ) - raise RuntimeError("The node network is breakdown.") + else: + total_worker_num = len(self._client.get_running_nodes()) + # If the number of nodes <= 2, we cannot determine which node + # breakdowns because there is no normal node in the job to + # execute allgather tasks with the two nodes. + if total_worker_num <= 2: + logger.error( + "Fail to check network when there are only 2 nodes." + ) + raise RuntimeError("The node network is breakdown.") time.sleep(1) if not success: self._client.report_failures(NodeErrorMessage.NETWORKER_ERROR) @@ -693,7 +696,7 @@ def network_check( master_addr=master_addr, ) - agent = NcclCheckElasticAgent( + agent = NetworkCheckElasticAgent( rank_id=rank_id, config=config, entrypoint=entrypoint, diff --git a/dlrover/python/master/elastic_training/rdzv_manager.py b/dlrover/python/master/elastic_training/rdzv_manager.py index 96da90385..f52192af4 100644 --- a/dlrover/python/master/elastic_training/rdzv_manager.py +++ b/dlrover/python/master/elastic_training/rdzv_manager.py @@ -242,17 +242,13 @@ def report_network_check_result(self, node_id, normal): class NetworkCheckRendezvousManager(RendezvousManager): """NcclCheckRendezvousManager runs on the DLRover master. The task - to check network contains 3 round to execute allgather on all nodes. + to check network contains 2 round to execute allgather on all nodes. We show the detail to check network assuming there are 4 nodes. - Round 1: all nodes join a communication world {0:8, 1:8, 2:8, 3:8} - where the key is the node id and the value is the local world size - of the node. The check passes if allgather of all nodes is succeed. - Otherwise, the round 2 starts. - Round 2: the manager splits nodes into groups and each group contains + Round 0: the manager splits nodes into groups and each group contains two nodes, like [{0:8, 1:8},{2:8, 3:8}]. The node in each group will execute allgather independently and report its result to the manager. For example, the result is {0:False, 1:False, 2:True, 3:True}. - Round 3: the manager will group the abnormal node with a normal node like + Round 1: the manager will group the abnormal node with a normal node like [{0:8, 2:8}, {1:8, 2:8}]. Then, the node executes allgather again. If the result is {0:True, 1:False, 2:False, 3:True}, the network of node-1 if not available. @@ -264,6 +260,7 @@ def __init__(self): self._node_status: Dict[int, bool] = {} self._reported_nodes = set() self._node_groups: List[Dict[int, int]] = [] + self._check_round = 2 def get_comm_world(self, rank_id): """Return the communication world if a round rendezvous is completed. @@ -278,7 +275,7 @@ def get_comm_world(self, rank_id): f"Round {self._rdzv_round} " f"node group: {self._node_groups}" ) - if self._rdzv_round % 3 == 0: + if self._rdzv_round % 2 == 0: self._node_status = {} self._reported_nodes = set() self._rdzv_round += 1 @@ -296,11 +293,9 @@ def _group_nodes(self, round): Round 1: group the abnormal node with a normal node like [{0:8, 2:8}, {1:8, 2:8}]. """ - round = round % 3 + round = round % self._check_round node_groups: List[Dict[int, int]] = [] if round == 0: - node_groups.append(self._rdzv_nodes) - elif round == 1: group = {} for node_id, local_world_size in self._rdzv_nodes.items(): group[node_id] = local_world_size @@ -312,7 +307,7 @@ def _group_nodes(self, round): node_groups[-1].update(group) else: node_groups.append(group) - elif round == 2: + elif round == 1: abnormal_nodes = [] normal_nodes = [] for node_id, status in self._node_status.items(): @@ -378,7 +373,10 @@ def network_check_success(self): list(self._node_status.values()) ) if success: - self._rdzv_round = math.ceil(self._rdzv_round / 3) * 3 + self._rdzv_round = ( + math.ceil(self._rdzv_round / self._check_round) + * self._check_round + ) else: reason = NetworkFailureReason.NODE_FAILURE return success, reason diff --git a/dlrover/python/master/node/job_manager.py b/dlrover/python/master/node/job_manager.py index 666966de2..fbe7e8717 100644 --- a/dlrover/python/master/node/job_manager.py +++ b/dlrover/python/master/node/job_manager.py @@ -560,7 +560,7 @@ def get_running_nodes(self): nodes = self._chief_manager.get_running_nodes() nodes.extend(self._worker_manager.get_running_nodes()) nodes.extend(self._evaluator_manager.get_running_nodes()) - nodes.extend(self._ps_manager.get_training_ps_cluster()) + nodes.extend(self._ps_manager.get_running_nodes()) return nodes def get_running_workers(self): diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index 986e909c6..df96f29de 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -328,7 +328,7 @@ def query_ps_nodes(self, request, _): ps_meta = res.ps_nodes.add() ps_meta.type = NodeType.PS ps_meta.addr = ps.service_addr - ps_meta.cpu = int(ps.config_resource.cpu) + ps_meta.cpu = ps.config_resource.cpu ps_meta.memory = int(ps.config_resource.memory) logger.info("PS nodes : %s", res) res.new_ps_ready = ready @@ -336,7 +336,7 @@ def query_ps_nodes(self, request, _): return res def query_running_nodes(self, request, _): - nodes: List[Node] = self._job_manager.get_all_running_nodes() + nodes: List[Node] = self._job_manager.get_running_nodes() res = elastic_training_pb2.RunningNodes() for node in nodes: meta = elastic_training_pb2.NodeMeta() diff --git a/dlrover/python/tests/test_rdzv_manager.py b/dlrover/python/tests/test_rdzv_manager.py index 2a5add0fd..e98ed3b00 100644 --- a/dlrover/python/tests/test_rdzv_manager.py +++ b/dlrover/python/tests/test_rdzv_manager.py @@ -134,28 +134,31 @@ def test_network_check_rdzv(self): self.assertEqual(group, 0) self.assertEqual(len(rdzv_manager._waiting_nodes), 0) self.assertEqual(len(rdzv_manager._rdzv_nodes), 4) - self.assertDictEqual(world, {0: 8, 1: 8, 2: 8, 3: 8}) - for i in range(4): - round = rdzv_manager.join_rendezvous(i, 8) - self.assertEqual(round, 1) - group, world = rdzv_manager.get_comm_world(0) self.assertDictEqual(world, {0: 8, 1: 8}) - group, world = rdzv_manager.get_comm_world(1) self.assertEqual(group, 0) - self.assertDictEqual(world, {0: 8, 1: 8}) group, world = rdzv_manager.get_comm_world(2) self.assertDictEqual(world, {2: 8, 3: 8}) self.assertEqual(group, 1) for i in range(3): rdzv_manager.report_network_check_result(i, True) rdzv_manager.report_network_check_result(3, False) + + for i in range(4): + round = rdzv_manager.join_rendezvous(i, 8) + self.assertEqual(round, 1) + group, world = rdzv_manager.get_comm_world(0) + self.assertDictEqual(world, {3: 8, 0: 8}) + group, world = rdzv_manager.get_comm_world(1) + self.assertDictEqual(world, {1: 8, 2: 8}) + self.assertEqual(group, 1) success, _ = rdzv_manager.network_check_success() self.assertFalse(success) + for i in range(4): round = rdzv_manager.join_rendezvous(i, 8) self.assertEqual(round, 2) group, world = rdzv_manager.get_comm_world(3) - self.assertDictEqual(world, {0: 8, 3: 8}) + self.assertDictEqual(world, {2: 8, 3: 8}) _, reason = rdzv_manager.network_check_success() self.assertEqual(reason, NetworkFailureReason.WAITING_NODE) for i in range(3): diff --git a/dlrover/python/tests/test_servicer.py b/dlrover/python/tests/test_servicer.py index aa3b6a060..80f08e232 100644 --- a/dlrover/python/tests/test_servicer.py +++ b/dlrover/python/tests/test_servicer.py @@ -134,7 +134,7 @@ def test_metric_service(self): self.job_metric_collector._report_runtime_stats() self.assertEqual(len(reporter._runtime_stats), 2) self.assertEqual(reporter._runtime_stats[0].global_step, 1100) - self.assertEqual(len(reporter._runtime_stats[0].running_nodes), 4) + self.assertEqual(len(reporter._runtime_stats[0].running_nodes), 2) request.timestamp = ts + 20 request.global_step = 2100