Skip to content

Commit

Permalink
Merge pull request #25 from umbertov/nonstrict_state_dict
Browse files Browse the repository at this point in the history
restore/finetune: allow non-strict loading of state dict
  • Loading branch information
lucmos authored Aug 6, 2023
2 parents a96c150 + 28d1878 commit fa4b2f1
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/nn_core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class NNTemplateCore(Callback):
def __init__(self, restore_cfg: Optional[DictConfig]):
self.resume_ckpt_path, self.resume_run_version = parse_restore(restore_cfg)
self.restore_mode: Optional[str] = restore_cfg.get("mode", None) if restore_cfg is not None else None
self.restore_strict: bool = restore_cfg.get("strict", True) if restore_cfg is not None else True

@property
def resume_id(self) -> Optional[str]:
Expand All @@ -41,7 +42,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
if self.restore_mode == "finetune":
checkpoint = NNCheckpointIO.load(path=Path(self.resume_ckpt_path))

pl_module.load_state_dict(checkpoint["state_dict"])
pl_module.load_state_dict(checkpoint["state_dict"], strict=self.restore_strict)

def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if self._is_nnlogger(trainer):
Expand Down

0 comments on commit fa4b2f1

Please sign in to comment.