diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 539c97d..5974a29 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,13 +45,13 @@ repos: - id: isort - repo: https://github.com/psf/black.git - rev: '21.12b0' + rev: '22.1.0' hooks: - id: black - id: black-jupyter - repo: https://github.com/asottile/blacken-docs.git - rev: 'v1.12.0' + rev: 'v1.12.1' hooks: - id: blacken-docs @@ -59,8 +59,6 @@ repos: rev: '4.0.1' hooks: - id: flake8 - additional_dependencies: - - flake8-docstrings==1.6.0 - repo: https://github.com/pycqa/pydocstyle.git rev: '6.1.1' diff --git a/src/nn_core/serialization.py b/src/nn_core/serialization.py index 1d8d366..893d665 100644 --- a/src/nn_core/serialization.py +++ b/src/nn_core/serialization.py @@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, Optional, Type, Union import pytorch_lightning as pl +import torch from pytorch_lightning.plugins import TorchCheckpointIO METADATA_KEY: str = "metadata" @@ -123,8 +124,15 @@ def extract_checkpoint(ckpt_file: Path) -> Path: yield Path(tmp_dir) -def load_model(module_class: Type[pl.LightningModule], checkpoint_path: Path): - checkpoint = NNCheckpointIO.load(path=checkpoint_path) - - model = module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint["metadata"]) - return model +def load_model( + module_class: Type[pl.LightningModule], + checkpoint_path: Path, + map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, +): + # Lightning checkpoints end with .ckpt, ours with .ckpt.zip + if checkpoint_path.name.endswith(".ckpt.zip"): + checkpoint = NNCheckpointIO.load(path=checkpoint_path, map_location=map_location) + return module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint.get("metadata", None)) + else: + pylogger.warning(f"Loading a legacy checkpoint (from vanilla PyTorch Lightning): '{checkpoint_path}'") + module_class.load_from_checkpoint(checkpoint_path=str(checkpoint_path), map_location=map_location)