-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
178 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import logging | ||
import os | ||
from typing import Dict, Optional | ||
|
||
import click | ||
import requests | ||
import yaml | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
def load_model_registry(registry_file: Optional[str] = None) -> Dict[str, str]: | ||
registry_file = registry_file or os.path.join(os.path.dirname(__file__), "model_registry.yaml") | ||
with open(os.path.join(os.path.dirname(__file__), "model_registry.yaml")) as f: | ||
return yaml.load(f, Loader=yaml.SafeLoader) | ||
|
||
|
||
def create_assets_dir(): | ||
os.makedirs(os.path.join(os.path.dirname(__file__), "assets"), exist_ok=True) | ||
|
||
|
||
def get_registry_model_path(model_name: str) -> str: | ||
model_registry = load_model_registry() | ||
create_assets_dir() | ||
if model_name in model_registry["aliases"]: | ||
model_name = model_registry["aliases"][model_name] # type: ignore | ||
if model_name not in model_registry["models"]: | ||
raise ValueError(f"Model {model_name} not found in the registry.") | ||
cfg = model_registry["models"][model_name] # type: ignore | ||
model_path = _maybe_download_asset(**cfg) # type: ignore | ||
return model_path | ||
|
||
|
||
def _maybe_download_asset(file: str, url: str) -> str: | ||
filename = os.path.join(os.path.dirname(__file__), "assets", file) | ||
if not os.path.exists(filename): | ||
print(f"Downloading {url} -> {filename}") | ||
with open(filename, "wb") as f: | ||
response = requests.get(url, timeout=60) | ||
f.write(response.content) | ||
return filename | ||
|
||
|
||
def get_model_path(s: str) -> str: | ||
# direct file path | ||
if os.path.isfile(s): | ||
print("Found model file:", s) | ||
else: | ||
s = get_registry_model_path(s) | ||
return s | ||
|
||
|
||
@click.command(short_help="Clear assets directory.") | ||
def clear_assets(): | ||
from glob import glob | ||
|
||
for fil in glob(os.path.join(os.path.dirname(__file__), "assets", "*")): | ||
if os.path.isfile(fil): | ||
logging.warn(f"Removing {fil}") | ||
os.remove(fil) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# map file name to url | ||
models: | ||
aimnet2_wb97m_d3_0: | ||
file: aimnet2_wb97m_d3_0.jpt | ||
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_0.jpt | ||
aimnet2_wb97m_d3_1: | ||
file: aimnet2_wb97m_d3_1.jpt | ||
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_1.jpt | ||
aimnet2_wb97m_d3_2: | ||
file: aimnet2_wb97m_d3_2.jpt | ||
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_2.jpt | ||
aimnet2_wb97m_d3_3: | ||
file: aimnet2_wb97m_d3_3.jpt | ||
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3_3.jpt | ||
aimnet2_b973c_d3_0: | ||
file: aimnet2_b973c_d3_0.jpt | ||
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_0.jpt | ||
aimnet2_b973c_d3_1: | ||
file: aimnet2_b973c_d3_1.jpt | ||
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_1.jpt | ||
aimnet2_b973c_d3_2: | ||
file: aimnet2_b973c_d3_2.jpt | ||
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_2.jpt | ||
aimnet2_b973c_d3_3: | ||
file: aimnet2_b973c_d3_3.jpt | ||
url: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3_3.jpt | ||
|
||
# map model alias to file name | ||
aliases: | ||
aimnet2: aimnet2_wb97m_d3_0 | ||
aimnet2_wb97m: aimnet2_wb97m_d3_0 | ||
aimnet2_b973c: aimnet2_b973c_d3_0 | ||
aimnet2_qr: aimnet2_qr_v0 |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os.path | ||
|
||
import requests | ||
import yaml | ||
|
||
from aimnet.train.pt2jpt import jitcompile | ||
|
||
|
||
def compile_from_config(config): | ||
for job_name, job_config in config.items(): | ||
print(f"Compiling {job_name}.") | ||
models = job_config.pop("models") | ||
job_config = _maybe_download(job_config) | ||
for task_config in models: | ||
task_config = _maybe_download(task_config) | ||
config = {**job_config, **task_config} | ||
print(f"{config['pt']} -> {config['jpt']}") | ||
jitcompile.callback(**config) # type: ignore | ||
|
||
|
||
def _maybe_download(d: dict[str, str]) -> dict[str, str]: | ||
for key, value in d.items(): | ||
if value.startswith("https:"): | ||
filename = value.split("/")[-1] | ||
if not os.path.exists(filename): | ||
print(f"Downloading {filename}.") | ||
with open(filename, "wb") as file: | ||
response = requests.get(value, timeout=20) | ||
file.write(response.content) | ||
value = filename | ||
d[key] = value | ||
return d | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser(description="Batch compile PyTorch models to TorchScript.") | ||
parser.add_argument("config", type=str, help="Path to the input YAML config file.") | ||
args = parser.parse_args() | ||
|
||
with open(args.config) as file: | ||
config = yaml.load(file.read(), Loader=yaml.SafeLoader) | ||
|
||
compile_from_config(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
aimnet2_b973c: | ||
model: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_d3.yaml | ||
sae: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c.sae | ||
models: | ||
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_0.pt | ||
jpt: aimnet2_b973c_d3_0.jpt | ||
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_1.pt | ||
jpt: aimnet2_b973c_d3_1.jpt | ||
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_2.pt | ||
jpt: aimnet2_b973c_d3_2.jpt | ||
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_b973c_3.pt | ||
jpt: aimnet2_b973c_d3_3.jpt | ||
|
||
aimnet2_wb97m: | ||
model: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_d3.yaml | ||
sae: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m.sae | ||
models: | ||
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_0.pt | ||
jpt: aimnet2_wb97m_d3_0.jpt | ||
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_1.pt | ||
jpt: aimnet2_wb97m_d3_1.jpt | ||
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_2.pt | ||
jpt: aimnet2_wb97m_d3_2.jpt | ||
- pt: https://storage.googleapis.com/aimnetcentral/AIMNet2/aimnet2_wb97m_3.pt | ||
jpt: aimnet2_wb97m_d3_3.jpt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters