Skip to content

Commit

Permalink
refine console scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
hanwen-sun committed Oct 16, 2024
1 parent 5504402 commit 920adbc
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 76 deletions.
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def get_and_set_version():
packages=['torchacc'] + ['torchacc.' + \
pkg for pkg in find_packages('torchacc')],

# add console_scripts
entry_points={
'console_scripts': [
'consolidate_and_reshape_ckpts = torchacc.utils.consolidate_and_reshard_ckpts:main',
],
},

# Add _ prefix to the names of temporary build dirs
options={'build': {'build_base': '_build'}, },
zip_safe=True,
Expand Down
8 changes: 4 additions & 4 deletions tests/standalone/consolidate_and_reshard_ckpts.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,17 @@ def main(args):
ckpt_dir=ckpt_dir,
ckpt_name=f"rank*-of-*-model.pth",
reshard_num=reshard_num,
save_model=True,
save_model=False,
)
print(f"model consolidate and reshard to path:{ckpt_dir}")
print(f"model consolidate and reshard done.")

optim_reshard_dicts, _ = consolidate_and_reshard_optim_dict(
ckpt_dir=ckpt_dir,
ckpt_name=f"rank*-of-*-optim.pth",
reshard_num=reshard_num,
save_optimizer=True,
save_optimizer=False,
)
print(f"optimizer consolidate and reshard to path:{ckpt_dir}")
print(f"optimizer consolidate and reshard done.")

