diff --git a/setup.py b/setup.py index aafecf4..c480f32 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,13 @@ def get_and_set_version(): packages=['torchacc'] + ['torchacc.' + \ pkg for pkg in find_packages('torchacc')], + # add console_scripts + entry_points={ + 'console_scripts': [ + 'consolidate_and_reshape_ckpts = torchacc.utils.consolidate_and_reshard_ckpts:main', + ], + }, + # Add _ prefix to the names of temporary build dirs options={'build': {'build_base': '_build'}, }, zip_safe=True, diff --git a/tests/standalone/consolidate_and_reshard_ckpts.py b/tests/standalone/consolidate_and_reshard_ckpts.py index cd39562..2e8a8d1 100644 --- a/tests/standalone/consolidate_and_reshard_ckpts.py +++ b/tests/standalone/consolidate_and_reshard_ckpts.py @@ -202,17 +202,17 @@ def main(args): ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-model.pth", reshard_num=reshard_num, - save_model=True, + save_model=False, ) - print(f"model consolidate and reshard to path:{ckpt_dir}") + print(f"model consolidate and reshard done.") optim_reshard_dicts, _ = consolidate_and_reshard_optim_dict( ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-optim.pth", reshard_num=reshard_num, - save_optimizer=True, + save_optimizer=False, ) - print(f"optimizer consolidate and reshard to path:{ckpt_dir}") + print(f"optimizer consolidate and reshard done.") # compare shard model and optimizer if reshard_num == fsdp_num: diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index 060eca1..f120a36 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -436,54 +436,57 @@ def reshard_optim_dict(consolidate_optim_dict, shard_optim, layer_name_lists, return shard_optim_state_dict_list, shard_metadata_list -def consolidate_and_reshard_model_dict(ckpt_dir, - ckpt_name="", - reshard_num=1, - save_path="", - save_model=True): +def consolidate_and_reshard_fsdp_model_dict(ckpt_dir, + ckpt_name, + save_dir="", + save_name="", + reshard_num=1, + save_model=True): """ Consolidate the sharded FSDP checkpoints into a single model checkpoint. Then reshard the FSDP model according to the reshard_num. Args: ckpt_dir (str): - The dir to FSDP checkpoint files from all ranks - ckpt_name (str, Optional): - The name_pattern to FSDP checkpoint files from all ranks. Files matching the - pattern ``ckpt_dir + ckpt_name`` will be loaded. The each + The dir to all FSDP shard model checkpoint files. + ckpt_name (str): + The name_pattern to all FSDP shard model checkpoint files. Files matching the + pattern ``ckpt_dir + ckpt_name`` will be loaded. Each checkpoint file is assumed to be a dict with a "model" key containing the FSDP model's ``model.state_dict()`` and a "shard_metadata" key containing the FSDP model's ``model.get_shard_metadata()``. + save_dir (str): + The save dir for consolidate or reshard model checkpoints. + save_name (str, Optional): + The name_pattern for consolidate or reshard model checkpoints. + For reshard checkpoints name pattern: ``rank*-of-*-model.pth`` + The final save_path is save_dir + save_name. reshard_num (int, Optional): - Reshard the fsdp model with reshard_num. If set to 1, we don't need to do + Reshard the fsdp model by reshard_num. If set to 1, we don't need to do resharding. - save_path (str, Optional): - the save path to the consolidated model checkpoint file (if - ``save_model`` is ``True``). The checkpoint file is a dict with a - "model" key containing the consolidated model state dict. save_model (str, Optional): - if ``True``, the model checkpoint will be saved to - ``save_path`` (or ``ckpt_dir + "consolidated_model.pth"`` if - ``save_path`` is empty). + if ``True``, the model checkpoint will be saved to ``save_dir + save_name``. Returns: model_state_dict: the consolidated model state dict or reshard model state dict list. - shard_meta_list: the reshard metadatalist. The consolidated model return None. + shard_meta_list: the reshard metadatalist. For consolidated model, return None. """ - checkpoints = load_checkpoints(ckpt_dir, ckpt_name) full_state_dict = consolidate_sharded_model_checkpoints( ckpt_dir, checkpoints) if reshard_num == 1: if save_model: - actual_save_path = save_path if save_path else os.path.join( - ckpt_dir, "consolidated_optimizer.pth") + if not save_dir or not save_name: + raise ValueError("save_dir and save_name should not be None!") + actual_save_path = os.path.join(save_dir, save_name) + save_checkpoints(full_state_dict, checkpoints[0]['shard_metadata'], actual_save_path, 'model') return full_state_dict, None + # load layer_info file_path = os.path.join(ckpt_dir, "layer_info.pickle") layer_info = [] @@ -491,19 +494,23 @@ def consolidate_and_reshard_model_dict(ckpt_dir, with open(file_path, 'rb') as f: layer_info = pickle.load(f) except FileNotFoundError: - print(f"please consolidate model first!") + raise NotImplementedError("please consolidate model first!") model_state_dict_list, shard_metadata_list = reshard_model_dict( full_state_dict, checkpoints[0], layer_info[0], reshard_num) if save_model: - if save_path == "": - save_path = ckpt_dir - - actual_save_path = [ - os.path.join(save_path, f"rank-{rank}-of-{reshard_num}-model.pth") - for rank in range(reshard_num) - ] + if not save_dir or not save_name: + raise ValueError("save_dir and save_name should not be None!") + + actual_save_path = [] + for idx in range(reshard_num): + save_name_ = re.sub( + r'\*', + lambda m: str(idx) if m.group(0) == '*' else str(reshard_num), + save_name, + count=2) + actual_save_path.append(os.path.join(save_dir, save_name_)) save_checkpoints(model_state_dict_list, shard_metadata_list, actual_save_path, 'model') @@ -511,18 +518,39 @@ def consolidate_and_reshard_model_dict(ckpt_dir, return model_state_dict_list, shard_metadata_list -def consolidate_and_reshard_optim_dict(ckpt_dir, - ckpt_name="", - reshard_num=1, - save_path="", - save_optimizer=True): +def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir, + ckpt_name, + save_dir="", + save_name="", + reshard_num=1, + save_optimizer=True): """ Consolidate the sharded FSDP checkpoints into a single optimizer checkpoint. Then reshard the FSDP optimizer according to the reshard_num. - + Args: + ckpt_dir (str): + The dir to all FSDP shard optimizer checkpoint files. + ckpt_name (str): + The name_pattern to all FSDP shard optimizer checkpoint files. Files matching the + pattern ``ckpt_dir + ckpt_name`` will be loaded. Each + checkpoint file is assumed to be a dict with a "optimizer" key + containing the FSDP optimizer's ``optimizer.state_dict()`` and a + "shard_metadata" key containing the FSDP model's + ``model.get_shard_metadata()``. + save_dir (str, Optional): + The save dir for consolidate or reshard optimizer checkpoints. + save_name (str, Optional): + The name_pattern for consolidate or reshard optimizer checkpoints. + For reshard checkpoints name pattern:: `rank*-of-*-optimizer.pth` + The final save_path is save_dir + save_name. + reshard_num (int, Optional): + Reshard the fsdp optimizer by reshard_num. If set to 1, we don't need to do + resharding. + save_model (str, Optional): + if ``True``, the model checkpoint will be saved to ``save_dir + save_name``. Returns: optim_state_dict: the consolidated optim state dict or reshard optim state dict list - shard_meta_list: the reshard metadatalist. The consolidated optim return None. + shard_meta_list: the reshard metadatalist. For consolidated optim, return None. """ # load checkpoints checkpoints = load_checkpoints(ckpt_dir, ckpt_name) @@ -539,12 +567,12 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, full_optim_state_dict = consolidate_sharded_optimizer_checkpoints( ckpt_dir, checkpoints, layer_info) - actual_save_path = None - if reshard_num == 1: if save_optimizer: - actual_save_path = save_path if save_path else os.path.join( - ckpt_dir, "consolidated_optimizer.pth") + if not save_dir or not save_name: + raise ValueError("save_dir and save_name should not be None!") + actual_save_path = os.path.join(save_dir, save_name) + save_checkpoints(full_optim_state_dict, checkpoints[0]['shard_metadata'], actual_save_path, 'optimizer') @@ -555,14 +583,17 @@ def consolidate_and_reshard_optim_dict(ckpt_dir, full_optim_state_dict, checkpoints[0], layer_info[0], reshard_num) if save_optimizer: - if save_path == "": - save_path = ckpt_dir - - actual_save_path = [ - os.path.join(save_path, - f"rank-{rank}-of-{reshard_num}-optimizer.pth") - for rank in range(reshard_num) - ] + if not save_dir or not save_name: + raise ValueError("save_dir and save_name should not be None!") + + actual_save_path = [] + for idx in range(reshard_num): + save_name_ = re.sub( + r'\*', + lambda m: str(idx) if m.group(0) == '*' else str(reshard_num), + save_name, + count=2) + actual_save_path.append(os.path.join(save_dir, save_name_)) save_checkpoints(optim_state_dict_list, shard_metadata_list, actual_save_path, 'optimizer') diff --git a/torchacc/utils/consolidate_and_reshard_ckpts.py b/torchacc/utils/consolidate_and_reshard_ckpts.py index de74333..8fcf9ae 100644 --- a/torchacc/utils/consolidate_and_reshard_ckpts.py +++ b/torchacc/utils/consolidate_and_reshard_ckpts.py @@ -1,10 +1,11 @@ from argparse import ArgumentParser -from torchacc.dist.state_dict_utils import (consolidate_and_reshard_model_dict, - consolidate_and_reshard_optim_dict) +from torchacc.dist.state_dict_utils import ( + consolidate_and_reshard_fsdp_model_dict, + consolidate_and_reshard_fsdp_optim_dict) -MODEL_NAME_PATTERN = "rank*-of-*-model.pth" -OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" +DEFAULT_MODEL_NAME_PATTERN = "rank*-of-*-model.pth" +DEFAULT_OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" def main(): @@ -14,24 +15,25 @@ def main(): type=str, required=True, help=( - f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " + f"The name dir of the XLA FSDP checkpoint files to be consolidated and reshard. " f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." - f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," - f"the default pattern is {OPTIM_NAME_PATTERN}"), + f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) parser.add_argument( "--ckpt_name", type=str, default="", help=( - f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " + f"The name pattern of the XLA FSDP checkpoint files to be consolidated and reshard. " f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." - f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," - f"the default pattern is {OPTIM_NAME_PATTERN}"), + f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) parser.add_argument( "--ckpt_type", type=str, + choices=["model", "optimizer"], default="model", help=( "The type of checkpoint to consolidate, you can choose model or optimizer. Please consolidate model fisrt, and then consolidate optimizer." @@ -43,32 +45,62 @@ def main(): default=1, help=( "We now support the reshard of XLA FSDP checkpoint according to the reshard_num." - )) + ), + ) parser.add_argument( - "--save_path", + "--save_dir", type=str, default="", help=( - f"The save path of the output state dict " - f"(default consolidate_path is ``ckpt_dir + model/optimizer_consolidated.pth``)" - f"If you need to reshard the checkpoint, please only pass the save_dir(default is ckpt_dir)," - f"we will save the file in path ``save_path + {MODEL_NAME_PATTERN}/{OPTIM_NAME_PATTERN}``" + f"The save dir of the output checkpoint files, the default value will be set to arg: ckpt_dir." + f"Files will be saved in path: ``save_dir + save_name``." + f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``." + f"For reshard checkpoints, the default path is: ``save_dir + {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." ), ) + parser.add_argument( + "--save_name", + type=str, + default="", + help=( + f"The save name pattern of the output checkpoint files, the default value is {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}." + f"Files will be saved in path: ``save_dir + save_name`.`" + f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``" + f"For reshard checkpoints, the default path is: ``save_dir + {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." + f"For reshard checkpoints, please use the same name patthern as {DEFULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}." + ), + ) + args = parser.parse_args() - assert args.ckpt_type in ['model', 'optimizer' - ], ('the ckpt_type should be model or optimizer') if args.ckpt_type == "model": if args.ckpt_name == "": - args.ckpt_name = MODEL_NAME_PATTERN - consolidate_and_reshard_model_dict(args.ckpt_dir, args.ckpt_name, - args.reshard_num, args.save_path) + args.ckpt_name = DEFULT_MODEL_NAME_PATTERN + if args.save_path == "": + args.save_path = args.ckpt_dir + if args.save_name == "": + if args.reshard_name == 1: + args.save_name = "model_consolidated.pth" + else: + args.save_name = DEFAULT_MODEL_NAME_PATTERN + + consolidate_and_reshard_fsdp_model_dict(args.ckpt_dir, args.ckpt_name, + args.save_dir, args.save_name, + args.reshard_num) else: if args.ckpt_name == "": - args.ckpt_name = OPTIM_NAME_PATTERN - consolidate_and_reshard_optim_dict(args.ckpt_dir, args.ckpt_name, - args.reshard_num, args.save_path) + args.ckpt_name = DEFULT_MODEL_NAME_PATTERN + if args.save_path == "": + args.save_path = args.ckpt_dir + if args.save_name == "": + if args.reshard_name == 1: + args.save_name = "optimizer_consolidated.pth" + else: + args.save_name = DEFAULT_OPTIM_NAME_PATTERN + + consolidate_and_reshard_fsdp_optim_dict(args.ckpt_dir, args.ckpt_name, + args.save_dir, args.save_name, + args.reshard_num) if __name__ == "__main__":