diff --git a/model_checkpointing/checkpoint_handler.py b/model_checkpointing/checkpoint_handler.py index b097df97d..e6749d397 100644 --- a/model_checkpointing/checkpoint_handler.py +++ b/model_checkpointing/checkpoint_handler.py @@ -5,6 +5,8 @@ from datetime import datetime import torch import time +import torch.optim as optim +from configs import fsdp_config, train_config from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, @@ -63,9 +65,18 @@ def load_model_sharded(model, rank, cfg): if rank == 0: print(f"loading model from model path: {load_dir} ") reader = FileSystemReader(load_dir) - + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): - checkpoint = {"model": model.state_dict()} + # if user specified save optimizer - initialize key in dict + if train_config.save_optimizer: + optimizer = optim.AdamW( + model.parameters(), + lr=train_config.lr, + weight_decay=0.0, + ) + checkpoint = {"model": model.state_dict(), "optim": FSDP.optim_state_dict(model, optimizer)} + else: + checkpoint = {"model": model.state_dict()} if rank == 0: ck = checkpoint.keys() print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") @@ -79,6 +90,8 @@ def load_model_sharded(model, rank, cfg): ck = checkpoint.keys() print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") model.load_state_dict(checkpoint["model"]) + if train_config.save_optimizer: + optimizer.load_state_dict(checkpoint["optim"]) if rank == 0: print(f"Sharded state checkpoint loaded from {load_dir}") @@ -264,4 +277,4 @@ def load_sharded_model_single_gpu(model,model_path): model.load_state_dict(state_dict["model"]) print(f"Sharded state checkpoint loaded from {model_path}") - return model \ No newline at end of file + return model