1
+ """
2
+ Usage:
3
+ python tools/convert_ckpt_parallel.py \
4
+ <origin_ckpt_path> <target_ckpt_path> \
5
+ (optional) [--origin_meta_path <origin_meta_path>] [--target_meta_path <target_meta_path>] \
6
+ (optional) [--copy_file <True/False>] [--convert_optimizer <True/False>]
7
+
8
+ When meta_path is not specified, it will automatically search and load meta in the ckpt path.
9
+ Default to convert optimizer state and copy files.
10
+ Example:
11
+ srun -p llm_s python tools/convert_ckpt_parallel.py \
12
+ /llm_ckpt/100 /target_ckpt/converted
13
+ """
1
14
import argparse
2
15
import os
3
16
import shutil
@@ -530,7 +543,6 @@ def convert_optimizer_ckpt(
530
543
base_state ["base_optim_states" ]["state" ][group_id ] = state
531
544
base_state ["flat_fp32_weights" ][group_id ] = flat_fp32_weights
532
545
533
- # print(f"optimizer tp{new_tp_rank}_pp{new_pp_rank}_zo{new_zero1_rank}: {base_state}")
534
546
torch .save (base_state , os .path .join (saved_folder , file_name ))
535
547
536
548
print ("Finish optimizer convert" , flush = True )
@@ -559,6 +571,7 @@ def convert_optimizer_ckpt(
559
571
new_meta_path
560
572
), "new meta file does not exist, plese generate it before converting checkpoint."
561
573
574
+ # read and process metaData for original ckpt
562
575
old_meta = torch .load (old_meta_path , map_location = "cpu" )
563
576
old_pp_size = old_meta ["parallel_setting" ]["pp_size" ]
564
577
old_zero1_size = old_meta ["parallel_setting" ]["zero1_size" ]
@@ -570,16 +583,19 @@ def convert_optimizer_ckpt(
570
583
assert False , "tp or wp should be in parallel setting."
571
584
old_tp_size = old_meta ["parallel_setting" ][f"{ old_tp_mode } _size" ]
572
585
586
+ # To facilitate key query, summarize meta_data.
573
587
old_meta_data = {}
574
588
for pp_rank in range (old_pp_size ):
575
589
for zero_rank in range (old_zero1_size ):
576
590
for states in old_meta ["metaData" ][0 ][pp_rank ][zero_rank ].values ():
577
591
old_meta_data .update (states )
578
592
593
+ # map local fqn to global fqn
579
594
old_map_local_to_global = [{} for _ in range (old_pp_size )]
580
595
for global_fqn , states in old_meta_data .items ():
581
596
old_map_local_to_global [states ["pp" ]][states ["fqn" ]] = global_fqn
582
597
598
+ # read and process metaData for target ckpt
583
599
new_meta = torch .load (new_meta_path , map_location = "cpu" )
584
600
new_pp_size = new_meta ["parallel_setting" ]["pp_size" ]
585
601
new_zero1_size = new_meta ["parallel_setting" ]["zero1_size" ]
@@ -597,6 +613,7 @@ def convert_optimizer_ckpt(
597
613
), "Error: old meta and new meta have diffent group_id lists."
598
614
group_id_list = list (new_meta ["metaData" ][0 ][0 ][0 ].keys ())
599
615
616
+ # To facilitate key query, summarize meta_data.
600
617
new_meta_data = {}
601
618
for pp_rank in range (new_pp_size ):
602
619
for zero_rank in range (new_zero1_size ):
0 commit comments