Skip to content

Commit

Permalink
Merge branch 'ACEsuit:main' into lbfgs-multi-gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
ttompa authored Jan 12, 2025
2 parents 4532b12 + 49293b8 commit 1bf4bb7
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 58 deletions.
2 changes: 1 addition & 1 deletion mace/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.3.9"
__version__ = "0.3.10"

__all__ = ["__version__"]
26 changes: 16 additions & 10 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

module_dir = os.path.dirname(__file__)
local_model_path = os.path.join(
module_dir, "foundations_models/2023-12-03-mace-mp.model"
module_dir, "foundations_models/mace-mpa-0-medium.model"
)


Expand All @@ -26,7 +26,7 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
Returns:
str: Path to the downloaded (or cached, if previously loaded) checkpoint file.
"""
if model in (None, "medium") and os.path.isfile(local_model_path):
if model in (None, "medium-mpa-0") and os.path.isfile(local_model_path):
return local_model_path

urls = {
Expand All @@ -38,10 +38,12 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
"small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model",
"medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model",
"large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
"medium-0b3": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model",
"medium-mpa-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
}

checkpoint_url = (
urls.get(model, urls["medium"])
urls.get(model, urls["medium-mpa-0"])
if model
in (
None,
Expand All @@ -53,13 +55,18 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
"small-0b2",
"medium-0b2",
"large-0b2",
"medium-0b3",
"medium-mpa-0",
)
else model
)

cache_dir = (
Path(os.environ.get("XDG_CACHE_HOME", "~/")).expanduser() / ".cache/mace"
)
if checkpoint_url == urls["medium-mpa-0"]:
print(
"Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument"
)

cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = "".join(
c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_"
)
Expand Down Expand Up @@ -124,10 +131,12 @@ def mace_mp(
"small",
"medium",
"large",
"medium-mpa-0",
"small-0b",
"medium-0b",
"small-0b2",
"medium-0b2",
"medium-0b3",
"large-0b2",
) or str(model).startswith("https:"):
model_path = download_mace_mp_checkpoint(model)
Expand Down Expand Up @@ -219,10 +228,7 @@ def mace_off(
if model in (None, "small", "medium", "large")
else model
)
cache_dir = (
Path(os.environ.get("XDG_CACHE_HOME", "~/")).expanduser()
/ ".cache/mace"
)
cache_dir = os.path.expanduser("~/.cache/mace")
checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0]
cached_model_path = f"{cache_dir}/{checkpoint_url_name}"
if not os.path.isfile(cached_model_path):
Expand Down
Binary file not shown.
35 changes: 28 additions & 7 deletions mace/cli/select_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,47 @@

def main():
parser = ArgumentParser()
parser.add_argument(
grp = parser.add_mutually_exclusive_group()
grp.add_argument(
"--head_name",
"-n",
help="name of the head to extract",
default=None,
)
grp.add_argument(
"--list_heads",
"-l",
action="store_true",
help="list names of the heads",
)
parser.add_argument(
"--target_device",
"-d",
help="target device, defaults to model's current device",
)
parser.add_argument(
"--output_file",
"-o",
help="name for output model, defaults to model_file.target_device",
help="name for output model, defaults to model.head_name, followed by .target_device if specified",
)
parser.add_argument("model_file", help="input model file path")
args = parser.parse_args()

if args.output_file is None:
args.output_file = args.model_file + "." + args.target_device

model = torch.load(args.model_file)
model_single = remove_pt_head(model, args.head_name)
torch.save(model_single, args.output_file)

if args.list_heads:
print("Available heads:")
print("\n".join([" " + h for h in model.heads]))
else:

if args.output_file is None:
args.output_file = args.model_file + "." + args.head_name + ("." + args.target_device if (args.target_device is not None) else "")

model_single = remove_pt_head(model, args.head_name)
if args.target_device is not None:
target_device = str(next(model.parameters()).device)
model_single.to(target_device)
torch.save(model_single, args.output_file)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
)
edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e")
if pair_repulsion:
self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff)
self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff)
self.pair_repulsion = True

sh_irreps = o3.Irreps.spherical_harmonics(max_ell)
Expand Down
81 changes: 44 additions & 37 deletions mace/modules/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################

import logging

import ase
import numpy as np
import torch
Expand Down Expand Up @@ -110,67 +112,70 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]

@compile_mode("script")
class PolynomialCutoff(torch.nn.Module):
"""
Equation (8)
"""Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max.
Equation (8) -- TODO: from where?
"""