# compare shard model and optimizer
if reshard_num == fsdp_num:
Expand Down
127 changes: 79 additions & 48 deletions torchacc/dist/state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,93 +436,121 @@ def reshard_optim_dict(consolidate_optim_dict, shard_optim, layer_name_lists,
return shard_optim_state_dict_list, shard_metadata_list


def consolidate_and_reshard_model_dict(ckpt_dir,
ckpt_name="",
reshard_num=1,
save_path="",
save_model=True):
def consolidate_and_reshard_fsdp_model_dict(ckpt_dir,
ckpt_name,
save_dir="",
save_name="",
reshard_num=1,
save_model=True):
"""
Consolidate the sharded FSDP checkpoints into a single model checkpoint. Then
reshard the FSDP model according to the reshard_num.
Args:
ckpt_dir (str):
The dir to FSDP checkpoint files from all ranks
ckpt_name (str, Optional):
The name_pattern to FSDP checkpoint files from all ranks. Files matching the
pattern ``ckpt_dir + ckpt_name`` will be loaded. The each
The dir to all FSDP shard model checkpoint files.
ckpt_name (str):
The name_pattern to all FSDP shard model checkpoint files. Files matching the
pattern ``ckpt_dir + ckpt_name`` will be loaded. Each
checkpoint file is assumed to be a dict with a "model" key
containing the FSDP model's ``model.state_dict()`` and a
"shard_metadata" key containing the FSDP model's
``model.get_shard_metadata()``.
save_dir (str):
The save dir for consolidate or reshard model checkpoints.
save_name (str, Optional):
The name_pattern for consolidate or reshard model checkpoints.
For reshard checkpoints name pattern: ``rank*-of-*-model.pth``
The final save_path is save_dir + save_name.
reshard_num (int, Optional):
Reshard the fsdp model with reshard_num. If set to 1, we don't need to do
Reshard the fsdp model by reshard_num. If set to 1, we don't need to do
resharding.
save_path (str, Optional):
the save path to the consolidated model checkpoint file (if
``save_model`` is ``True``). The checkpoint file is a dict with a
"model" key containing the consolidated model state dict.
save_model (str, Optional):
if ``True``, the model checkpoint will be saved to
``save_path`` (or ``ckpt_dir + "consolidated_model.pth"`` if
``save_path`` is empty).
if ``True``, the model checkpoint will be saved to ``save_dir + save_name``.
Returns:
model_state_dict: the consolidated model state dict or reshard model state dict list.
shard_meta_list: the reshard metadatalist. The consolidated model return None.
shard_meta_list: the reshard metadatalist. For consolidated model, return None.
"""

checkpoints = load_checkpoints(ckpt_dir, ckpt_name)
full_state_dict = consolidate_sharded_model_checkpoints(
ckpt_dir, checkpoints)

if reshard_num == 1:
if save_model:
actual_save_path = save_path if save_path else os.path.join(
ckpt_dir, "consolidated_optimizer.pth")
if not save_dir or not save_name:
raise ValueError("save_dir and save_name should not be None!")
actual_save_path = os.path.join(save_dir, save_name)

save_checkpoints(full_state_dict, checkpoints[0]['shard_metadata'],
actual_save_path, 'model')

return full_state_dict, None

# load layer_info
file_path = os.path.join(ckpt_dir, "layer_info.pickle")
layer_info = []
try:
with open(file_path, 'rb') as f:
layer_info = pickle.load(f)
except FileNotFoundError:
print(f"please consolidate model first!")
raise NotImplementedError("please consolidate model first!")

model_state_dict_list, shard_metadata_list = reshard_model_dict(
full_state_dict, checkpoints[0], layer_info[0], reshard_num)

if save_model:
if save_path == "":
save_path = ckpt_dir

actual_save_path = [
os.path.join(save_path, f"rank-{rank}-of-{reshard_num}-model.pth")
for rank in range(reshard_num)
]
if not save_dir or not save_name:
raise ValueError("save_dir and save_name should not be None!")

actual_save_path = []
for idx in range(reshard_num):
save_name_ = re.sub(
r'\*',
lambda m: str(idx) if m.group(0) == '*' else str(reshard_num),
save_name,
count=2)
actual_save_path.append(os.path.join(save_dir, save_name_))

save_checkpoints(model_state_dict_list, shard_metadata_list,
actual_save_path, 'model')

return model_state_dict_list, shard_metadata_list


def consolidate_and_reshard_optim_dict(ckpt_dir,
ckpt_name="",
reshard_num=1,
save_path="",
save_optimizer=True):
def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir,
ckpt_name,
save_dir="",
save_name="",
reshard_num=1,
save_optimizer=True):
"""
Consolidate the sharded FSDP checkpoints into a single optimizer checkpoint. Then
reshard the FSDP optimizer according to the reshard_num.
Args:
ckpt_dir (str):
The dir to all FSDP shard optimizer checkpoint files.
ckpt_name (str):
The name_pattern to all FSDP shard optimizer checkpoint files. Files matching the
pattern ``ckpt_dir + ckpt_name`` will be loaded. Each
checkpoint file is assumed to be a dict with a "optimizer" key
containing the FSDP optimizer's ``optimizer.state_dict()`` and a
"shard_metadata" key containing the FSDP model's
``model.get_shard_metadata()``.
save_dir (str, Optional):
The save dir for consolidate or reshard optimizer checkpoints.
save_name (str, Optional):
The name_pattern for consolidate or reshard optimizer checkpoints.
For reshard checkpoints name pattern:: `rank*-of-*-optimizer.pth`
The final save_path is save_dir + save_name.
reshard_num (int, Optional):
Reshard the fsdp optimizer by reshard_num. If set to 1, we don't need to do
resharding.
save_model (str, Optional):
if ``True``, the model checkpoint will be saved to ``save_dir + save_name``.
Returns:
optim_state_dict: the consolidated optim state dict or reshard optim state dict list
shard_meta_list: the reshard metadatalist. The consolidated optim return None.
shard_meta_list: the reshard metadatalist. For consolidated optim, return None.
"""
# load checkpoints
checkpoints = load_checkpoints(ckpt_dir, ckpt_name)
Expand All @@ -539,12 +567,12 @@ def consolidate_and_reshard_optim_dict(ckpt_dir,
full_optim_state_dict = consolidate_sharded_optimizer_checkpoints(
ckpt_dir, checkpoints, layer_info)

actual_save_path = None

if reshard_num == 1:
if save_optimizer:
actual_save_path = save_path if save_path else os.path.join(
ckpt_dir, "consolidated_optimizer.pth")
if not save_dir or not save_name:
raise ValueError("save_dir and save_name should not be None!")
actual_save_path = os.path.join(save_dir, save_name)

save_checkpoints(full_optim_state_dict,
checkpoints[0]['shard_metadata'], actual_save_path,
'optimizer')
Expand All @@ -555,14 +583,17 @@ def consolidate_and_reshard_optim_dict(ckpt_dir,
full_optim_state_dict, checkpoints[0], layer_info[0], reshard_num)

if save_optimizer:
if save_path == "":
save_path = ckpt_dir

actual_save_path = [
os.path.join(save_path,
f"rank-{rank}-of-{reshard_num}-optimizer.pth")
for rank in range(reshard_num)
]
if not save_dir or not save_name:
raise ValueError("save_dir and save_name should not be None!")

actual_save_path = []
for idx in range(reshard_num):
save_name_ = re.sub(
r'\*',
lambda m: str(idx) if m.group(0) == '*' else str(reshard_num),
save_name,
count=2)
actual_save_path.append(os.path.join(save_dir, save_name_))

save_checkpoints(optim_state_dict_list, shard_metadata_list,
actual_save_path, 'optimizer')
Expand Down
80 changes: 56 additions & 24 deletions torchacc/utils/consolidate_and_reshard_ckpts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from argparse import ArgumentParser

from torchacc.dist.state_dict_utils import (consolidate_and_reshard_model_dict,
consolidate_and_reshard_optim_dict)
from torchacc.dist.state_dict_utils import (
consolidate_and_reshard_fsdp_model_dict,
consolidate_and_reshard_fsdp_optim_dict)

MODEL_NAME_PATTERN = "rank*-of-*-model.pth"
OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth"
DEFAULT_MODEL_NAME_PATTERN = "rank*-of-*-model.pth"
DEFAULT_OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth"


def main():
Expand All @@ -14,24 +15,25 @@ def main():
type=str,
required=True,
help=(
f"The name pattern of the XLA FSDP checkpoint files to be consolidated. "
f"The name dir of the XLA FSDP checkpoint files to be consolidated and reshard. "
f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded."
f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer,"
f"the default pattern is {OPTIM_NAME_PATTERN}"),
f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer,"
f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"),
)
parser.add_argument(
"--ckpt_name",
type=str,
default="",
help=(
f"The name pattern of the XLA FSDP checkpoint files to be consolidated. "
f"The name pattern of the XLA FSDP checkpoint files to be consolidated and reshard. "
f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded."
f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer,"
f"the default pattern is {OPTIM_NAME_PATTERN}"),
f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer,"
f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"),
)
parser.add_argument(
"--ckpt_type",
type=str,
choices=["model", "optimizer"],
default="model",
help=(
"The type of checkpoint to consolidate, you can choose model or optimizer. Please consolidate model fisrt, and then consolidate optimizer."
Expand All @@ -43,32 +45,62 @@ def main():
default=1,
help=(
"We now support the reshard of XLA FSDP checkpoint according to the reshard_num."
))
),
)
parser.add_argument(
"--save_path",
"--save_dir",
type=str,
default="",
help=(
f"The save path of the output state dict "
f"(default consolidate_path is ``ckpt_dir + model/optimizer_consolidated.pth``)"
f"If you need to reshard the checkpoint, please only pass the save_dir(default is ckpt_dir),"
f"we will save the file in path ``save_path + {MODEL_NAME_PATTERN}/{OPTIM_NAME_PATTERN}``"
f"The save dir of the output checkpoint files, the default value will be set to arg: ckpt_dir."
f"Files will be saved in path: ``save_dir + save_name``."
f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``."
f"For reshard checkpoints, the default path is: ``save_dir + {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``."
),
)
parser.add_argument(
"--save_name",
type=str,
default="",
help=(
f"The save name pattern of the output checkpoint files, the default value is {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}."
f"Files will be saved in path: ``save_dir + save_name`.`"
f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``"
f"For reshard checkpoints, the default path is: ``save_dir + {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``."
f"For reshard checkpoints, please use the same name patthern as {DEFULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}."
),
)

