Skip to content

Commit 24c6901

Browse files
committed
fix ci
1 parent 88cf02d commit 24c6901

File tree

4 files changed

+26
-5
lines changed

4 files changed

+26
-5
lines changed

ci_scripts/train/load_ckpt.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ source ./ci_scripts/common/variables.sh
88
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
99
readonly CKPTS40_PATH="$GITHUB_WORKSPACE/llm_ckpts/40"
1010
readonly CKPTS40_OUTPUT="${CKPTS40_PATH}/*.pt"
11-
expected_num=22
11+
expected_num=23
1212
exit_code=0
1313

1414
source ./ci_scripts/common/basic_func.sh

internlm/checkpoint/checkpoint_manager.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def save_checkpoint(
631631
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
632632
timer("save-optimizer").stop()
633633

634-
if gpc.get_global_rank() == 0:
634+
if gpc.get_global_rank() == 0 and gpc.config.ckpt.need_metadata:
635635
assert self.meta_data is not None
636636
llm_save(os.path.join(folder, "metadata.pt"), saved_obj=self.meta_data)
637637

tests/test_training/train_CI.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161

6262
def fuse_wqkv(key, state_dict) -> None: # pylint: disable=W0613
63-
prefix = key.rstrip("wqkv.weight")
63+
prefix = key.rstrip("Wqkv.weight")
6464
wq_name, wk_name, wv_name = (
6565
f"{prefix}wq.weight",
6666
f"{prefix}wk.weight",
@@ -78,8 +78,12 @@ def check_model_weights(model, ckpt_path, total_equal=False):
7878
copy_of_ordered_dict = model2_dict.copy()
7979

8080
for key in copy_of_ordered_dict.keys():
81+
if "wqkv" in key:
82+
model2_dict[key.replace("wqkv", "Wqkv")] = model2_dict.pop(key)
83+
key = key.replace("wqkv", "Wqkv")
84+
8185
if key not in model1_dict:
82-
if "wqkv" in key:
86+
if "Wqkv" in key:
8387
fuse_wqkv(key, model1_dict)
8488
else:
8589
assert False, f"Error: The key {key} for current model dose not exist in standard ckpt!"

tools/convert_ckpt_parallel.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
"""
2+
Usage:
3+
python tools/convert_ckpt_parallel.py \
4+
<origin_ckpt_path> <target_ckpt_path> \
5+
(optional) [--origin_meta_path <origin_meta_path>] [--target_meta_path <target_meta_path>] \
6+
(optional) [--copy_file <True/False>] [--convert_optimizer <True/False>]
7+
8+
When meta_path is not specified, it will automatically search and load meta in the ckpt path.
9+
Default to convert optimizer state and copy files.
10+
Example:
11+
srun -p llm_s python tools/convert_ckpt_parallel.py \
12+
/llm_ckpt/100 /target_ckpt/converted
13+
"""
114
import argparse
215
import os
316
import shutil
@@ -530,7 +543,6 @@ def convert_optimizer_ckpt(
530543
base_state["base_optim_states"]["state"][group_id] = state
531544
base_state["flat_fp32_weights"][group_id] = flat_fp32_weights
532545

533-
# print(f"optimizer tp{new_tp_rank}_pp{new_pp_rank}_zo{new_zero1_rank}: {base_state}")
534546
torch.save(base_state, os.path.join(saved_folder, file_name))
535547

536548
print("Finish optimizer convert", flush=True)
@@ -559,6 +571,7 @@ def convert_optimizer_ckpt(
559571
new_meta_path
560572
), "new meta file does not exist, plese generate it before converting checkpoint."
561573

574+
# read and process metaData for original ckpt
562575
old_meta = torch.load(old_meta_path, map_location="cpu")
563576
old_pp_size = old_meta["parallel_setting"]["pp_size"]
564577
old_zero1_size = old_meta["parallel_setting"]["zero1_size"]
@@ -570,16 +583,19 @@ def convert_optimizer_ckpt(
570583
assert False, "tp or wp should be in parallel setting."
571584
old_tp_size = old_meta["parallel_setting"][f"{old_tp_mode}_size"]
572585

586+
# To facilitate key query, summarize meta_data.
573587
old_meta_data = {}
574588
for pp_rank in range(old_pp_size):
575589
for zero_rank in range(old_zero1_size):
576590
for states in old_meta["metaData"][0][pp_rank][zero_rank].values():
577591
old_meta_data.update(states)
578592

593+
# map local fqn to global fqn
579594
old_map_local_to_global = [{} for _ in range(old_pp_size)]
580595
for global_fqn, states in old_meta_data.items():
581596
old_map_local_to_global[states["pp"]][states["fqn"]] = global_fqn
582597

598+
# read and process metaData for target ckpt
583599
new_meta = torch.load(new_meta_path, map_location="cpu")
584600
new_pp_size = new_meta["parallel_setting"]["pp_size"]
585601
new_zero1_size = new_meta["parallel_setting"]["zero1_size"]
@@ -597,6 +613,7 @@ def convert_optimizer_ckpt(
597613
), "Error: old meta and new meta have diffent group_id lists."
598614
group_id_list = list(new_meta["metaData"][0][0][0].keys())
599615

616+
# To facilitate key query, summarize meta_data.
600617
new_meta_data = {}
601618
for pp_rank in range(new_pp_size):
602619
for zero_rank in range(new_zero1_size):

0 commit comments

Comments
 (0)