Skip to content

Commit

Permalink
Merge pull request #25 from mir-group/develop
Browse files Browse the repository at this point in the history
Small fixes to deploy - v0.2.1
  • Loading branch information
Linux-cpp-lisp authored May 3, 2021
2 parents 4586c78 + 0d70ae2 commit 349b269
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

Most recent change on the bottom.

## [Unreleased]
## [0.2.1] - 2021-05-03
### Fixed
- `load_deployed_model` now correctly loads all metadata

## [0.2.0] - 2021-04-30
2 changes: 1 addition & 1 deletion nequip/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# See Python packaging guide
# https://packaging.python.org/guides/single-sourcing-package-version/

__version__ = "0.2.0"
__version__ = "0.2.1"
24 changes: 21 additions & 3 deletions nequip/scripts/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import pathlib
import logging
import warnings
import yaml

# This is a weird hack to avoid Intel MKL issues on the cluster when this is called as a subprocess of a process that has itself initialized PyTorch.
Expand All @@ -20,10 +21,12 @@
R_MAX_KEY: Final[str] = "r_max"
N_SPECIES_KEY: Final[str] = "n_species"

_ALL_METADATA_KEYS = [CONFIG_KEY, NEQUIP_VERSION_KEY, R_MAX_KEY, N_SPECIES_KEY]


def load_deployed_model(
model_path: Union[pathlib.Path, str]
) -> Tuple[torch.nn.Module, Dict[str, str]]:
) -> Tuple[torch.jit.ScriptModule, Dict[str, str]]:
r"""Load a deployed model.
Args:
Expand All @@ -32,17 +35,29 @@ def load_deployed_model(
Returns:
model, metadata dictionary
"""
metadata = {CONFIG_KEY: "", NEQUIP_VERSION_KEY: ""}
metadata = {k: "" for k in _ALL_METADATA_KEYS}
try:
model = torch.jit.load(model_path, _extra_files=metadata)
except RuntimeError as e:
raise ValueError(
f"{model_path} does not seem to be a deployed NequIP model file. (Underlying error: {e})"
f"{model_path} does not seem to be a deployed NequIP model file. Did you forget to deploy it using `nequip-deploy`? \n\n(Underlying error: {e})"
)
# Confirm nequip made it
if metadata[NEQUIP_VERSION_KEY] == "":
raise ValueError(
f"{model_path} does not seem to be a deployed NequIP model file"
)
# Remove missing metadata
for k in metadata:
# TODO: some better semver based checking of versions here, or something
if metadata[k] == "":
warnings.warn(
f"Metadata key `{k}` wasn't present in the saved model; this may indicate compatability issues."
)
# Confirm its TorchScript
assert isinstance(model, torch.jit.ScriptModule)
# Make sure we're in eval mode
model.eval()
# Everything we store right now is ASCII, so decode for printing
metadata = {k: v.decode("ascii") for k, v in metadata.items()}
return model, metadata
Expand Down Expand Up @@ -116,6 +131,8 @@ def main(args=None):
model = script(model)
logging.info("Compiled model to TorchScript")

model.eval() # just to be sure

model = torch.jit.freeze(model)
logging.info("Froze TorchScript model")

Expand All @@ -129,6 +146,7 @@ def main(args=None):
metadata[R_MAX_KEY] = str(float(config["r_max"]))
metadata[N_SPECIES_KEY] = str(len(config["allowed_species"]))
metadata[CONFIG_KEY] = config_str
metadata = {k: v.encode("ascii") for k, v in metadata.items()}
torch.jit.save(model, args.out_file, _extra_files=metadata)
else:
raise ValueError
Expand Down

0 comments on commit 349b269

Please sign in to comment.