Skip to content

Commit

Permalink
improve stability visualization; update leaderborad
Browse files Browse the repository at this point in the history
  • Loading branch information
chiang-yuan committed Jul 16, 2024
1 parent 606b930 commit 3b3aaa9
Show file tree
Hide file tree
Showing 14 changed files with 331 additions and 262 deletions.
5 changes: 3 additions & 2 deletions mlip_arena/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from __future__ import annotations

from pathlib import Path

import torch
Expand All @@ -10,7 +11,7 @@

# from torch_geometric.data import Data

with open(Path(__file__).parent / "registry.yaml") as f:
with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
REGISTRY = yaml.safe_load(f)


Expand Down
28 changes: 28 additions & 0 deletions mlip_arena/models/externals.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,34 @@ def __init__(self, device=None, default_dtype="float32", **kwargs):
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
)

class MACE_OFF_Medium(MACECalculator):
def __init__(self, device=None, default_dtype="float32", **kwargs):
checkpoint_url = "https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true"
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
os.makedirs(cache_dir, exist_ok=True)
# download and save to disk
print(f"Downloading MACE model from {checkpoint_url!r}")
_, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path)
if "Content-Type: text/html" in http_msg:
raise RuntimeError(
f"Model download failed, please check the URL {checkpoint_url}"
)
print(f"Cached MACE model to {cached_model_path}")
model = cached_model_path
msg = f"Using Materials Project MACE for MACECalculator with {model}"
print(msg)

device = device or str(get_freer_device())

super().__init__(
model_paths=model, device=device, default_dtype=default_dtype, **kwargs
)


class CHGNet(CHGNetCalculator):
def __init__(
Expand Down
74 changes: 58 additions & 16 deletions mlip_arena/models/registry.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@


MACE-MP(M):
module: externals
class: MACE_MP_Medium
family: mace-mp
username: cyrusyc # HF username
username: cyrusyc
last-update: 2024-03-25T14:30:00
datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
datasets: # list of training datasets
datasets:
- atomind/mptrj # TODO: fake HF dataset repo
cpu-tasks:
- alexandria
- qmof
gpu-tasks:
- diatomics
- homonuclear-diatomics
- stability
github: https://github.com/ACEsuit/mace
doi: https://arxiv.org/abs/2401.00096
date: 2023-12-29
prediction: EFS
nvt: true
npt: true

CHGNet:
module: externals
Expand All @@ -28,7 +30,11 @@ CHGNet:
datasets:
- atomind/mptrj
gpu-tasks:
- diatomics
- homonuclear-diatomics
- stability
prediction: EFSM
nvt: true
npt: true

EquiformerV2(OC22):
module: externals
Expand All @@ -40,7 +46,12 @@ EquiformerV2(OC22):
datasets:
- ocp
gpu-tasks:
- diatomics
- homonuclear-diatomics
github: https://github.com/FAIR-Chem/fairchem
doi: https://arxiv.org/abs/2306.12059
prediction: EF
nvt: true
npt: false

eSCN(OC20):
module: externals
Expand All @@ -52,14 +63,45 @@ eSCN(OC20):
datasets:
- ocp
gpu-tasks:
- diatomics
- homonuclear-diatomics
github: https://github.com/FAIR-Chem/fairchem
doi: https://arxiv.org/abs/2302.03655
prediction: EF
nvt: true
npt: false

MACE-OFF(M):
module: externals
class: MACE_OFF_Medium
family: mace-off
username: cyrusyc
last-update: 2024-03-25T14:30:00
datetime: 2024-03-25T14:30:00 # TODO: Fake datetime
datasets:
- atomind/mptrj # TODO: fake HF dataset repo
cpu-tasks:
- alexandria
- qmof
gpu-tasks:
- homonuclear-diatomics
github: https://github.com/ACEsuit/mace
doi: https://arxiv.org/abs/2312.15211
date: 2023-12-29
prediction: EFS
nvt: true
npt: true

# CHGNet:
# module: chgnet
# username: cyrusyc
# datetime: 2024-03-25T14:30:00
# datasets:
# - atomind/mptrj
# cpu-tasks:
# - diatomics
ALIGNN:
module: externals
class: ALIGNN
family: alignn
username: cyrusyc
last-update: 2024-07-08T00:00:00
datetime: 2024-07-08T00:00:00
datasets:
- atomind/mptrj
gpu-tasks:
- homonuclear-diatomics
prediction: EFS
nvt: true
npt: true
5 changes: 4 additions & 1 deletion mlip_arena/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from pathlib import Path

import yaml
Expand All @@ -7,6 +6,10 @@
from mlip_arena.models import MLIP
from mlip_arena.models import REGISTRY as MODEL_REGISTRY

from .run import md as MD

__all__ = ["MD"]

with open(Path(__file__).parent / "registry.yaml") as f:
REGISTRY = yaml.safe_load(f)

Expand Down
165 changes: 0 additions & 165 deletions mlip_arena/tasks/diatomics/mlip.py

This file was deleted.

2 changes: 1 addition & 1 deletion mlip_arena/tasks/stability/chgnet/chloride-salts.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mlip_arena/tasks/stability/mace-mp/chloride-salts.json

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions mlip_arena/tasks/stability/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def md(

if restart and traj_file.exists():
traj = read(traj_file, index=":")
last_step = len(traj)
n_steps -= len(traj)
last_step = traj[-1].info.get("step", len(traj) * traj_interval)
n_steps -= last_step
last_atoms = traj[-1]
traj = Trajectory(traj_file, "a", atoms)
atoms.set_positions(last_atoms.get_positions())
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ torch==2.2.1
pymatgen==2024.4.13
bokeh==2.4.3
statsmodels==0.14.2
# py3Dmol==2.0.0.post2
# stmol==0.0.9
Loading

0 comments on commit 3b3aaa9

Please sign in to comment.