Skip to content

Commit

Permalink
Merge pull request #20 from grok-ai/develop
Browse files Browse the repository at this point in the history
Version 0.1.0
  • Loading branch information
Flegyas authored Mar 1, 2022
2 parents e4e0296 + efb7ce8 commit f97aa83
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
6 changes: 2 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,20 @@ repos:
- id: isort

- repo: https://github.com/psf/black.git
rev: '21.12b0'
rev: '22.1.0'
hooks:
- id: black
- id: black-jupyter

- repo: https://github.com/asottile/blacken-docs.git
rev: 'v1.12.0'
rev: 'v1.12.1'
hooks:
- id: blacken-docs

- repo: https://github.com/PyCQA/flake8.git
rev: '4.0.1'
hooks:
- id: flake8
additional_dependencies:
- flake8-docstrings==1.6.0

- repo: https://github.com/pycqa/pydocstyle.git
rev: '6.1.1'
Expand Down
18 changes: 13 additions & 5 deletions src/nn_core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Callable, Dict, Optional, Type, Union

import pytorch_lightning as pl
import torch
from pytorch_lightning.plugins import TorchCheckpointIO

METADATA_KEY: str = "metadata"
Expand Down Expand Up @@ -123,8 +124,15 @@ def extract_checkpoint(ckpt_file: Path) -> Path:
yield Path(tmp_dir)


def load_model(module_class: Type[pl.LightningModule], checkpoint_path: Path):
checkpoint = NNCheckpointIO.load(path=checkpoint_path)

model = module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint["metadata"])
return model
def load_model(
module_class: Type[pl.LightningModule],
checkpoint_path: Path,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
):
# Lightning checkpoints end with .ckpt, ours with .ckpt.zip
if checkpoint_path.name.endswith(".ckpt.zip"):
checkpoint = NNCheckpointIO.load(path=checkpoint_path, map_location=map_location)
return module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint.get("metadata", None))
else:
pylogger.warning(f"Loading a legacy checkpoint (from vanilla PyTorch Lightning): '{checkpoint_path}'")
module_class.load_from_checkpoint(checkpoint_path=str(checkpoint_path), map_location=map_location)

0 comments on commit f97aa83

Please sign in to comment.