Skip to content

Commit

Permalink
add name pattern matching
Browse files Browse the repository at this point in the history
  • Loading branch information
hanwen-sun committed Oct 16, 2024
1 parent 920adbc commit 1758008
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 49 deletions.
54 changes: 33 additions & 21 deletions tests/standalone/consolidate_and_reshard_ckpts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch
import torch_xla.core.xla_model as xm
import torchacc as ta
from torchacc.dist.state_dict_utils import (consolidate_and_reshard_model_dict,
consolidate_and_reshard_optim_dict,
load_checkpoints)
from torchacc.dist.state_dict_utils import (
consolidate_and_reshard_fsdp_model_dict,
consolidate_and_reshard_fsdp_optim_dict, load_checkpoints)
from utils import EchoDataset, set_seed


Expand Down Expand Up @@ -198,15 +198,15 @@ def main(args):
# rank 0 do consolidate and reshard:
if ta.dist.local_rank() == 0:
# consolidate and reshard model and optimizer
model_reshard_dicts, _ = consolidate_and_reshard_model_dict(
model_reshard_dicts, _ = consolidate_and_reshard_fsdp_model_dict(
ckpt_dir=ckpt_dir,
ckpt_name=f"rank*-of-*-model.pth",
reshard_num=reshard_num,
save_model=False,
)
print(f"model consolidate and reshard done.")

optim_reshard_dicts, _ = consolidate_and_reshard_optim_dict(
optim_reshard_dicts, _ = consolidate_and_reshard_fsdp_optim_dict(
ckpt_dir=ckpt_dir,
ckpt_name=f"rank*-of-*-optim.pth",
reshard_num=reshard_num,
Expand All @@ -217,7 +217,7 @@ def main(args):
# compare shard model and optimizer
if reshard_num == fsdp_num:
model_shard_dicts = load_checkpoints(
kpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-model.pth")
ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-model.pth")
optim_shard_dicts = load_checkpoints(
ckpt_dir=ckpt_dir, ckpt_name=f"rank*-of-*-optim.pth")

Expand All @@ -242,28 +242,28 @@ def main(args):
parser.add_argument("--bf16", action="store_true", default=False)
parser.add_argument("--backend", type=str, default="lazy")

MODEL_NAME_PATTERN = "rank*-of-*-model.pth"
OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth"
DEFAULT_MODEL_NAME_PATTERN = "rank*-of-*-model.pth"
DEFAULT_OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth"
# ckpt arguments
parser.add_argument(
"--ckpt_dir",
type=str,
required=True,
help=(
f"The name pattern of the XLA FSDP checkpoint files to be consolidated. "
f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded."
f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer,"
f"the default pattern is {OPTIM_NAME_PATTERN}"),
f"The name dir of the XLA FSDP checkpoint files to be consolidated and reshard. "
f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load."
f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer,"
f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"),
)
parser.add_argument(
"--ckpt_name",
type=str,
default="",
help=(
f"The name pattern of the XLA FSDP checkpoint files to be consolidated. "
f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded."
f"For model, the default pattern is {MODEL_NAME_PATTERN}. For optimizer,"
f"the default pattern is {OPTIM_NAME_PATTERN}"),
f"The name pattern of the XLA FSDP checkpoint files to be consolidated and reshard. "
f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load."
f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer,"
f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"),
)
parser.add_argument(
"--reshard_num",
Expand All @@ -273,14 +273,26 @@ def main(args):
"We now support the reshard of XLA FSDP checkpoint according to the reshard_num."
))
parser.add_argument(
"--save_path",
"--save_dir",
type=str,
default="",
help=(
f"The save path of the output state dict "
f"(default consolidate_path is ``ckpt_dir + model/optimizer_consolidated.pth``)"
f"If you need to reshard the checkpoint, please only pass the save_dir(default is ckpt_dir),"
f"we will save the file in path ``save_path + {MODEL_NAME_PATTERN}/{OPTIM_NAME_PATTERN}``"
f"The save dir of the output checkpoint files, the default value will be set to arg: ckpt_dir."
f"Files will be saved in path: ``save_dir + save_name``."
f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``."
f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``."
),
)
parser.add_argument(
"--save_name",
type=str,
default="",
help=(
f"The save name pattern of the output checkpoint files, the default value is {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}."
f"Files will be saved in path: ``save_dir + save_name`.`"
f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``"
f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``."
f"For reshard checkpoints, please use the same name patthern as {DEFAULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}."
),
)

Expand Down
21 changes: 9 additions & 12 deletions torchacc/dist/state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pickle
import threading
import re
from collections import OrderedDict
from glob import glob
from typing import Dict
Expand Down Expand Up @@ -505,12 +506,10 @@ def consolidate_and_reshard_fsdp_model_dict(ckpt_dir,

actual_save_path = []
for idx in range(reshard_num):
save_name_ = re.sub(
r'\*',
lambda m: str(idx) if m.group(0) == '*' else str(reshard_num),
save_name,
count=2)
actual_save_path.append(os.path.join(save_dir, save_name_))
# replace the two '*'
save_name_temp = save_name.replace('*', str(idx), 1)
save_name_temp = save_name_temp.replace('*', str(reshard_num), 1)
actual_save_path.append(os.path.join(save_dir, save_name_temp))

save_checkpoints(model_state_dict_list, shard_metadata_list,
actual_save_path, 'model')
Expand Down Expand Up @@ -588,12 +587,10 @@ def consolidate_and_reshard_fsdp_optim_dict(ckpt_dir,

actual_save_path = []
for idx in range(reshard_num):
save_name_ = re.sub(
r'\*',
lambda m: str(idx) if m.group(0) == '*' else str(reshard_num),
save_name,
count=2)
actual_save_path.append(os.path.join(save_dir, save_name_))
# replace the two '*'
save_name_temp = save_name.replace('*', str(idx), 1)
save_name_temp = save_name_temp.replace('*', str(reshard_num), 1)
actual_save_path.append(os.path.join(save_dir, save_name_temp))

save_checkpoints(optim_state_dict_list, shard_metadata_list,
actual_save_path, 'optimizer')
Expand Down
34 changes: 18 additions & 16 deletions torchacc/utils/consolidate_and_reshard_ckpts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
consolidate_and_reshard_fsdp_model_dict,
consolidate_and_reshard_fsdp_optim_dict)

DEFAULT_MODEL_NAME_PATTERN = "rank*-of-*-model.pth"
DEFAULT_OPTIM_NAME_PATTERN = "rank*-of-*-optimizer.pth"
DEFAULT_MODEL_NAME_PATTERN = "rank-*-of-*-model.pth"
DEFAULT_OPTIM_NAME_PATTERN = "rank-*-of-*-optimizer.pth"


def main():
Expand All @@ -16,7 +16,7 @@ def main():
required=True,
help=(
f"The name dir of the XLA FSDP checkpoint files to be consolidated and reshard. "
f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded."
f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load."
f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer,"
f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"),
)
Expand All @@ -26,7 +26,7 @@ def main():
default="",
help=(
f"The name pattern of the XLA FSDP checkpoint files to be consolidated and reshard. "
f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be loaded."
f"Files matching the pattern ``ckpt_dir + ckpt_name`` will be load."
f"For model, the default pattern is {DEFAULT_MODEL_NAME_PATTERN}. For optimizer,"
f"the default pattern is {DEFAULT_OPTIM_NAME_PATTERN}"),
)
Expand Down Expand Up @@ -55,19 +55,19 @@ def main():
f"The save dir of the output checkpoint files, the default value will be set to arg: ckpt_dir."
f"Files will be saved in path: ``save_dir + save_name``."
f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``."
f"For reshard checkpoints, the default path is: ``save_dir + {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``."
f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``."
),
)
parser.add_argument(
"--save_name",
type=str,
default="",
help=(
f"The save name pattern of the output checkpoint files, the default value is {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}."
f"The save name pattern of the output checkpoint files, the default value is {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}."
f"Files will be saved in path: ``save_dir + save_name`.`"
f"For consolidated checkpoint, the default path is: ``save_dir + model/optimizer_consolidated.pth``"
f"For reshard checkpoints, the default path is: ``save_dir + {DEFULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``."
f"For reshard checkpoints, please use the same name patthern as {DEFULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}."
f"For reshard checkpoints, the default path is: ``save_dir + {DEFAULT_MODEL_NAME_PATTERN}/{DEFAULT_OPTIM_NAME_PATTERN}``."
f"For reshard checkpoints, please use the same name patthern as {DEFAULT_MODEL_NAME_PATTERN} and {DEFAULT_OPTIM_NAME_PATTERN}."
),
)

Expand All @@ -76,10 +76,10 @@ def main():
if args.ckpt_type == "model":
if args.ckpt_name == "":
args.ckpt_name = DEFULT_MODEL_NAME_PATTERN
if args.save_path == "":
args.save_path = args.ckpt_dir
if args.save_dir == "":
args.save_dir = args.ckpt_dir
if args.save_name == "":
if args.reshard_name == 1:
if args.reshard_num == 1:
args.save_name = "model_consolidated.pth"
else:
args.save_name = DEFAULT_MODEL_NAME_PATTERN
Expand All @@ -89,15 +89,17 @@ def main():
args.reshard_num)
else:
if args.ckpt_name == "":
args.ckpt_name = DEFULT_MODEL_NAME_PATTERN
if args.save_path == "":
args.save_path = args.ckpt_dir
args.ckpt_name = DEFULT_OPTIM_NAME_PATTERN
if args.save_dir == "":
args.save_dir = args.ckpt_dir
if args.save_name == "":
if args.reshard_name == 1:
if args.reshard_num == 1:
args.save_name = "optimizer_consolidated.pth"
else:
args.save_name = DEFAULT_OPTIM_NAME_PATTERN

print(args.ckpt_dir)
print(args.save_dir)
print(args.save_name)
consolidate_and_reshard_fsdp_optim_dict(args.ckpt_dir, args.ckpt_name,
args.save_dir, args.save_name,
args.reshard_num)
Expand Down

0 comments on commit 1758008

Please sign in to comment.