@@ -2899,7 +2899,7 @@ def _get_all_ckpt_names(self, checkpoints_path, tag):
2899
2899
2900
2900
ckpt_files = glob .glob (ckpt_file_pattern )
2901
2901
ckpt_files .sort ()
2902
- return ckpt_files
2902
+ return ckpt_files , ckpt_file_pattern
2903
2903
2904
2904
def load_checkpoint (self ,
2905
2905
load_dir ,
@@ -2923,7 +2923,7 @@ def load_checkpoint(self,
2923
2923
2924
2924
Returns:
2925
2925
A tuple of ``load_path`` and ``client_state``.
2926
- *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
2926
+ *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed or loading a HF based UCP
2927
2927
*``client_state``: State dictionary used for loading required training states in the client code.
2928
2928
2929
2929
Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
@@ -2962,6 +2962,12 @@ def load_checkpoint(self,
2962
2962
custom_load_fn = custom_load_fn )
2963
2963
2964
2964
load_zero_checkpoint = load_path is not None and (self .zero_optimization () or self .bfloat16_enabled ())
2965
+ # import pdb; pdb.set_trace()
2966
+ if self .load_universal_checkpoint ():
2967
+ ucp_ckpt_folder = os .path .join (load_dir , tag )
2968
+ # UCP load can ignore '*mp' files or '*model_states.pt' but ucp_ckpt_folder must exist
2969
+ load_zero_checkpoint = os .path .isdir (ucp_ckpt_folder )
2970
+
2965
2971
if load_zero_checkpoint :
2966
2972
if (load_optimizer_states and not load_module_only ) or self .load_universal_checkpoint ():
2967
2973
success = self ._load_zero_checkpoint (load_dir , tag , load_optimizer_states = load_optimizer_states )
@@ -3002,7 +3008,11 @@ def _load_checkpoint(self,
3002
3008
3003
3009
from deepspeed .runtime .state_dict_factory import SDLoaderFactory
3004
3010
3005
- ckpt_list = self ._get_all_ckpt_names (load_dir , tag )
3011
+ ckpt_list , ckpt_file_pattern = self ._get_all_ckpt_names (load_dir , tag )
3012
+ if self .load_universal_checkpoint () and len (ckpt_list ) == 0 :
3013
+ logger .warning (f"Unable to find { ckpt_file_pattern } files in UCP folder { load_dir } " )
3014
+ return None , {}
3015
+
3006
3016
sd_loader = SDLoaderFactory .get_sd_loader (ckpt_list , checkpoint_engine = self .checkpoint_engine )
3007
3017
3008
3018
is_pipe_parallel = isinstance (self .module , PipelineModule )
0 commit comments