Skip to content

Commit 915bef8

Browse files
committed
fix to make ucp load more lenient
Signed-off-by: Schwidola0607 <[email protected]>
1 parent 662a297 commit 915bef8

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

deepspeed/runtime/engine.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2899,7 +2899,7 @@ def _get_all_ckpt_names(self, checkpoints_path, tag):
28992899

29002900
ckpt_files = glob.glob(ckpt_file_pattern)
29012901
ckpt_files.sort()
2902-
return ckpt_files
2902+
return ckpt_files, ckpt_file_pattern
29032903

29042904
def load_checkpoint(self,
29052905
load_dir,
@@ -2923,7 +2923,7 @@ def load_checkpoint(self,
29232923
29242924
Returns:
29252925
A tuple of ``load_path`` and ``client_state``.
2926-
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
2926+
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed or loading a HF based UCP
29272927
*``client_state``: State dictionary used for loading required training states in the client code.
29282928
29292929
Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
@@ -2962,6 +2962,12 @@ def load_checkpoint(self,
29622962
custom_load_fn=custom_load_fn)
29632963

29642964
load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
2965+
# import pdb; pdb.set_trace()
2966+
if self.load_universal_checkpoint():
2967+
ucp_ckpt_folder = os.path.join(load_dir, tag)
2968+
# UCP load can ignore '*mp' files or '*model_states.pt' but ucp_ckpt_folder must exist
2969+
load_zero_checkpoint = os.path.isdir(ucp_ckpt_folder)
2970+
29652971
if load_zero_checkpoint:
29662972
if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint():
29672973
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
@@ -3002,7 +3008,11 @@ def _load_checkpoint(self,
30023008

30033009
from deepspeed.runtime.state_dict_factory import SDLoaderFactory
30043010

3005-
ckpt_list = self._get_all_ckpt_names(load_dir, tag)
3011+
ckpt_list, ckpt_file_pattern = self._get_all_ckpt_names(load_dir, tag)
3012+
if self.load_universal_checkpoint() and len(ckpt_list) == 0:
3013+
logger.warning(f"Unable to find {ckpt_file_pattern} files in UCP folder {load_dir}")
3014+
return None, {}
3015+
30063016
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine)
30073017

30083018
is_pipe_parallel = isinstance(self.module, PipelineModule)

deepspeed/runtime/zero/stage3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2777,7 +2777,6 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
27772777
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
27782778
self.optimizer_swapper.purge_state()
27792779

2780-
if self.swap_optimizer:
27812780
# Touch all parameters to synchronize all buffers
27822781
timer_names = set()
27832782
self._partition_all_parameters()
@@ -2812,7 +2811,7 @@ def _load_global_state_stage3(self, sd):
28122811

28132812
def load_hp_checkpoint_state(self, folder, key):
28142813
local_rank = dist.get_local_rank()
2815-
2814+
28162815
# Load tensors from files and reshape them to flat vectors
28172816
loaded_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False)
28182817
if isinstance(loaded_state, dict):

0 commit comments

Comments
 (0)