Skip to content

Commit 774d32f

Browse files
committed
support optm
1 parent b40b766 commit 774d32f

12 files changed

+765
-186
lines changed

internlm/checkpoint/checkpoint_manager.py

+42-11
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444
from .load_funcs import LOAD_FUNC_DICT
4545
from .utils import process_load_info
46+
from internlm.checkpoint.vescale.api import save as vescale_save
4647

4748
logger = get_logger(__file__)
4849
internlm_accelerator = get_accelerator()
@@ -60,7 +61,7 @@ class CheckpointLoadContent:
6061
SCHEDULAER = "scheduler"
6162

6263

63-
def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None):
64+
def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None, universal_ckpt=False):
6465
"""Tries to load a checkpoint from the given folder.
6566
6667
Args:
@@ -82,8 +83,13 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None):
8283
and the checkpoint manager ckpt_mm and train state objects
8384
"""
8485
load_content_str, load_ckpt_folder, load_content = process_load_info(load_info)
86+
87+
if universal_ckpt:
88+
from internlm.checkpoint.vescale.api import load as vescale_load
89+
checkpoint_state = {"model": ckpt_mm.model, "optimizer": ckpt_mm.optimizer}
90+
vescale_load(load_ckpt_folder, checkpoint_state, broadcast_checkpoint=False)
8591

86-
if load_content.need_load(CheckpointLoadContent.MODEL):
92+
if not universal_ckpt and load_content.need_load(CheckpointLoadContent.MODEL):
8793
load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model)
8894
load_content_str += f"{CheckpointLoadContent.MODEL}, "
8995

@@ -93,7 +99,7 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None):
9399
load_context(load_ckpt_folder, train_state)
94100

95101
# load optimizer states.
96-
if load_content.need_load(CheckpointLoadContent.OPIMIZER):
102+
if not universal_ckpt and load_content.need_load(CheckpointLoadContent.OPIMIZER):
97103
load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
98104
load_content_str += f"{CheckpointLoadContent.OPIMIZER}, "
99105
else:
@@ -110,6 +116,7 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None):
110116
logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!")
111117

112118
if not load_content.need_load(CheckpointLoadContent.OPIMIZER):
119+
assert False
113120
if ckpt_mm.lr_scheduler and train_state:
114121
gpc.config.only_load_lr = True
115122
load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
@@ -419,10 +426,11 @@ def is_now_to_save_ckpt(self, train_state, force=False) -> (bool, CheckpointSave
419426
def try_save_checkpoint(self, train_state, force=False):
420427
if not self.enable_save_ckpt:
421428
return False
422-
429+
423430
save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state, force=force)
424431

425432
if save_ckpts:
433+
begin = time.time()
426434
# Wait for the previous round of asynchronous upload storage to complete.
427435
self.storage_manager.wait()
428436
if save_type == CheckpointSaveType.SNAPSHOT_CHECKPOINT:
@@ -440,6 +448,7 @@ def try_save_checkpoint(self, train_state, force=False):
440448
train_state=train_state,
441449
model_config=self.model_config,
442450
model_config_file=self.model_config_file,
451+
universal_ckpt=gpc.config.ckpt.universal_ckpt,
443452
)
444453

445454
if (
@@ -460,6 +469,8 @@ def try_save_checkpoint(self, train_state, force=False):
460469
f"Finish to convert internevo2hf checkpoint from {save_ckpt_folder} to {save_hf_ckpt_folder}."
461470
)
462471
torch.distributed.barrier()
472+
end = time.time() - begin
473+
print(f"finsh save time {gpc.get_global_rank()}: {end}", flush=True)
463474

464475
return now_break
465476

@@ -576,12 +587,19 @@ def try_resume_training(self, train_state: TrainState, current_time=""):
576587
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
577588
)
578589
else:
590+
begin = time.time()
579591
load_path = self.load_ckpt_info["path"]
580592
load_content = self.load_ckpt_info["content"]
581593
load_type = self.load_ckpt_info["ckpt_type"]
594+
universal_ckpt = gpc.config.ckpt.universal_ckpt
595+
kwargs = {}
596+
597+
if universal_ckpt:
598+
assert load_type == "internevo", "Only internevo ckpt support universal ckpt."
599+
kwargs = {"universal_ckpt": universal_ckpt}
582600

583601
load_func = CheckpointLoadMethod.get_ckpt_load_type_func(load_type)
584-
load_content_str = load_func(self, self.load_ckpt_info, train_state)
602+
load_content_str = load_func(self, self.load_ckpt_info, train_state, **kwargs)
585603

586604
# If we only load model weight, we need rewrite zero optim's fp32 buffer.
587605
if (
@@ -598,6 +616,8 @@ def try_resume_training(self, train_state: TrainState, current_time=""):
598616
)
599617
if load_content_str:
600618
logger.info(f"===========Load contents are: {load_content_str}")
619+
end = time.time() - begin
620+
print(f"finsh load time {gpc.get_global_rank()}: {end}", flush=True)
601621

602622
@llm_timeout(func_name="save_checkpoint")
603623
def save_checkpoint(
@@ -609,6 +629,7 @@ def save_checkpoint(
609629
train_state: TrainState,
610630
model_config: Dict = None,
611631
model_config_file: str = None,
632+
universal_ckpt=False,
612633
):
613634
"""
614635
Save checkpoint to the given folder path.
@@ -621,13 +642,23 @@ def save_checkpoint(
621642
if gpc.is_rank_for_log():
622643
logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...")
623644

624-
timer("save-model").start()
625-
save_model_checkpoint(folder=folder, model=model)
626-
timer("save-model").stop()
645+
if not universal_ckpt:
646+
print(f"save ckpt: base", flush=True)
647+
timer("save-model").start()
648+
save_model_checkpoint(folder=folder, model=model)
649+
timer("save-model").stop()
627650

628-
timer("save-optimizer").start()
629-
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
630-
timer("save-optimizer").stop()
651+
timer("save-optimizer").start()
652+
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
653+
timer("save-optimizer").stop()
654+
else:
655+
print(f"save ckpt: universal", flush=True)
656+
vescale_save(
657+
path=folder,
658+
checkpoint_state={"model": model, "optimizer": optimizer},
659+
async_checkpoint=False,
660+
)
661+
631662

632663
if (
633664
hasattr(train_state, "data_state_dict")

0 commit comments

Comments
 (0)