Skip to content

Commit

Permalink
Merge pull request #26 from grok-ai/feature/bump-dependencies
Browse files Browse the repository at this point in the history
Update to Lightning 1.7
  • Loading branch information
lucmos authored Sep 9, 2022
2 parents 6ecc71f + 5fe4662 commit f55551b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ packages=find:
install_requires =
# Add project specific dependencies
# Stuff easy to break with updates
pytorch-lightning>=1.5.8,<1.6
pytorch-lightning==1.7.*
hydra-core
wandb

Expand Down
6 changes: 5 additions & 1 deletion src/nn_core/model_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def __init__(self, logging_cfg: DictConfig, cfg: DictConfig, resume_id: Optional
self.logging_cfg.logger.mode = "offline"

pylogger.info(f"Instantiating <{self.logging_cfg.logger['_target_'].split('.')[-1]}>")
self.wrapped: LightningLoggerBase = hydra.utils.instantiate(self.logging_cfg.logger, version=self.resume_id)
self.wrapped: LightningLoggerBase = hydra.utils.instantiate(
self.logging_cfg.logger,
version=self.resume_id,
dir=os.getenv("WANDB_DIR", "."),
)

# force experiment lazy initialization
_ = self.wrapped.experiment
Expand Down
5 changes: 3 additions & 2 deletions src/nn_core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import pytorch_lightning as pl
import torch
from pytorch_lightning.core.saving import _load_state
from pytorch_lightning.plugins import TorchCheckpointIO

METADATA_KEY: str = "metadata"
Expand Down Expand Up @@ -94,7 +95,7 @@ def remove_checkpoint(self, path) -> None:


def compress_checkpoint(src_dir: Path, dst_file: Path, delete_dir: bool = True):

dst_file.parent.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(_normalize_path(dst_file), "w") as zip_file:
for folder, subfolders, files in os.walk(src_dir):
folder: Path = Path(folder)
Expand Down Expand Up @@ -161,7 +162,7 @@ def load_model(
if substitute_values is not None:
checkpoint = _substistute(checkpoint, substitute_values=substitute_values, substitute_keys=substitute_keys)

return module_class._load_model_state(checkpoint=checkpoint, metadata=checkpoint.get("metadata", None))
return _load_state(cls=module_class, checkpoint=checkpoint, metadata=checkpoint.get("metadata", None))
else:
pylogger.warning(f"Loading a legacy checkpoint (from vanilla PyTorch Lightning): '{checkpoint_path}'")
module_class.load_from_checkpoint(checkpoint_path=str(checkpoint_path), map_location=map_location)

0 comments on commit f55551b

Please sign in to comment.