p: torch.Tensor
r_max: torch.Tensor

def __init__(self, r_max: float, p=6):
super().__init__()
self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
self.register_buffer("p", torch.tensor(p, dtype=torch.int))
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# yapf: disable
return self.calculate_envelope(x, self.r_max, self.p.to(torch.int))

@staticmethod
def calculate_envelope(
x: torch.Tensor, r_max: torch.Tensor, p: int
) -> torch.Tensor:
r_over_r_max = x / r_max
envelope = (
1.0
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2)
1.0
- ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p)
+ p * (p + 2.0) * torch.pow(r_over_r_max, p + 1)
- (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2)
)
# yapf: enable

# noinspection PyUnresolvedReferences
return envelope * (x < self.r_max)
return envelope * (x < r_max)

def __repr__(self):
return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})"


@compile_mode("script")
class ZBLBasis(torch.nn.Module):
"""
Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
"""Implementation of the Ziegler-Biersack-Littmark (ZBL) potential
with a polynomial cutoff envelope.
"""

p: torch.Tensor
r_max: torch.Tensor

def __init__(self, r_max: float, p=6, trainable=False):
def __init__(self, p=6, trainable=False, **kwargs):
super().__init__()
self.register_buffer(
"r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
)
if "r_max" in kwargs:
logging.warning(
"r_max is deprecated. r_max is determined from the covalent radii."
)

# Pre-calculate the p coefficients for the ZBL potential
self.register_buffer(
"c",
torch.tensor(
[0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype()
),
)
self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
self.register_buffer("p", torch.tensor(p, dtype=torch.int))
self.register_buffer(
"covalent_radii",
torch.tensor(
ase.data.covalent_radii,
dtype=torch.get_default_dtype(),
),
)
self.cutoff = PolynomialCutoff(r_max, p)
if trainable:
self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True))
self.a_prefactor = torch.nn.Parameter(
Expand Down Expand Up @@ -208,12 +213,7 @@ def forward(
)
v_edges = (14.3996 * Z_u * Z_v) / x * phi
r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v]
envelope = (
1.0
- ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / r_max, self.p)
+ self.p * (self.p + 2.0) * torch.pow(x / r_max, self.p + 1)
- (self.p * (self.p + 1.0) / 2) * torch.pow(x / r_max, self.p + 2)
) * (x < r_max)
envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p)
v_edges = 0.5 * v_edges * envelope
V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0))
return V_ZBL.squeeze(-1)
Expand All @@ -224,8 +224,8 @@ def __repr__(self):

@compile_mode("script")
class AgnesiTransform(torch.nn.Module):
"""
Agnesi transform see ACEpotentials.jl, JCP 2023, p. 160
"""Agnesi transform - see section on Radial transformations in
ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783).
"""

def __init__(
Expand Down Expand Up @@ -265,21 +265,27 @@ def forward(
)
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0 = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v])
r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v])
r_over_r_0 = x / r_0
return (
1 + (self.a * ((x / r_0) ** self.q) / (1 + (x / r_0) ** (self.q - self.p)))
) ** (-1)
1
+ (
self.a
* torch.pow(r_over_r_0, self.q)
/ (1 + torch.pow(r_over_r_0, self.q - self.p))
)
).reciprocal_()

def __repr__(self):
return f"{self.__class__.__name__}(a={self.a}, q={self.q}, p={self.p})"
return (
f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})"
)


@simplify_if_compile
@compile_mode("script")
class SoftTransform(torch.nn.Module):
"""
Soft Transform
"""
"""Soft Transform."""

def __init__(self, a: float = 0.2, b: float = 3.0, trainable=False):
super().__init__()
Expand Down Expand Up @@ -312,9 +318,10 @@ def forward(
Z_u = node_atomic_numbers[sender]
Z_v = node_atomic_numbers[receiver]
r_0 = (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) / 4
r_over_r_0 = x / r_0
y = (
x
+ (1 / 2) * torch.tanh(-(x / r_0) - self.a * ((x / r_0) ** self.b))
+ (1 / 2) * torch.tanh(-r_over_r_0 - self.a * torch.pow(r_over_r_0, self.b))
+ 1 / 2
)
return y
Expand Down
Loading

0 comments on commit 1bf4bb7

Please sign in to comment.