diff --git a/tests/standalone/consolidate_and_reshard_ckpts.py b/tests/standalone/consolidate_and_reshard_ckpts.py index 2e8a8d1..3d09360 100644 --- a/tests/standalone/consolidate_and_reshard_ckpts.py +++ b/tests/standalone/consolidate_and_reshard_ckpts.py @@ -4,9 +4,9 @@ import torch import torch_xla.core.xla_model as xm import torchacc as ta -from torchacc.dist.state_dict_utils import (consolidate_and_reshard_model_dict, - consolidate_and_reshard_optim_dict, - load_checkpoints) +from torchacc.dist.state_dict_utils import ( + consolidate_and_reshard_fsdp_model_dict, + consolidate_and_reshard_fsdp_optim_dict, load_checkpoints) from utils import EchoDataset, set_seed @@ -198,7 +198,7 @@ def main(args): # rank 0 do consolidate and reshard: if ta.dist.local_rank() == 0: # consolidate and reshard model and optimizer - model_reshard_dicts, _ = consolidate_and_reshard_model_dict( + model_reshard_dicts, _ = consolidate_and_reshard_fsdp_model_dict( ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-model.pth", reshard_num=reshard_num, @@ -206,7 +206,7 @@ def main(args): ) print(f"model consolidate and reshard done.") - optim_reshard_dicts, _ = consolidate_and_reshard_optim_dict( + optim_reshard_dicts, _ = consolidate_and_reshard_fsdp_optim_dict( ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-optim.pth", reshard_num=reshard_num, @@ -217,7 +217,7 @@ def main(args): # compare shard model and optimizer if reshard_num == fsdp_num: model_shard_dicts = load_checkpoints( - kpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-model.pth") + ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-model.pth") optim_shard_dicts = load_checkpoints( ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-optim.pth") @@ -242,28 +242,28 @@ def main(args): parser.add_argument("--bf16", action="store_true", default=False) parser.add_argument("--backend", type=str, default="lazy") - MODEL_NAME_PATTERN = "rank*-of-*-model.pth" - OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" + DEFAULT_MODEL_NAME_PATTERN = "rank*-of-*-model.pth" + DEFAULT_OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" # ckpt arguments parser.add_argument( "--ckpt_dir", type=str, required=True, help=( - f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." - f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," - f"the default pattern is {OPTIM_NAME_PATTERN}"), + f"The name dir of the XLA FSDP checkpoint files to be consolidated and reshard. " + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load." + f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) parser.add_argument( "--ckpt_name", type=str, default="", help=( - f"The name pattern of the XLA FSDP checkpoint files to be consolidated. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." - f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer," - f"the default pattern is {OPTIM_NAME_PATTERN}"), + f"The name pattern of the XLA FSDP checkpoint files to be consolidated and reshard. " + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load." + f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," + f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) parser.add_argument( "--reshard_num", @@ -273,14 +273,26 @@ def main(args): "We now support the reshard of XLA FSDP checkpoint according to the reshard_num." )) parser.add_argument( - "--save_path", + "--save_dir", type=str, default="", help=( - f"The save path of the output state dict " - f"(default consolidate_path is ``ckpt_dir + model/optimizer_consolidated.pth``)" - f"If you need to reshard the checkpoint, please only pass the save_dir(default is ckpt_dir)," - f"we will save the file in path ``save_path + {MODEL_NAME_PATTERN}/{OPTIM_NAME_PATTERN}``" + f"The save dir of the output checkpoint files, the default value will be set to arg: ckpt_dir." + f"Files will be saved in path: ``save_dir + save_name``." + f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``." + f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." + ), + ) + parser.add_argument( + "--save_name", + type=str, + default="", + help=( + f"The save name pattern of the output checkpoint files, the default value is {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}." + f"Files will be saved in path: ``save_dir + save_name`.`" + f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``" + f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." + f"For reshard checkpoints, please use the same name patthern as {DEFAULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}." ), ) diff --git a/torchacc/dist/state_dict_utils.py b/torchacc/dist/state_dict_utils.py index f120a36..7090f09 100644 --- a/torchacc/dist/state_dict_utils.py +++ b/torchacc/dist/state_dict_utils.py @@ -2,6 +2,7 @@ import os import pickle import threading +import re from collections import OrderedDict from glob import glob from typing import Dict @@ -505,12 +506,10 @@ def consolidate_and_reshard_fsdp_model_dict(ckpt_dir, actual_save_path = [] for idx in range(reshard_num): - save_name_ = re.sub( - r'\*', - lambda m: str(idx) if m.group(0) == '*' else str(reshard_num), - save_name, - count=2) - actual_save_path.append(os.path.join(save_dir, save_name_)) + # replace the two '*' + save_name_temp = save_name.replace('*', str(idx), 1) + save_name_temp = save_name_temp.replace('*', str(reshard_num), 1) + actual_save_path.append(os.path.join(save_dir, save_name_temp)) save_checkpoints(model_state_dict_list, shard_metadata_list, actual_save_path, 'model') @@ -588,12 +587,10 @@ def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir, actual_save_path = [] for idx in range(reshard_num): - save_name_ = re.sub( - r'\*', - lambda m: str(idx) if m.group(0) == '*' else str(reshard_num), - save_name, - count=2) - actual_save_path.append(os.path.join(save_dir, save_name_)) + # replace the two '*' + save_name_temp = save_name.replace('*', str(idx), 1) + save_name_temp = save_name_temp.replace('*', str(reshard_num), 1) + actual_save_path.append(os.path.join(save_dir, save_name_temp)) save_checkpoints(optim_state_dict_list, shard_metadata_list, actual_save_path, 'optimizer') diff --git a/torchacc/utils/consolidate_and_reshard_ckpts.py b/torchacc/utils/consolidate_and_reshard_ckpts.py index 8fcf9ae..112a49f 100644 --- a/torchacc/utils/consolidate_and_reshard_ckpts.py +++ b/torchacc/utils/consolidate_and_reshard_ckpts.py @@ -4,8 +4,8 @@ consolidate_and_reshard_fsdp_model_dict, consolidate_and_reshard_fsdp_optim_dict) -DEFAULT_MODEL_NAME_PATTERN = "rank*-of-*-model.pth" -DEFAULT_OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth" +DEFAULT_MODEL_NAME_PATTERN = "rank-*-of-*-model.pth" +DEFAULT_OPTIM_NAME_PATTERN = "rank-*-of-*-optimizer.pth" def main(): @@ -16,7 +16,7 @@ def main(): required=True, help=( f"The name dir of the XLA FSDP checkpoint files to be consolidated and reshard. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load." f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) @@ -26,7 +26,7 @@ def main(): default="", help=( f"The name pattern of the XLA FSDP checkpoint files to be consolidated and reshard. " - f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded." + f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load." f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer," f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"), ) @@ -55,7 +55,7 @@ def main(): f"The save dir of the output checkpoint files, the default value will be set to arg: ckpt_dir." f"Files will be saved in path: ``save_dir + save_name``." f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``." - f"For reshard checkpoints, the default path is: ``save_dir + {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." + f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." ), ) parser.add_argument( @@ -63,11 +63,11 @@ def main(): type=str, default="", help=( - f"The save name pattern of the output checkpoint files, the default value is {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}." + f"The save name pattern of the output checkpoint files, the default value is {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}." f"Files will be saved in path: ``save_dir + save_name`.`" f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``" - f"For reshard checkpoints, the default path is: ``save_dir + {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." - f"For reshard checkpoints, please use the same name patthern as {DEFULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}." + f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``." + f"For reshard checkpoints, please use the same name patthern as {DEFAULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}." ), ) @@ -76,10 +76,10 @@ def main(): if args.ckpt_type == "model": if args.ckpt_name == "": args.ckpt_name = DEFULT_MODEL_NAME_PATTERN - if args.save_path == "": - args.save_path = args.ckpt_dir + if args.save_dir == "": + args.save_dir = args.ckpt_dir if args.save_name == "": - if args.reshard_name == 1: + if args.reshard_num == 1: args.save_name = "model_consolidated.pth" else: args.save_name = DEFAULT_MODEL_NAME_PATTERN @@ -89,15 +89,17 @@ def main(): args.reshard_num) else: if args.ckpt_name == "": - args.ckpt_name = DEFULT_MODEL_NAME_PATTERN - if args.save_path == "": - args.save_path = args.ckpt_dir + args.ckpt_name = DEFULT_OPTIM_NAME_PATTERN + if args.save_dir == "": + args.save_dir = args.ckpt_dir if args.save_name == "": - if args.reshard_name == 1: + if args.reshard_num == 1: args.save_name = "optimizer_consolidated.pth" else: args.save_name = DEFAULT_OPTIM_NAME_PATTERN - + print(args.ckpt_dir) + print(args.save_dir) + print(args.save_name) consolidate_and_reshard_fsdp_optim_dict(args.ckpt_dir, args.ckpt_name, args.save_dir, args.save_name, args.reshard_num)