From 857f0358598a1386ca86aedc0b4fe8941c082be6 Mon Sep 17 00:00:00 2001 From: "jimmy.qin" Date: Mon, 28 Oct 2024 15:48:13 +0800 Subject: [PATCH] Fix: Use return instead of break to correctly exit loop on path existence check --- dlrover/go/operator/Makefile | 2 +- dlrover/python/elastic_agent/torch/ckpt_saver.py | 7 ++++--- .../trainer/torch/flash_checkpoint/megatron_dist_ckpt.py | 8 +++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/dlrover/go/operator/Makefile b/dlrover/go/operator/Makefile index 17b18a4a1..900e0ca9f 100644 --- a/dlrover/go/operator/Makefile +++ b/dlrover/go/operator/Makefile @@ -40,7 +40,7 @@ help: ## Display this help. .PHONY: manifests manifests: controller-gen ## Generate WebhookConfiguration, ClusterRole and CustomResourceDefinition objects. - $(CONTROLLER_GEN) rbac:roleName=manager-role crd webhook paths="./..." output:crd:artifacts:config=config/crd/bases + $(CONTROLLER_GEN) rbac:roleName=manager-role crd:generateEmbeddedObjectMeta=true webhook paths="./..." output:crd:artifacts:config=config/crd/bases .PHONY: generate generate: controller-gen ## Generate code containing DeepCopy, DeepCopyInto, and DeepCopyObject method implementations. diff --git a/dlrover/python/elastic_agent/torch/ckpt_saver.py b/dlrover/python/elastic_agent/torch/ckpt_saver.py index 3c5211982..eea82a984 100644 --- a/dlrover/python/elastic_agent/torch/ckpt_saver.py +++ b/dlrover/python/elastic_agent/torch/ckpt_saver.py @@ -625,7 +625,7 @@ def _save_shard( f"The step {step} in event is no equal " f"to step {config.step} in memory." ) - return + return False logger.info( f"Saves the checkpoint shard {local_shard_id} " @@ -659,7 +659,7 @@ def _dist_make_dir(self, path, timeout=30): else: for _ in range(timeout): if self.storage.exists(path): - break + return time.sleep(1) logger.warning( f"Worker {self._node_rank} can't find path {path} " @@ -914,7 +914,7 @@ def save_step_checkpoint(self, step: int): f"Fail to save checkpoint shared {i} for step {step}" ) - if success_count == self.local_shard_num: + if success_count == len(futures): write_success = True self._latest_step = step @@ -923,6 +923,7 @@ def save_step_checkpoint(self, step: int): f"Rank {self._node_rank} save checkpoint failed for " f"step {step}" ) + self._writing_storage = False return # commit checkpoint diff --git a/dlrover/trainer/torch/flash_checkpoint/megatron_dist_ckpt.py b/dlrover/trainer/torch/flash_checkpoint/megatron_dist_ckpt.py index 2034356af..262bce8f3 100644 --- a/dlrover/trainer/torch/flash_checkpoint/megatron_dist_ckpt.py +++ b/dlrover/trainer/torch/flash_checkpoint/megatron_dist_ckpt.py @@ -26,6 +26,7 @@ try: from megatron.core import mpu, tensor_parallel from megatron.core.optimizer.optimizer import ChainedOptimizer + from megatron.core.num_microbatches_calculator import update_num_microbatches from megatron.training import get_args from megatron.training.checkpointing import ( check_checkpoint_args, @@ -37,13 +38,15 @@ get_rng_state, read_metadata, set_checkpoint_version, - update_num_microbatches, ) from megatron.training.utils import print_rank_0, unwrap_model except ImportError: # Keep back compatibility with Megatron-LM. try: - from megatron import get_args + from megatron import ( + get_args, + update_num_microbatches, + ) from megatron.checkpointing import ( check_checkpoint_args, find_checkpoint_rank_0, @@ -54,7 +57,6 @@ get_rng_state, read_metadata, set_checkpoint_version, - update_num_microbatches, ) from megatron.optimizer.optimizer import ChainedOptimizer from megatron.utils import print_rank_0, unwrap_model