args = parser.parse_args()
assert args.ckpt_type in ['model', 'optimizer'
], ('the ckpt_type should be model or optimizer')

if args.ckpt_type == "model":
if args.ckpt_name == "":
args.ckpt_name = MODEL_NAME_PATTERN
consolidate_and_reshard_model_dict(args.ckpt_dir, args.ckpt_name,
args.reshard_num, args.save_path)
args.ckpt_name = DEFULT_MODEL_NAME_PATTERN
if args.save_path == "":
args.save_path = args.ckpt_dir
if args.save_name == "":
if args.reshard_name == 1:
args.save_name = "model_consolidated.pth"
else:
args.save_name = DEFAULT_MODEL_NAME_PATTERN

consolidate_and_reshard_fsdp_model_dict(args.ckpt_dir, args.ckpt_name,
args.save_dir, args.save_name,
args.reshard_num)
else:
if args.ckpt_name == "":
args.ckpt_name = OPTIM_NAME_PATTERN
consolidate_and_reshard_optim_dict(args.ckpt_dir, args.ckpt_name,
args.reshard_num, args.save_path)
args.ckpt_name = DEFULT_MODEL_NAME_PATTERN
if args.save_path == "":
args.save_path = args.ckpt_dir
if args.save_name == "":
if args.reshard_name == 1:
args.save_name = "optimizer_consolidated.pth"
else:
args.save_name = DEFAULT_OPTIM_NAME_PATTERN

consolidate_and_reshard_fsdp_optim_dict(args.ckpt_dir, args.ckpt_name,
args.save_dir, args.save_name,
args.reshard_num)


if __name__ == "__main__":
Expand Down

0 comments on commit 920adbc

Please sign in to comment.