Skip to content

Commit

Permalink
Merge branch 'ACEsuit:main' into update_workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
alinelena authored Dec 17, 2024
2 parents 72a8b6d + c8f2d61 commit cf3e202
Show file tree
Hide file tree
Showing 28 changed files with 1,779 additions and 134 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ dist/
*.xyz
/checkpoints
*.model

.benchmarks
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@ repos:
'--disable=cell-var-from-loop',
'--disable=duplicate-code',
'--disable=use-dict-literal',
'--max-module-lines=1500',
]
exclude: *exclude_files
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
- [Training](#training)
- [Evaluation](#evaluation)
- [Tutorials](#tutorials)
- [CUDA acceleration with cuEquivariance](#cuda-acceleration-with-cuequivariance)
- [Weights and Biases for experiment tracking](#weights-and-biases-for-experiment-tracking)
- [Pretrained Foundation Models](#pretrained-foundation-models)
- [MACE-MP: Materials Project Force Fields](#mace-mp-materials-project-force-fields)
Expand Down Expand Up @@ -171,6 +172,9 @@ We also have a more detailed Colab tutorials on:
- [Introduction to MACE active learning and fine-tuning](https://colab.research.google.com/drive/1oCSVfMhWrqHTeHbKgUSQN9hTKxLzoNyb)
- [MACE theory and code (advanced)](https://colab.research.google.com/drive/1AlfjQETV_jZ0JQnV5M3FGwAM2SGCl2aU)

## CUDA acceleration with cuEquivariance

MACE supports CUDA acceleration with the cuEquivariance library. To install the library and use the acceleration, see our documentation at https://mace-docs.readthedocs.io/en/latest/guide/cuda_acceleration.html.

## On-line data loading for large datasets

Expand Down
2 changes: 1 addition & 1 deletion mace/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.3.8"
__version__ = "0.3.9"

__all__ = ["__version__"]
9 changes: 7 additions & 2 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,16 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
"small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
"medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
"large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model",
"small-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model",
"medium-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_medium.model",
"small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model",
"medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model",
"large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
}

checkpoint_url = (
urls.get(model, urls["medium"])
if model in (None, "small", "medium", "large")
if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2")
else model
)

Expand Down Expand Up @@ -101,7 +106,7 @@ def mace_mp(
MACECalculator: trained on the MPtrj dataset (unless model otherwise specified).
"""
try:
if model in (None, "small", "medium", "large") or str(model).startswith(
if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") or str(model).startswith(
"https:"
):
model_path = download_mace_mp_checkpoint(model)
Expand Down
35 changes: 28 additions & 7 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import torch
from ase.calculators.calculator import Calculator, all_changes
from ase.stress import full_3x3_to_voigt_6_stress
from e3nn import o3

from mace import data
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
from mace.modules.utils import extract_invariant
from mace.tools import torch_geometric, torch_tools, utils
from mace.tools.compile import prepare
Expand Down Expand Up @@ -60,10 +62,13 @@ def __init__(
model_type="MACE",
compile_mode=None,
fullgraph=True,
enable_cueq=False,
**kwargs,
):
Calculator.__init__(self, **kwargs)

if enable_cueq:
assert model_type == "MACE", "CuEq only supports MACE models"
compile_mode = None
if "model_path" in kwargs:
deprecation_message = (
"'model_path' argument is deprecated, please use 'model_paths'"
Expand Down Expand Up @@ -130,6 +135,12 @@ def __init__(
torch.load(f=model_path, map_location=device)
for model_path in model_paths
]
if enable_cueq:
print("Converting models to CuEq for acceleration")
self.models = [
run_e3nn_to_cueq(model, device=device).to(device)
for model in self.models
]

elif models is not None:
if not isinstance(models, list):
Expand Down Expand Up @@ -390,24 +401,34 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
atoms = self.atoms
if self.model_type != "MACE":
raise NotImplementedError("Only implemented for MACE models")
num_interactions = int(self.models[0].num_interactions)
if num_layers == -1:
num_layers = int(self.models[0].num_interactions)
num_layers = num_interactions
batch = self._atoms_to_batch(atoms)
descriptors = [model(batch.to_dict())["node_feats"] for model in self.models]

irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out))
l_max = irreps_out.lmax
num_invariant_features = irreps_out.dim // (l_max + 1) ** 2
per_layer_features = [irreps_out.dim for _ in range(num_interactions)]
per_layer_features[-1] = (
num_invariant_features # Equivariant features not created for the last layer
)

if invariants_only:
irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"]
l_max = irreps_out.lmax
num_features = irreps_out.dim // (l_max + 1) ** 2
descriptors = [
extract_invariant(
descriptor,
num_layers=num_layers,
num_features=num_features,
num_features=num_invariant_features,
l_max=l_max,
)
for descriptor in descriptors
]
descriptors = [descriptor.detach().cpu().numpy() for descriptor in descriptors]
to_keep = np.sum(per_layer_features[:num_layers])
descriptors = [
descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors
]

if self.num_models == 1:
return descriptors[0]
Expand Down
193 changes: 193 additions & 0 deletions mace/cli/convert_cueq_e3nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import argparse
import logging
import os
from typing import Dict, List, Tuple

import torch

from mace.tools.scripts_utils import extract_config_mace_model


def get_transfer_keys() -> List[str]:
"""Get list of keys that need to be transferred"""
return [
"node_embedding.linear.weight",
"radial_embedding.bessel_fn.bessel_weights",
"atomic_energies_fn.atomic_energies",
"readouts.0.linear.weight",
"scale_shift.scale",
"scale_shift.shift",
*[f"readouts.1.linear_{i}.weight" for i in range(1, 3)],
] + [
s
for j in range(2)
for s in [
f"interactions.{j}.linear_up.weight",
*[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)],
f"interactions.{j}.linear.weight",
f"interactions.{j}.skip_tp.weight",
f"products.{j}.linear.weight",
]
]


def get_kmax_pairs(max_L: int, correlation: int) -> List[Tuple[int, int]]:
"""Determine kmax pairs based on max_L and correlation"""
if correlation == 2:
raise NotImplementedError("Correlation 2 not supported yet")
if correlation == 3:
return [[0, max_L], [1, 0]]
raise NotImplementedError(f"Correlation {correlation} not supported")


def transfer_symmetric_contractions(
source_dict: Dict[str, torch.Tensor],
target_dict: Dict[str, torch.Tensor],
max_L: int,
correlation: int,
):
"""Transfer symmetric contraction weights from CuEq to E3nn format"""
kmax_pairs = get_kmax_pairs(max_L, correlation)

for i, kmax in kmax_pairs:
# Get the combined weight tensor from source
wm = source_dict[f"products.{i}.symmetric_contractions.weight"]

# Get split sizes based on target dimensions
splits = []
for k in range(kmax + 1):
for suffix in ["_max", ".0", ".1"]:
key = f"products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}"
target_shape = target_dict[key].shape
splits.append(target_shape[1])

# Split the weights using the calculated sizes
weights_split = torch.split(wm, splits, dim=1)

# Assign back to target dictionary
idx = 0
for k in range(kmax + 1):
target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights_max"
] = weights_split[idx]
target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights.0"
] = weights_split[idx + 1]
target_dict[
f"products.{i}.symmetric_contractions.contractions.{k}.weights.1"
] = weights_split[idx + 2]
idx += 3


def transfer_weights(
source_model: torch.nn.Module,
target_model: torch.nn.Module,
max_L: int,
correlation: int,
):
"""Transfer weights from CuEq to E3nn format"""
# Get state dicts
source_dict = source_model.state_dict()
target_dict = target_model.state_dict()

# Transfer main weights
transfer_keys = get_transfer_keys()
for key in transfer_keys:
if key in source_dict: # Check if key exists
target_dict[key] = source_dict[key]
else:
logging.warning(f"Key {key} not found in source model")

# Transfer symmetric contractions
transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation)

# Transfer remaining matching keys
transferred_keys = set(transfer_keys)
remaining_keys = (
set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys
)
remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k}

if remaining_keys:
for key in remaining_keys:
if source_dict[key].shape == target_dict[key].shape:
logging.debug(f"Transferring additional key: {key}")
target_dict[key] = source_dict[key]
else:
logging.warning(
f"Shape mismatch for key {key}: "
f"source {source_dict[key].shape} vs target {target_dict[key].shape}"
)

# Transfer avg_num_neighbors
for i in range(2):
target_model.interactions[i].avg_num_neighbors = source_model.interactions[
i
].avg_num_neighbors

# Load state dict into target model
target_model.load_state_dict(target_dict)


def run(input_model, output_model="_e3nn.model", device="cpu", return_model=True):

# Load CuEq model
if isinstance(input_model, str):
source_model = torch.load(input_model, map_location=device)
else:
source_model = input_model
default_dtype = next(source_model.parameters()).dtype
torch.set_default_dtype(default_dtype)
# Extract configuration
config = extract_config_mace_model(source_model)

# Get max_L and correlation from config
max_L = config["hidden_irreps"].lmax
correlation = config["correlation"]

# Remove CuEq config
config.pop("cueq_config", None)

# Create new model without CuEq config
logging.info("Creating new model without CuEq settings")
target_model = source_model.__class__(**config)

# Transfer weights with proper remapping
transfer_weights(source_model, target_model, max_L, correlation)

if return_model:
return target_model

# Save model
if isinstance(input_model, str):
base = os.path.splitext(input_model)[0]
output_model = f"{base}.{output_model}"
logging.warning(f"Saving E3nn model to {output_model}")
torch.save(target_model, output_model)
return None


def main():
parser = argparse.ArgumentParser()
parser.add_argument("input_model", help="Path to input CuEq model")
parser.add_argument(
"--output_model", help="Path to output E3nn model", default="e3nn_model.pt"
)
parser.add_argument("--device", default="cpu", help="Device to use")
parser.add_argument(
"--return_model",
action="store_false",
help="Return model instead of saving to file",
)
args = parser.parse_args()

run(
input_model=args.input_model,
output_model=args.output_model,
device=args.device,
return_model=args.return_model,
)


if __name__ == "__main__":
main()
Loading

0 comments on commit cf3e202

Please sign in to comment.