From 9dd34cfb408490e1be5ad8de1e30f9fbe826cc8e Mon Sep 17 00:00:00 2001 From: Jeremias Traub Date: Mon, 5 Jun 2023 14:49:54 +0200 Subject: [PATCH] Fix loading of zenodo model weights --- fd_shifts/experiments/__init__.py | 3 +++ fd_shifts/models/confidnet_model.py | 10 ++++++++++ fd_shifts/models/devries_model.py | 10 ++++++++++ 3 files changed, 23 insertions(+) diff --git a/fd_shifts/experiments/__init__.py b/fd_shifts/experiments/__init__.py index 4d73be2..a998bad 100644 --- a/fd_shifts/experiments/__init__.py +++ b/fd_shifts/experiments/__init__.py @@ -49,6 +49,9 @@ def overrides(self): "exp.name": str(self.to_path().name), } + if dataset in ("cifar10", "cifar100", "supercifar") and self.dropout: + overrides["model.avg_pool"] = False + if self.learning_rate is not None: overrides["trainer.optimizer.lr"] = self.learning_rate diff --git a/fd_shifts/models/confidnet_model.py b/fd_shifts/models/confidnet_model.py index c2307fb..1f5efe9 100644 --- a/fd_shifts/models/confidnet_model.py +++ b/fd_shifts/models/confidnet_model.py @@ -309,5 +309,15 @@ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: def load_only_state_dict(self, path: str | Path) -> None: ckpt = torch.load(path) + + # For backwards-compatibility with before commit 1bdc717 + for param in list(ckpt["state_dict"].keys()): + if ".encoder." in param or ".classifier." in param: + correct_param = param.replace(".encoder.", "._encoder.").replace( + ".classifier.", "._classifier." + ) + ckpt["state_dict"][correct_param] = ckpt["state_dict"][param] + del ckpt["state_dict"][param] + logger.info("loading checkpoint from epoch {}".format(ckpt["epoch"])) self.load_state_dict(ckpt["state_dict"], strict=True) diff --git a/fd_shifts/models/devries_model.py b/fd_shifts/models/devries_model.py index 52db46d..ae497bc 100644 --- a/fd_shifts/models/devries_model.py +++ b/fd_shifts/models/devries_model.py @@ -295,5 +295,15 @@ def on_load_checkpoint(self, checkpoint): def load_only_state_dict(self, path): ckpt = torch.load(path) + + # For backwards-compatibility with before commit 1bdc717 + for param in list(ckpt["state_dict"].keys()): + if ".encoder." in param or ".classifier." in param: + correct_param = param.replace(".encoder.", "._encoder.").replace( + ".classifier.", "._classifier." + ) + ckpt["state_dict"][correct_param] = ckpt["state_dict"][param] + del ckpt["state_dict"][param] + logger.info("loading checkpoint from epoch {}".format(ckpt["epoch"])) self.load_state_dict(ckpt["state_dict"], strict=True)