Skip to content

Commit

Permalink
Merge branch 'zenodo-weights' into 'main'
Browse files Browse the repository at this point in the history
Fix loading of zenodo model weights

Closes #19

See merge request hi-dkfz/iml/failure-detection-benchmark!15
  • Loading branch information
jeremiastraub committed Jun 5, 2023
2 parents ec2ee92 + 9dd34cf commit 7545b32
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
3 changes: 3 additions & 0 deletions fd_shifts/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def overrides(self):
"exp.name": str(self.to_path().name),
}

if dataset in ("cifar10", "cifar100", "supercifar") and self.dropout:
overrides["model.avg_pool"] = False

if self.learning_rate is not None:
overrides["trainer.optimizer.lr"] = self.learning_rate

Expand Down
10 changes: 10 additions & 0 deletions fd_shifts/models/confidnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,5 +309,15 @@ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:

def load_only_state_dict(self, path: str | Path) -> None:
ckpt = torch.load(path)

# For backwards-compatibility with before commit 1bdc717
for param in list(ckpt["state_dict"].keys()):
if ".encoder." in param or ".classifier." in param:
correct_param = param.replace(".encoder.", "._encoder.").replace(
".classifier.", "._classifier."
)
ckpt["state_dict"][correct_param] = ckpt["state_dict"][param]
del ckpt["state_dict"][param]

logger.info("loading checkpoint from epoch {}".format(ckpt["epoch"]))
self.load_state_dict(ckpt["state_dict"], strict=True)
10 changes: 10 additions & 0 deletions fd_shifts/models/devries_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,5 +295,15 @@ def on_load_checkpoint(self, checkpoint):

def load_only_state_dict(self, path):
ckpt = torch.load(path)

# For backwards-compatibility with before commit 1bdc717
for param in list(ckpt["state_dict"].keys()):
if ".encoder." in param or ".classifier." in param:
correct_param = param.replace(".encoder.", "._encoder.").replace(
".classifier.", "._classifier."
)
ckpt["state_dict"][correct_param] = ckpt["state_dict"][param]
del ckpt["state_dict"][param]

logger.info("loading checkpoint from epoch {}".format(ckpt["epoch"]))
self.load_state_dict(ckpt["state_dict"], strict=True)

0 comments on commit 7545b32

Please sign in to comment.