diff --git a/src/nn_core/callbacks.py b/src/nn_core/callbacks.py index 06cfbf2..b1e0687 100644 --- a/src/nn_core/callbacks.py +++ b/src/nn_core/callbacks.py @@ -17,6 +17,7 @@ class NNTemplateCore(Callback): def __init__(self, restore_cfg: Optional[DictConfig]): self.resume_ckpt_path, self.resume_run_version = parse_restore(restore_cfg) self.restore_mode: Optional[str] = restore_cfg.get("mode", None) if restore_cfg is not None else None + self.restore_strict: bool = restore_cfg.get("strict", True) if restore_cfg is not None else True @property def resume_id(self) -> Optional[str]: @@ -41,7 +42,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O if self.restore_mode == "finetune": checkpoint = NNCheckpointIO.load(path=Path(self.resume_ckpt_path)) - pl_module.load_state_dict(checkpoint["state_dict"]) + pl_module.load_state_dict(checkpoint["state_dict"], strict=self.restore_strict) def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if self._is_nnlogger(trainer):