diff --git a/mace/__version__.py b/mace/__version__.py index 17eec33d..5bb0cfce 100644 --- a/mace/__version__.py +++ b/mace/__version__.py @@ -1,3 +1,3 @@ -__version__ = "0.3.9" +__version__ = "0.3.10" __all__ = ["__version__"] diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index f479d9c4..74b117cf 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -11,7 +11,7 @@ module_dir = os.path.dirname(__file__) local_model_path = os.path.join( - module_dir, "foundations_models/2023-12-03-mace-mp.model" + module_dir, "foundations_models/mace-mpa-0-medium.model" ) @@ -26,7 +26,7 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: Returns: str: Path to the downloaded (or cached, if previously loaded) checkpoint file. """ - if model in (None, "medium") and os.path.isfile(local_model_path): + if model in (None, "medium-mpa-0") and os.path.isfile(local_model_path): return local_model_path urls = { @@ -38,10 +38,12 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: "small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", "medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model", "large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model", + "medium-0b3": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model", + "medium-mpa-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model", } checkpoint_url = ( - urls.get(model, urls["medium"]) + urls.get(model, urls["medium-mpa-0"]) if model in ( None, @@ -53,13 +55,18 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: "small-0b2", "medium-0b2", "large-0b2", + "medium-0b3", + "medium-mpa-0", ) else model ) - cache_dir = ( - Path(os.environ.get("XDG_CACHE_HOME", "~/")).expanduser() / ".cache/mace" - ) + if checkpoint_url == urls["medium-mpa-0"]: + print( + "Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument" + ) + + cache_dir = os.path.expanduser("~/.cache/mace") checkpoint_url_name = "".join( c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" ) @@ -124,10 +131,12 @@ def mace_mp( "small", "medium", "large", + "medium-mpa-0", "small-0b", "medium-0b", "small-0b2", "medium-0b2", + "medium-0b3", "large-0b2", ) or str(model).startswith("https:"): model_path = download_mace_mp_checkpoint(model) @@ -219,10 +228,7 @@ def mace_off( if model in (None, "small", "medium", "large") else model ) - cache_dir = ( - Path(os.environ.get("XDG_CACHE_HOME", "~/")).expanduser() - / ".cache/mace" - ) + cache_dir = os.path.expanduser("~/.cache/mace") checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0] cached_model_path = f"{cache_dir}/{checkpoint_url_name}" if not os.path.isfile(cached_model_path): diff --git a/mace/calculators/foundations_models/mace-mpa-0-medium.model b/mace/calculators/foundations_models/mace-mpa-0-medium.model new file mode 100644 index 00000000..962c5918 Binary files /dev/null and b/mace/calculators/foundations_models/mace-mpa-0-medium.model differ diff --git a/mace/cli/select_head.py b/mace/cli/select_head.py index a1e27229..6141d50e 100644 --- a/mace/cli/select_head.py +++ b/mace/cli/select_head.py @@ -7,26 +7,47 @@ def main(): parser = ArgumentParser() - parser.add_argument( + grp = parser.add_mutually_exclusive_group() + grp.add_argument( "--head_name", "-n", help="name of the head to extract", default=None, ) + grp.add_argument( + "--list_heads", + "-l", + action="store_true", + help="list names of the heads", + ) + parser.add_argument( + "--target_device", + "-d", + help="target device, defaults to model's current device", + ) parser.add_argument( "--output_file", "-o", - help="name for output model, defaults to model_file.target_device", + help="name for output model, defaults to model.head_name, followed by .target_device if specified", ) parser.add_argument("model_file", help="input model file path") args = parser.parse_args() - if args.output_file is None: - args.output_file = args.model_file + "." + args.target_device - model = torch.load(args.model_file) - model_single = remove_pt_head(model, args.head_name) - torch.save(model_single, args.output_file) + + if args.list_heads: + print("Available heads:") + print("\n".join([" " + h for h in model.heads])) + else: + + if args.output_file is None: + args.output_file = args.model_file + "." + args.head_name + ("." + args.target_device if (args.target_device is not None) else "") + + model_single = remove_pt_head(model, args.head_name) + if args.target_device is not None: + target_device = str(next(model.parameters()).device) + model_single.to(target_device) + torch.save(model_single, args.output_file) if __name__ == "__main__": diff --git a/mace/modules/models.py b/mace/modules/models.py index 0e03317e..ebee0b7a 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -96,7 +96,7 @@ def __init__( ) edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") if pair_repulsion: - self.pair_repulsion_fn = ZBLBasis(r_max=r_max, p=num_polynomial_cutoff) + self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff) self.pair_repulsion = True sh_irreps = o3.Irreps.spherical_harmonics(max_ell) diff --git a/mace/modules/radial.py b/mace/modules/radial.py index a928c184..cae2aa71 100644 --- a/mace/modules/radial.py +++ b/mace/modules/radial.py @@ -4,6 +4,8 @@ # This program is distributed under the MIT License (see MIT.md) ########################################################################################### +import logging + import ase import numpy as np import torch @@ -110,8 +112,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] @compile_mode("script") class PolynomialCutoff(torch.nn.Module): - """ - Equation (8) + """Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max. + Equation (8) -- TODO: from where? """ p: torch.Tensor @@ -119,23 +121,26 @@ class PolynomialCutoff(torch.nn.Module): def __init__(self, r_max: float, p=6): super().__init__() - self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer("p", torch.tensor(p, dtype=torch.int)) self.register_buffer( "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) ) def forward(self, x: torch.Tensor) -> torch.Tensor: - # yapf: disable + return self.calculate_envelope(x, self.r_max, self.p.to(torch.int)) + + @staticmethod + def calculate_envelope( + x: torch.Tensor, r_max: torch.Tensor, p: int + ) -> torch.Tensor: + r_over_r_max = x / r_max envelope = ( - 1.0 - - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p) - + self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1) - - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2) + 1.0 + - ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p) + + p * (p + 2.0) * torch.pow(r_over_r_max, p + 1) + - (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2) ) - # yapf: enable - - # noinspection PyUnresolvedReferences - return envelope * (x < self.r_max) + return envelope * (x < r_max) def __repr__(self): return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})" @@ -143,18 +148,19 @@ def __repr__(self): @compile_mode("script") class ZBLBasis(torch.nn.Module): - """ - Implementation of the Ziegler-Biersack-Littmark (ZBL) potential + """Implementation of the Ziegler-Biersack-Littmark (ZBL) potential + with a polynomial cutoff envelope. """ p: torch.Tensor - r_max: torch.Tensor - def __init__(self, r_max: float, p=6, trainable=False): + def __init__(self, p=6, trainable=False, **kwargs): super().__init__() - self.register_buffer( - "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) - ) + if "r_max" in kwargs: + logging.warning( + "r_max is deprecated. r_max is determined from the covalent radii." + ) + # Pre-calculate the p coefficients for the ZBL potential self.register_buffer( "c", @@ -162,7 +168,7 @@ def __init__(self, r_max: float, p=6, trainable=False): [0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype() ), ) - self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer("p", torch.tensor(p, dtype=torch.int)) self.register_buffer( "covalent_radii", torch.tensor( @@ -170,7 +176,6 @@ def __init__(self, r_max: float, p=6, trainable=False): dtype=torch.get_default_dtype(), ), ) - self.cutoff = PolynomialCutoff(r_max, p) if trainable: self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True)) self.a_prefactor = torch.nn.Parameter( @@ -208,12 +213,7 @@ def forward( ) v_edges = (14.3996 * Z_u * Z_v) / x * phi r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] - envelope = ( - 1.0 - - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / r_max, self.p) - + self.p * (self.p + 2.0) * torch.pow(x / r_max, self.p + 1) - - (self.p * (self.p + 1.0) / 2) * torch.pow(x / r_max, self.p + 2) - ) * (x < r_max) + envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p) v_edges = 0.5 * v_edges * envelope V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0)) return V_ZBL.squeeze(-1) @@ -224,8 +224,8 @@ def __repr__(self): @compile_mode("script") class AgnesiTransform(torch.nn.Module): - """ - Agnesi transform see ACEpotentials.jl, JCP 2023, p. 160 + """Agnesi transform - see section on Radial transformations in + ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783). """ def __init__( @@ -265,21 +265,27 @@ def forward( ) Z_u = node_atomic_numbers[sender] Z_v = node_atomic_numbers[receiver] - r_0 = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) + r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) + r_over_r_0 = x / r_0 return ( - 1 + (self.a * ((x / r_0) ** self.q) / (1 + (x / r_0) ** (self.q - self.p))) - ) ** (-1) + 1 + + ( + self.a + * torch.pow(r_over_r_0, self.q) + / (1 + torch.pow(r_over_r_0, self.q - self.p)) + ) + ).reciprocal_() def __repr__(self): - return f"{self.__class__.__name__}(a={self.a}, q={self.q}, p={self.p})" + return ( + f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})" + ) @simplify_if_compile @compile_mode("script") class SoftTransform(torch.nn.Module): - """ - Soft Transform - """ + """Soft Transform.""" def __init__(self, a: float = 0.2, b: float = 3.0, trainable=False): super().__init__() @@ -312,9 +318,10 @@ def forward( Z_u = node_atomic_numbers[sender] Z_v = node_atomic_numbers[receiver] r_0 = (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) / 4 + r_over_r_0 = x / r_0 y = ( x - + (1 / 2) * torch.tanh(-(x / r_0) - self.a * ((x / r_0) ** self.b)) + + (1 / 2) * torch.tanh(-r_over_r_0 - self.a * torch.pow(r_over_r_0, self.b)) + 1 / 2 ) return y diff --git a/tests/modules/test_radial.py b/tests/modules/test_radial.py new file mode 100644 index 00000000..1d8a0c6d --- /dev/null +++ b/tests/modules/test_radial.py @@ -0,0 +1,83 @@ +import pytest +import torch +from mace.modules.radial import ZBLBasis, AgnesiTransform + +@pytest.fixture +def zbl_basis(): + return ZBLBasis(p=6, trainable=False) + +def test_zbl_basis_initialization(zbl_basis): + assert zbl_basis.p == torch.tensor(6.0) + assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817])) + + assert zbl_basis.a_exp == torch.tensor(0.300) + assert zbl_basis.a_prefactor == torch.tensor(0.4543) + assert not zbl_basis.a_exp.requires_grad + assert not zbl_basis.a_prefactor.requires_grad + +def test_trainable_zbl_basis_initialization(zbl_basis): + zbl_basis = ZBLBasis(p=6, trainable=True) + assert zbl_basis.p == torch.tensor(6.0) + assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817])) + + assert zbl_basis.a_exp == torch.tensor(0.300) + assert zbl_basis.a_prefactor == torch.tensor(0.4543) + assert zbl_basis.a_exp.requires_grad + assert zbl_basis.a_prefactor.requires_grad + +def test_forward(zbl_basis): + x = torch.tensor([1.0, 1.0, 2.0]).unsqueeze(-1) # [n_edges] + node_attrs = torch.tensor([[1, 0], [0, 1]]) # [n_nodes, n_node_features] - one_hot encoding of atomic numbers + edge_index = torch.tensor([[0, 1, 1], [1, 0, 1]]) # [2, n_edges] + atomic_numbers = torch.tensor([1, 6]) # [n_nodes] + output = zbl_basis(x, node_attrs, edge_index, atomic_numbers) + + assert output.shape == torch.Size([node_attrs.shape[0]]) + assert torch.is_tensor(output) + assert torch.allclose( + output, + torch.tensor([0.0031, 0.0031], dtype=torch.get_default_dtype()), + rtol=1e-2 + ) + +@pytest.fixture +def agnesi(): + return AgnesiTransform(trainable=False) + +def test_agnesi_transform_initialization(agnesi: AgnesiTransform): + assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4) + assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4) + assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4) + assert not agnesi.a.requires_grad + assert not agnesi.q.requires_grad + assert not agnesi.p.requires_grad + +def test_trainable_agnesi_transform_initialization(): + agnesi = AgnesiTransform(trainable=True) + + assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4) + assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4) + assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4) + assert agnesi.a.requires_grad + assert agnesi.q.requires_grad + assert agnesi.p.requires_grad + +def test_agnesi_transform_forward(): + agnesi = AgnesiTransform() + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.get_default_dtype()).unsqueeze(-1) + node_attrs = torch.tensor([[0, 1], [1, 0], [0, 1]], dtype=torch.get_default_dtype()) + edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) + atomic_numbers = torch.tensor([1, 6, 8]) + output = agnesi(x, node_attrs, edge_index, atomic_numbers) + assert output.shape == x.shape + assert torch.is_tensor(output) + assert torch.allclose( + output, + torch.tensor( + [0.3646, 0.2175, 0.2089], dtype=torch.get_default_dtype() + ).unsqueeze(-1), + rtol=1e-2 + ) + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/test_foundations.py b/tests/test_foundations.py index 44879395..a50716db 100644 --- a/tests/test_foundations.py +++ b/tests/test_foundations.py @@ -146,7 +146,7 @@ def test_multi_reference(): heads=["MP2", "DFT"], ) model = modules.ScaleShiftMACE(**model_config) - calc_foundation = mace_mp(device="cpu", default_dtype="float64") + calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") model_loaded = load_foundations_elements( model, calc_foundation.models[0], @@ -166,7 +166,7 @@ def test_multi_reference(): ) batch = next(iter(data_loader)) forces_loaded = model_loaded(batch)["forces"] - calc_foundation = mace_mp(device="cpu", default_dtype="float64") + calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") atoms = molecule("H2COH") atoms.info["head"] = "MP2" atoms.calc = calc_foundation