Skip to content

Commit

Permalink
model registry and compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
zubatyuk committed Oct 29, 2024
1 parent 28a1536 commit e4c3a1c
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 52 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
docs/source
aimnet/calculators/assets/

scripts/compile/*.pt
scripts/compile/*.jpt
scripts/compile/*.sae
scripts/compile/aimnet2_b973c_d3.yaml
scripts/compile/aimnet2_wb97m_d3.yaml


# From https://raw.githubusercontent.com/github/gitignore/main/Python.gitignore

Expand Down
2 changes: 1 addition & 1 deletion aimnet/calculators/aimnet2ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class AIMNet2ASE(Calculator):
from typing import ClassVar

implemented_properties: ClassVar[list] = ["energy", "forces", "free_energy", "charges", "stress"]
implemented_properties: ClassVar[list[str]] = ["energy", "forces", "free_energy", "charges", "stress"]

def __init__(self, base_calc: AIMNet2Calculator | str = "aimnet2", charge=0, mult=1):
super().__init__()
Expand Down
Binary file not shown.
2 changes: 1 addition & 1 deletion aimnet/calculators/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import Tensor, nn

from .model_zoo import get_model_path
from .model_registry import get_model_path
from .nbmat import calc_nbmat


Expand Down
60 changes: 60 additions & 0 deletions aimnet/calculators/model_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import logging
import os
from typing import Dict, Optional

import click
import requests
import yaml

logging.basicConfig(level=logging.INFO)


def load_model_registry(registry_file: Optional[str] = None) -> Dict[str, str]:
registry_file = registry_file or os.path.join(os.path.dirname(__file__), "model_registry.yaml")
with open(os.path.join(os.path.dirname(__file__), "model_registry.yaml")) as f:
return yaml.load(f, Loader=yaml.SafeLoader)


def create_assets_dir():
os.makedirs(os.path.join(os.path.dirname(__file__), "assets"), exist_ok=True)


def get_registry_model_path(model_name: str) -> str:
model_registry = load_model_registry()
create_assets_dir()
if model_name in model_registry["aliases"]:
model_name = model_registry["aliases"][model_name] # type: ignore
if model_name not in model_registry["models"]:
raise ValueError(f"Model {model_name} not found in the registry.")
cfg = model_registry["models"][model_name] # type: ignore
model_path = _maybe_download_asset(**cfg) # type: ignore
return model_path


def _maybe_download_asset(file: str, url: str) -> str:
filename = os.path.join(os.path.dirname(__file__), "assets", file)
if not os.path.exists(filename):
print(f"Downloading {url} -> {filename}")
with open(filename, "wb") as f:
response = requests.get(url, timeout=60)
f.write(response.content)
return filename


def get_model_path(s: str) -> str:
# direct file path
if os.path.isfile(s):
print("Found model file:", s)
else:
s = get_registry_model_path(s)
return s


@click.command(short_help="Clear assets directory.")
def clear_assets():
from glob import glob

for fil in glob(os.path.join(os.path.dirname(__file__), "assets", "*")):
if os.path.isfile(fil):
logging.warn(f"Removing {fil}")
os.remove(fil)
33 changes: 33 additions & 0 deletions aimnet/calculators/model_registry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# map file name to url
models:
aimnet2_wb97m_d3_0:
file: aimnet2_wb97m_d3_0.jpt
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_0.jpt
aimnet2_wb97m_d3_1:
file: aimnet2_wb97m_d3_1.jpt
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_1.jpt
aimnet2_wb97m_d3_2:
file: aimnet2_wb97m_d3_2.jpt
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_2.jpt
aimnet2_wb97m_d3_3:
file: aimnet2_wb97m_d3_3.jpt
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_3.jpt
aimnet2_b973c_d3_0:
file: aimnet2_b973c_d3_0.jpt
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_0.jpt
aimnet2_b973c_d3_1:
file: aimnet2_b973c_d3_1.jpt
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_1.jpt
aimnet2_b973c_d3_2:
file: aimnet2_b973c_d3_2.jpt
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_2.jpt
aimnet2_b973c_d3_3:
file: aimnet2_b973c_d3_3.jpt
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_3.jpt

# map model alias to file name
aliases:
aimnet2: aimnet2_wb97m_d3_0
aimnet2_wb97m: aimnet2_wb97m_d3_0
aimnet2_b973c: aimnet2_b973c_d3_0
aimnet2_qr: aimnet2_qr_v0
37 changes: 0 additions & 37 deletions aimnet/calculators/model_zoo.py

This file was deleted.

15 changes: 3 additions & 12 deletions aimnet/train/pt2jpt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
from typing import Dict, List, Optional
from typing import List, Optional

import click
import torch
from torch import Tensor, nn
from torch import nn

from aimnet.config import build_module, load_yaml

Expand Down Expand Up @@ -35,12 +35,6 @@ def add_sae_to_shifts(model: nn.Module, sae_file: str) -> nn.Module:
return model


def fix_agh(state_dict: Dict[str, Tensor]) -> dict:
state_dict["conv_q.agh"] = state_dict["conv_q.agh"].permute(2, 1, 0).contiguous()
state_dict["conv_a.agh"] = state_dict["conv_a.agh"].permute(2, 1, 0).contiguous()
return state_dict


def mask_not_implemented_species(model: nn.Module, species: List[int]) -> nn.Module:
weight = model.afv.weight
for i in range(1, weight.shape[0]):
Expand All @@ -58,9 +52,8 @@ def mask_not_implemented_species(model: nn.Module, species: List[int]) -> nn.Mod
@click.option("--model", type=str, default=_default_aimnet2_config, help="Path to model definition YAML file")
@click.option("--sae", type=str, default=None, help="Path to the energy shift YAML file.")
@click.option("--species", type=str, default=None, help="Comma-separated list of parametrized atomic numbers.")
@click.option("--fix-agh", is_flag=True, help="Fix the agh weights in the PyTorch model.")
@click.option("--no-lr", is_flag=True, help="Do not add LR cutoff for model")
def jitcompile(model, pt, jpt, sae, species, fix_agh, no_lr): # type: ignore[no-untyped-def]
def jitcompile(model: str, pt: str, jpt: str, sae=None, species=None, no_lr=False): # type: ignore
"""Build model from YAML config, load weight from PT file and write JIT-compiled JPT file.
Plus some modifications to work with aimnet2calc.
"""
Expand All @@ -69,8 +62,6 @@ def jitcompile(model, pt, jpt, sae, species, fix_agh, no_lr): # type: ignore[no
cutoff_lr = None if no_lr else float("inf")
model = add_cutoff(model, cutoff_lr=cutoff_lr)
sd = torch.load(pt, map_location="cpu", weights_only=True)
if fix_agh:
sd = fix_agh(sd)
print(model.load_state_dict(sd, strict=False))
if sae:
model = add_sae_to_shifts(model, sae)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ packages = [
[tool.poetry.scripts]
aimnet2pysis = "aimnet.calculators.aimnet2pysis:main"
aimnet = "aimnet.cli:cli"
aimnet-clear-models = "aimnet.calculators.model_registry:clear_assets"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down
45 changes: 45 additions & 0 deletions scripts/compile/compile_off.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os.path

import requests
import yaml

from aimnet.train.pt2jpt import jitcompile


def compile_from_config(config):
for job_name, job_config in config.items():
print(f"Compiling {job_name}.")
models = job_config.pop("models")
job_config = _maybe_download(job_config)
for task_config in models:
task_config = _maybe_download(task_config)
config = {**job_config, **task_config}
print(f"{config['pt']} -> {config['jpt']}")
jitcompile.callback(**config) # type: ignore


def _maybe_download(d: dict[str, str]) -> dict[str, str]:
for key, value in d.items():
if value.startswith("https:"):
filename = value.split("/")[-1]
if not os.path.exists(filename):
print(f"Downloading {filename}.")
with open(filename, "wb") as file:
response = requests.get(value, timeout=20)
file.write(response.content)
value = filename
d[key] = value
return d


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Batch compile PyTorch models to TorchScript.")
parser.add_argument("config", type=str, help="Path to the input YAML config file.")
args = parser.parse_args()

with open(args.config) as file:
config = yaml.load(file.read(), Loader=yaml.SafeLoader)

compile_from_config(config)
25 changes: 25 additions & 0 deletions scripts/compile/compile_off.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
aimnet2_b973c:
model: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3.yaml
sae: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c.sae
models:
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_0.pt
jpt: aimnet2_b973c_d3_0.jpt
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_1.pt
jpt: aimnet2_b973c_d3_1.jpt
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_2.pt
jpt: aimnet2_b973c_d3_2.jpt
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_3.pt
jpt: aimnet2_b973c_d3_3.jpt

aimnet2_wb97m:
model: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3.yaml
sae: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m.sae
models:
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_0.pt
jpt: aimnet2_wb97m_d3_0.jpt
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_1.pt
jpt: aimnet2_wb97m_d3_1.jpt
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_2.pt
jpt: aimnet2_wb97m_d3_2.jpt
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_3.pt
jpt: aimnet2_wb97m_d3_3.jpt
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch import nn

from aimnet.calculators.model_zoo import get_model_path
from aimnet.calculators.model_registry import get_model_path
from aimnet.config import build_module
from aimnet.modules.core import Forces

Expand Down

0 comments on commit e4c3a1c

Please sign in to comment.