43
43
)
44
44
from .load_funcs import LOAD_FUNC_DICT
45
45
from .utils import process_load_info
46
+ from internlm .checkpoint .vescale .api import save as vescale_save
46
47
47
48
logger = get_logger (__file__ )
48
49
internlm_accelerator = get_accelerator ()
@@ -60,7 +61,7 @@ class CheckpointLoadContent:
60
61
SCHEDULAER = "scheduler"
61
62
62
63
63
- def try_load_internevo_ckpt (ckpt_mm , load_info , train_state : TrainState = None ):
64
+ def try_load_internevo_ckpt (ckpt_mm , load_info , train_state : TrainState = None , universal_ckpt = False ):
64
65
"""Tries to load a checkpoint from the given folder.
65
66
66
67
Args:
@@ -82,8 +83,13 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None):
82
83
and the checkpoint manager ckpt_mm and train state objects
83
84
"""
84
85
load_content_str , load_ckpt_folder , load_content = process_load_info (load_info )
86
+
87
+ if universal_ckpt :
88
+ from internlm .checkpoint .vescale .api import load as vescale_load
89
+ checkpoint_state = {"model" : ckpt_mm .model , "optimizer" : ckpt_mm .optimizer }
90
+ vescale_load (load_ckpt_folder , checkpoint_state , broadcast_checkpoint = False )
85
91
86
- if load_content .need_load (CheckpointLoadContent .MODEL ):
92
+ if not universal_ckpt and load_content .need_load (CheckpointLoadContent .MODEL ):
87
93
load_model_checkpoint (folder = load_ckpt_folder , model = ckpt_mm .model )
88
94
load_content_str += f"{ CheckpointLoadContent .MODEL } , "
89
95
@@ -93,7 +99,7 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None):
93
99
load_context (load_ckpt_folder , train_state )
94
100
95
101
# load optimizer states.
96
- if load_content .need_load (CheckpointLoadContent .OPIMIZER ):
102
+ if not universal_ckpt and load_content .need_load (CheckpointLoadContent .OPIMIZER ):
97
103
load_optimizer_checkpoint (load_ckpt_folder , ckpt_mm .optimizer )
98
104
load_content_str += f"{ CheckpointLoadContent .OPIMIZER } , "
99
105
else :
@@ -110,6 +116,7 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None):
110
116
logger .warning ("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!" )
111
117
112
118
if not load_content .need_load (CheckpointLoadContent .OPIMIZER ):
119
+ assert False
113
120
if ckpt_mm .lr_scheduler and train_state :
114
121
gpc .config .only_load_lr = True
115
122
load_optimizer_checkpoint (load_ckpt_folder , ckpt_mm .optimizer )
@@ -419,10 +426,11 @@ def is_now_to_save_ckpt(self, train_state, force=False) -> (bool, CheckpointSave
419
426
def try_save_checkpoint (self , train_state , force = False ):
420
427
if not self .enable_save_ckpt :
421
428
return False
422
-
429
+
423
430
save_ckpts , save_type , now_break = self .is_now_to_save_ckpt (train_state , force = force )
424
431
425
432
if save_ckpts :
433
+ begin = time .time ()
426
434
# Wait for the previous round of asynchronous upload storage to complete.
427
435
self .storage_manager .wait ()
428
436
if save_type == CheckpointSaveType .SNAPSHOT_CHECKPOINT :
@@ -440,6 +448,7 @@ def try_save_checkpoint(self, train_state, force=False):
440
448
train_state = train_state ,
441
449
model_config = self .model_config ,
442
450
model_config_file = self .model_config_file ,
451
+ universal_ckpt = gpc .config .ckpt .universal_ckpt ,
443
452
)
444
453
445
454
if (
@@ -460,6 +469,8 @@ def try_save_checkpoint(self, train_state, force=False):
460
469
f"Finish to convert internevo2hf checkpoint from { save_ckpt_folder } to { save_hf_ckpt_folder } ."
461
470
)
462
471
torch .distributed .barrier ()
472
+ end = time .time () - begin
473
+ print (f"finsh save time { gpc .get_global_rank ()} : { end } " , flush = True )
463
474
464
475
return now_break
465
476
@@ -576,12 +587,19 @@ def try_resume_training(self, train_state: TrainState, current_time=""):
576
587
f"dp={ gpc .get_local_rank (ParallelMode .DATA )} ==========="
577
588
)
578
589
else :
590
+ begin = time .time ()
579
591
load_path = self .load_ckpt_info ["path" ]
580
592
load_content = self .load_ckpt_info ["content" ]
581
593
load_type = self .load_ckpt_info ["ckpt_type" ]
594
+ universal_ckpt = gpc .config .ckpt .universal_ckpt
595
+ kwargs = {}
596
+
597
+ if universal_ckpt :
598
+ assert load_type == "internevo" , "Only internevo ckpt support universal ckpt."
599
+ kwargs = {"universal_ckpt" : universal_ckpt }
582
600
583
601
load_func = CheckpointLoadMethod .get_ckpt_load_type_func (load_type )
584
- load_content_str = load_func (self , self .load_ckpt_info , train_state )
602
+ load_content_str = load_func (self , self .load_ckpt_info , train_state , ** kwargs )
585
603
586
604
# If we only load model weight, we need rewrite zero optim's fp32 buffer.
587
605
if (
@@ -598,6 +616,8 @@ def try_resume_training(self, train_state: TrainState, current_time=""):
598
616
)
599
617
if load_content_str :
600
618
logger .info (f"===========Load contents are: { load_content_str } " )
619
+ end = time .time () - begin
620
+ print (f"finsh load time { gpc .get_global_rank ()} : { end } " , flush = True )
601
621
602
622
@llm_timeout (func_name = "save_checkpoint" )
603
623
def save_checkpoint (
@@ -609,6 +629,7 @@ def save_checkpoint(
609
629
train_state : TrainState ,
610
630
model_config : Dict = None ,
611
631
model_config_file : str = None ,
632
+ universal_ckpt = False ,
612
633
):
613
634
"""
614
635
Save checkpoint to the given folder path.
@@ -621,13 +642,23 @@ def save_checkpoint(
621
642
if gpc .is_rank_for_log ():
622
643
logger .info (f"Saving checkpoint to `{ folder } ` at batch count:{ train_state .step_count } ..." )
623
644
624
- timer ("save-model" ).start ()
625
- save_model_checkpoint (folder = folder , model = model )
626
- timer ("save-model" ).stop ()
645
+ if not universal_ckpt :
646
+ print (f"save ckpt: base" , flush = True )
647
+ timer ("save-model" ).start ()
648
+ save_model_checkpoint (folder = folder , model = model )
649
+ timer ("save-model" ).stop ()
627
650
628
- timer ("save-optimizer" ).start ()
629
- save_optimizer_checkpoint (optim = optimizer , state_path = folder )
630
- timer ("save-optimizer" ).stop ()
651
+ timer ("save-optimizer" ).start ()
652
+ save_optimizer_checkpoint (optim = optimizer , state_path = folder )
653
+ timer ("save-optimizer" ).stop ()
654
+ else :
655
+ print (f"save ckpt: universal" , flush = True )
656
+ vescale_save (
657
+ path = folder ,
658
+ checkpoint_state = {"model" : model , "optimizer" : optimizer },
659
+ async_checkpoint = False ,
660
+ )
661
+
631
662
632
663
if (
633
664
hasattr (train_state , "data_state_dict" )
0 commit comments