Skip to content

Commit

Permalink
Merge pull request #30 from grok-ai/develop
Browse files Browse the repository at this point in the history
Hotfix load model method
  • Loading branch information
lucmos authored Sep 10, 2023
2 parents fa4b2f1 + cabf23e commit 585e7a6
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/nn_core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import zipfile
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, Union
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -151,18 +151,19 @@ def _substistute(dictionary, substitute_values: Dict[str, str], substitute_keys:
def load_model(
module_class: Type[pl.LightningModule],
checkpoint_path: Path,
strict: bool = True,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
substitute_keys: Optional[Dict[str, str]] = None,
substitute_values: Optional[Dict[str, str]] = None,
):
) -> Tuple[pl.LightningModule, Dict[str, Any]]:
# Lightning checkpoints end with .ckpt, ours with .ckpt.zip
if checkpoint_path.name.endswith(".ckpt.zip"):
checkpoint = NNCheckpointIO.load(path=checkpoint_path, map_location=map_location)

if substitute_values is not None:
checkpoint = _substistute(checkpoint, substitute_values=substitute_values, substitute_keys=substitute_keys)

return _load_state(cls=module_class, checkpoint=checkpoint, metadata=checkpoint.get("metadata", None))
return _load_state(cls=module_class, checkpoint=checkpoint, strict=strict, metadata=checkpoint.get("metadata", None)), checkpoint
else:
pylogger.warning(f"Loading a legacy checkpoint (from vanilla PyTorch Lightning): '{checkpoint_path}'")
module_class.load_from_checkpoint(checkpoint_path=str(checkpoint_path), map_location=map_location)
return module_class.load_from_checkpoint(checkpoint_path=str(checkpoint_path), map_location=map_location), None

0 comments on commit 585e7a6

Please sign in to comment.