diff --git a/.gitignore b/.gitignore index 3817d9f30..296776e42 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,5 @@ dist/ *.xyz /checkpoints *.model + +.benchmarks diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d78624bbf..6f8c2daa6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,5 +55,6 @@ repos: '--disable=cell-var-from-loop', '--disable=duplicate-code', '--disable=use-dict-literal', + '--max-module-lines=1500', ] exclude: *exclude_files \ No newline at end of file diff --git a/README.md b/README.md index 8481760d1..93db4c87c 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ - [Training](#training) - [Evaluation](#evaluation) - [Tutorials](#tutorials) + - [CUDA acceleration with cuEquivariance](#cuda-acceleration-with-cuequivariance) - [Weights and Biases for experiment tracking](#weights-and-biases-for-experiment-tracking) - [Pretrained Foundation Models](#pretrained-foundation-models) - [MACE-MP: Materials Project Force Fields](#mace-mp-materials-project-force-fields) @@ -171,6 +172,9 @@ We also have a more detailed Colab tutorials on: - [Introduction to MACE active learning and fine-tuning](https://colab.research.google.com/drive/1oCSVfMhWrqHTeHbKgUSQN9hTKxLzoNyb) - [MACE theory and code (advanced)](https://colab.research.google.com/drive/1AlfjQETV_jZ0JQnV5M3FGwAM2SGCl2aU) +## CUDA acceleration with cuEquivariance + +MACE supports CUDA acceleration with the cuEquivariance library. To install the library and use the acceleration, see our documentation at https://mace-docs.readthedocs.io/en/latest/guide/cuda_acceleration.html. ## On-line data loading for large datasets diff --git a/mace/__version__.py b/mace/__version__.py index 2eb279ae4..17eec33de 100644 --- a/mace/__version__.py +++ b/mace/__version__.py @@ -1,3 +1,3 @@ -__version__ = "0.3.8" +__version__ = "0.3.9" __all__ = ["__version__"] diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index ed814f1ab..3ebddbedc 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -33,11 +33,16 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: "small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", "medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", "large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", + "small-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model", + "medium-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_medium.model", + "small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", + "medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model", + "large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model", } checkpoint_url = ( urls.get(model, urls["medium"]) - if model in (None, "small", "medium", "large") + if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") else model ) @@ -101,7 +106,7 @@ def mace_mp( MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). """ try: - if model in (None, "small", "medium", "large") or str(model).startswith( + if model in (None, "small", "medium", "large", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "large-0b2") or str(model).startswith( "https:" ): model_path = download_mace_mp_checkpoint(model) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 0b801bafb..9d3f07ca0 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -14,8 +14,10 @@ import torch from ase.calculators.calculator import Calculator, all_changes from ase.stress import full_3x3_to_voigt_6_stress +from e3nn import o3 from mace import data +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq from mace.modules.utils import extract_invariant from mace.tools import torch_geometric, torch_tools, utils from mace.tools.compile import prepare @@ -60,10 +62,13 @@ def __init__( model_type="MACE", compile_mode=None, fullgraph=True, + enable_cueq=False, **kwargs, ): Calculator.__init__(self, **kwargs) - + if enable_cueq: + assert model_type == "MACE", "CuEq only supports MACE models" + compile_mode = None if "model_path" in kwargs: deprecation_message = ( "'model_path' argument is deprecated, please use 'model_paths'" @@ -130,6 +135,12 @@ def __init__( torch.load(f=model_path, map_location=device) for model_path in model_paths ] + if enable_cueq: + print("Converting models to CuEq for acceleration") + self.models = [ + run_e3nn_to_cueq(model, device=device).to(device) + for model in self.models + ] elif models is not None: if not isinstance(models, list): @@ -390,24 +401,34 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): atoms = self.atoms if self.model_type != "MACE": raise NotImplementedError("Only implemented for MACE models") + num_interactions = int(self.models[0].num_interactions) if num_layers == -1: - num_layers = int(self.models[0].num_interactions) + num_layers = num_interactions batch = self._atoms_to_batch(atoms) descriptors = [model(batch.to_dict())["node_feats"] for model in self.models] + + irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out)) + l_max = irreps_out.lmax + num_invariant_features = irreps_out.dim // (l_max + 1) ** 2 + per_layer_features = [irreps_out.dim for _ in range(num_interactions)] + per_layer_features[-1] = ( + num_invariant_features # Equivariant features not created for the last layer + ) + if invariants_only: - irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"] - l_max = irreps_out.lmax - num_features = irreps_out.dim // (l_max + 1) ** 2 descriptors = [ extract_invariant( descriptor, num_layers=num_layers, - num_features=num_features, + num_features=num_invariant_features, l_max=l_max, ) for descriptor in descriptors ] - descriptors = [descriptor.detach().cpu().numpy() for descriptor in descriptors] + to_keep = np.sum(per_layer_features[:num_layers]) + descriptors = [ + descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors + ] if self.num_models == 1: return descriptors[0] diff --git a/mace/cli/convert_cueq_e3nn.py b/mace/cli/convert_cueq_e3nn.py new file mode 100644 index 000000000..dd7eb9f8b --- /dev/null +++ b/mace/cli/convert_cueq_e3nn.py @@ -0,0 +1,193 @@ +import argparse +import logging +import os +from typing import Dict, List, Tuple + +import torch + +from mace.tools.scripts_utils import extract_config_mace_model + + +def get_transfer_keys() -> List[str]: + """Get list of keys that need to be transferred""" + return [ + "node_embedding.linear.weight", + "radial_embedding.bessel_fn.bessel_weights", + "atomic_energies_fn.atomic_energies", + "readouts.0.linear.weight", + "scale_shift.scale", + "scale_shift.shift", + *[f"readouts.1.linear_{i}.weight" for i in range(1, 3)], + ] + [ + s + for j in range(2) + for s in [ + f"interactions.{j}.linear_up.weight", + *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], + f"interactions.{j}.linear.weight", + f"interactions.{j}.skip_tp.weight", + f"products.{j}.linear.weight", + ] + ] + + +def get_kmax_pairs(max_L: int, correlation: int) -> List[Tuple[int, int]]: + """Determine kmax pairs based on max_L and correlation""" + if correlation == 2: + raise NotImplementedError("Correlation 2 not supported yet") + if correlation == 3: + return [[0, max_L], [1, 0]] + raise NotImplementedError(f"Correlation {correlation} not supported") + + +def transfer_symmetric_contractions( + source_dict: Dict[str, torch.Tensor], + target_dict: Dict[str, torch.Tensor], + max_L: int, + correlation: int, +): + """Transfer symmetric contraction weights from CuEq to E3nn format""" + kmax_pairs = get_kmax_pairs(max_L, correlation) + + for i, kmax in kmax_pairs: + # Get the combined weight tensor from source + wm = source_dict[f"products.{i}.symmetric_contractions.weight"] + + # Get split sizes based on target dimensions + splits = [] + for k in range(kmax + 1): + for suffix in ["_max", ".0", ".1"]: + key = f"products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}" + target_shape = target_dict[key].shape + splits.append(target_shape[1]) + + # Split the weights using the calculated sizes + weights_split = torch.split(wm, splits, dim=1) + + # Assign back to target dictionary + idx = 0 + for k in range(kmax + 1): + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights_max" + ] = weights_split[idx] + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights.0" + ] = weights_split[idx + 1] + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights.1" + ] = weights_split[idx + 2] + idx += 3 + + +def transfer_weights( + source_model: torch.nn.Module, + target_model: torch.nn.Module, + max_L: int, + correlation: int, +): + """Transfer weights from CuEq to E3nn format""" + # Get state dicts + source_dict = source_model.state_dict() + target_dict = target_model.state_dict() + + # Transfer main weights + transfer_keys = get_transfer_keys() + for key in transfer_keys: + if key in source_dict: # Check if key exists + target_dict[key] = source_dict[key] + else: + logging.warning(f"Key {key} not found in source model") + + # Transfer symmetric contractions + transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation) + + # Transfer remaining matching keys + transferred_keys = set(transfer_keys) + remaining_keys = ( + set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys + ) + remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} + + if remaining_keys: + for key in remaining_keys: + if source_dict[key].shape == target_dict[key].shape: + logging.debug(f"Transferring additional key: {key}") + target_dict[key] = source_dict[key] + else: + logging.warning( + f"Shape mismatch for key {key}: " + f"source {source_dict[key].shape} vs target {target_dict[key].shape}" + ) + + # Transfer avg_num_neighbors + for i in range(2): + target_model.interactions[i].avg_num_neighbors = source_model.interactions[ + i + ].avg_num_neighbors + + # Load state dict into target model + target_model.load_state_dict(target_dict) + + +def run(input_model, output_model="_e3nn.model", device="cpu", return_model=True): + + # Load CuEq model + if isinstance(input_model, str): + source_model = torch.load(input_model, map_location=device) + else: + source_model = input_model + default_dtype = next(source_model.parameters()).dtype + torch.set_default_dtype(default_dtype) + # Extract configuration + config = extract_config_mace_model(source_model) + + # Get max_L and correlation from config + max_L = config["hidden_irreps"].lmax + correlation = config["correlation"] + + # Remove CuEq config + config.pop("cueq_config", None) + + # Create new model without CuEq config + logging.info("Creating new model without CuEq settings") + target_model = source_model.__class__(**config) + + # Transfer weights with proper remapping + transfer_weights(source_model, target_model, max_L, correlation) + + if return_model: + return target_model + + # Save model + if isinstance(input_model, str): + base = os.path.splitext(input_model)[0] + output_model = f"{base}.{output_model}" + logging.warning(f"Saving E3nn model to {output_model}") + torch.save(target_model, output_model) + return None + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_model", help="Path to input CuEq model") + parser.add_argument( + "--output_model", help="Path to output E3nn model", default="e3nn_model.pt" + ) + parser.add_argument("--device", default="cpu", help="Device to use") + parser.add_argument( + "--return_model", + action="store_false", + help="Return model instead of saving to file", + ) + args = parser.parse_args() + + run( + input_model=args.input_model, + output_model=args.output_model, + device=args.device, + return_model=args.return_model, + ) + + +if __name__ == "__main__": + main() diff --git a/mace/cli/convert_e3nn_cueq.py b/mace/cli/convert_e3nn_cueq.py new file mode 100644 index 000000000..29966838c --- /dev/null +++ b/mace/cli/convert_e3nn_cueq.py @@ -0,0 +1,189 @@ +import argparse +import logging +import os +from typing import Dict, List, Tuple + +import torch + +from mace.modules.wrapper_ops import CuEquivarianceConfig +from mace.tools.scripts_utils import extract_config_mace_model + + +def get_transfer_keys() -> List[str]: + """Get list of keys that need to be transferred""" + return [ + "node_embedding.linear.weight", + "radial_embedding.bessel_fn.bessel_weights", + "atomic_energies_fn.atomic_energies", + "readouts.0.linear.weight", + "scale_shift.scale", + "scale_shift.shift", + *[f"readouts.1.linear_{i}.weight" for i in range(1, 3)], + ] + [ + s + for j in range(2) + for s in [ + f"interactions.{j}.linear_up.weight", + *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], + f"interactions.{j}.linear.weight", + f"interactions.{j}.skip_tp.weight", + f"products.{j}.linear.weight", + ] + ] + + +def get_kmax_pairs(max_L: int, correlation: int) -> List[Tuple[int, int]]: + """Determine kmax pairs based on max_L and correlation""" + if correlation == 2: + raise NotImplementedError("Correlation 2 not supported yet") + if correlation == 3: + return [[0, max_L], [1, 0]] + raise NotImplementedError(f"Correlation {correlation} not supported") + + +def transfer_symmetric_contractions( + source_dict: Dict[str, torch.Tensor], + target_dict: Dict[str, torch.Tensor], + max_L: int, + correlation: int, +): + """Transfer symmetric contraction weights""" + kmax_pairs = get_kmax_pairs(max_L, correlation) + + for i, kmax in kmax_pairs: + wm = torch.concatenate( + [ + source_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights{j}" + ] + for k in range(kmax + 1) + for j in ["_max", ".0", ".1"] + ], + dim=1, + ) + target_dict[f"products.{i}.symmetric_contractions.weight"] = wm + + +def transfer_weights( + source_model: torch.nn.Module, + target_model: torch.nn.Module, + max_L: int, + correlation: int, +): + """Transfer weights with proper remapping""" + # Get source state dict + source_dict = source_model.state_dict() + target_dict = target_model.state_dict() + + # Transfer main weights + transfer_keys = get_transfer_keys() + for key in transfer_keys: + if key in source_dict: # Check if key exists + target_dict[key] = source_dict[key] + else: + logging.warning(f"Key {key} not found in source model") + + # Transfer symmetric contractions + transfer_symmetric_contractions(source_dict, target_dict, max_L, correlation) + + transferred_keys = set(transfer_keys) + remaining_keys = ( + set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys + ) + remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} + if remaining_keys: + for key in remaining_keys: + if source_dict[key].shape == target_dict[key].shape: + logging.debug(f"Transferring additional key: {key}") + target_dict[key] = source_dict[key] + else: + logging.warning( + f"Shape mismatch for key {key}: " + f"source {source_dict[key].shape} vs target {target_dict[key].shape}" + ) + # Transfer avg_num_neighbors + for i in range(2): + target_model.interactions[i].avg_num_neighbors = source_model.interactions[ + i + ].avg_num_neighbors + + # Load state dict into target model + target_model.load_state_dict(target_dict) + + +def run( + input_model, + output_model="_cueq.model", + device="cpu", + return_model=True, +): + # Setup logging + + # Load original model + # logging.warning(f"Loading model") + # check if input_model is a path or a model + if isinstance(input_model, str): + source_model = torch.load(input_model, map_location=device) + else: + source_model = input_model + default_dtype = next(source_model.parameters()).dtype + torch.set_default_dtype(default_dtype) + # Extract configuration + config = extract_config_mace_model(source_model) + + # Get max_L and correlation from config + max_L = config["hidden_irreps"].lmax + correlation = config["correlation"] + + # Add cuequivariance config + config["cueq_config"] = CuEquivarianceConfig( + enabled=True, + layout="ir_mul", + group="O3_e3nn", + optimize_all=True, + ) + + # Create new model with cuequivariance config + logging.info("Creating new model with cuequivariance settings") + target_model = source_model.__class__(**config).to(device) + + # Transfer weights with proper remapping + transfer_weights(source_model, target_model, max_L, correlation) + + if return_model: + return target_model + + if isinstance(input_model, str): + base = os.path.splitext(input_model)[0] + output_model = f"{base}.{output_model}" + logging.warning(f"Saving CuEq model to {output_model}") + torch.save(target_model, output_model) + return None + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_model", help="Path to input MACE model") + parser.add_argument( + "--output_model", + help="Path to output cuequivariance model", + default="cueq_model.pt", + ) + parser.add_argument("--device", default="cpu", help="Device to use") + parser.add_argument( + "--return_model", + action="store_false", + help="Return model instead of saving to file", + ) + args = parser.parse_args() + + run( + input_model=args.input_model, + output_model=args.output_model, + device=args.device, + return_model=args.return_model, + ) + + +if __name__ == "__main__": + main() diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 3813b055c..1c0898b73 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -24,6 +24,8 @@ import mace from mace import data, tools from mace.calculators.foundations_models import mace_mp, mace_off +from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq from mace.tools import torch_geometric from mace.tools.model_script_utils import configure_model from mace.tools.multihead_tools import ( @@ -158,6 +160,16 @@ def run(args: argparse.Namespace) -> None: args.E0s != "average" ), "average atomic energies cannot be used for multiheads finetuning" # check that the foundation model has a single head, if not, use the first head + if not args.force_mh_ft_lr: + logging.info( + "Multihead finetuning mode, setting learning rate to 0.001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True." + ) + args.lr = 0.001 + args.ema = True + args.ema_decay = 0.999 + logging.info( + "Using multiheads finetuning mode, setting learning rate to 0.001 and EMA to True" + ) if hasattr(model_foundation, "heads"): if len(model_foundation.heads) > 1: logging.warning( @@ -263,7 +275,7 @@ def run(args: argparse.Namespace) -> None: args.loss = "universal" if ( args.foundation_model in ["small", "medium", "large"] - or args.pt_train_file is None + or args.pt_train_file == "mp" ): logging.info( "Using foundation model for multiheads finetuning with Materials Project data" @@ -323,8 +335,21 @@ def run(args: argparse.Namespace) -> None: ) head_config_pt.collections = collections head_configs.append(head_config_pt) + + ratio_pt_ft = size_collections_train / len(head_config_pt.collections.train) + if ratio_pt_ft < 0.1: + logging.warning( + f"Ratio of the number of configurations in the training set and the in the pt_train_file is {ratio_pt_ft}, " + f"increasing the number of configurations in the pt_train_file by a factor of {int(0.1 / ratio_pt_ft)}" + ) + for head_config in head_configs: + if head_config.head_name == "pt_head": + continue + head_config.collections.train += head_config.collections.train * int( + 0.1 / ratio_pt_ft + ) logging.info( - f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}" + f"Total number of configurations in pretraining: train={len(head_config_pt.collections.train)}, valid={len(head_config_pt.collections.valid)}" ) # Atomic number table @@ -549,6 +574,11 @@ def run(args: argparse.Namespace) -> None: logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") logging.info(loss_fn) + # Cueq + if args.enable_cueq: + logging.info("Converting model to CUEQ for accelerated training") + assert model.__class__.__name__ in ["MACE", "ScaleShiftMACE"] + model = run_e3nn_to_cueq(deepcopy(model), device=device) # Optimizer param_options = get_params_options(args, model) optimizer: torch.optim.Optimizer @@ -600,7 +630,6 @@ def run(args: argparse.Namespace) -> None: if args.wandb: setup_wandb(args) - if args.distributed: distributed_model = DDP(model, device_ids=[local_rank]) else: @@ -757,9 +786,14 @@ def run(args: argparse.Namespace) -> None: else: model_path = Path(args.checkpoints_dir) / (tag + ".model") logging.info(f"Saving model to {model_path}") + model_to_save = deepcopy(model) + if args.enable_cueq: + print("RUNING CUEQ TO E3NN") + print("swa_eval", swa_eval) + model_to_save = run_cueq_to_e3nn(deepcopy(model), device=device) if args.save_cpu: - model = model.to("cpu") - torch.save(model, model_path) + model_to_save = model_to_save.to("cpu") + torch.save(model_to_save, model_path) extra_files = { "commit.txt": commit.encode("utf-8") if commit is not None else b"", "config.yaml": json.dumps( @@ -768,14 +802,14 @@ def run(args: argparse.Namespace) -> None: } if swa_eval: torch.save( - model, Path(args.model_dir) / (args.name + "_stagetwo.model") + model_to_save, Path(args.model_dir) / (args.name + "_stagetwo.model") ) try: path_complied = Path(args.model_dir) / ( args.name + "_stagetwo_compiled.model" ) logging.info(f"Compiling model, saving metadata {path_complied}") - model_compiled = jit.compile(deepcopy(model)) + model_compiled = jit.compile(deepcopy(model_to_save)) torch.jit.save( model_compiled, path_complied, @@ -784,13 +818,13 @@ def run(args: argparse.Namespace) -> None: except Exception as e: # pylint: disable=W0703 pass else: - torch.save(model, Path(args.model_dir) / (args.name + ".model")) + torch.save(model_to_save, Path(args.model_dir) / (args.name + ".model")) try: path_complied = Path(args.model_dir) / ( args.name + "_compiled.model" ) logging.info(f"Compiling model, saving metadata to {path_complied}") - model_compiled = jit.compile(deepcopy(model)) + model_compiled = jit.compile(deepcopy(model_to_save)) torch.jit.save( model_compiled, path_complied, diff --git a/mace/data/utils.py b/mace/data/utils.py index bb8e54484..59b868ed3 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -265,7 +265,6 @@ def load_from_xyz( atoms_without_iso_atoms = [] for idx, atoms in enumerate(atoms_list): - atoms.info[head_key] = head_name isolated_atom_config = ( len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" ) @@ -288,6 +287,9 @@ def load_from_xyz( if not keep_isolated_atoms: atoms_list = atoms_without_iso_atoms + for atoms in atoms_list: + atoms.info[head_key] = head_name + configs = config_from_atoms_list( atoms_list, config_type_weights=config_type_weights, diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 0db3b02e1..ea0e228b2 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -12,6 +12,13 @@ from e3nn import nn, o3 from e3nn.util.jit import compile_mode +from mace.modules.wrapper_ops import ( + CuEquivarianceConfig, + FullyConnectedTensorProduct, + Linear, + SymmetricContractionWrapper, + TensorProduct, +) from mace.tools.compile import simplify_if_compile from mace.tools.scatter import scatter_sum @@ -29,14 +36,20 @@ PolynomialCutoff, SoftTransform, ) -from .symmetric_contraction import SymmetricContraction @compile_mode("script") class LinearNodeEmbeddingBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps): + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): super().__init__() - self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out) + self.linear = Linear( + irreps_in=irreps_in, irreps_out=irreps_out, cueq_config=cueq_config + ) def forward( self, @@ -47,9 +60,16 @@ def forward( @compile_mode("script") class LinearReadoutBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, irrep_out: o3.Irreps = o3.Irreps("0e")): + def __init__( + self, + irreps_in: o3.Irreps, + irrep_out: o3.Irreps = o3.Irreps("0e"), + cueq_config: Optional[CuEquivarianceConfig] = None, + ): super().__init__() - self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irrep_out) + self.linear = Linear( + irreps_in=irreps_in, irreps_out=irrep_out, cueq_config=cueq_config + ) def forward( self, @@ -69,13 +89,18 @@ def __init__( gate: Optional[Callable], irrep_out: o3.Irreps = o3.Irreps("0e"), num_heads: int = 1, + cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() self.hidden_irreps = MLP_irreps self.num_heads = num_heads - self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) + self.linear_1 = Linear( + irreps_in=irreps_in, irreps_out=self.hidden_irreps, cueq_config=cueq_config + ) self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) - self.linear_2 = o3.Linear(irreps_in=self.hidden_irreps, irreps_out=irrep_out) + self.linear_2 = Linear( + irreps_in=self.hidden_irreps, irreps_out=irrep_out, cueq_config=cueq_config + ) def forward( self, x: torch.Tensor, heads: Optional[torch.Tensor] = None @@ -89,13 +114,20 @@ def forward( @compile_mode("script") class LinearDipoleReadoutBlock(torch.nn.Module): - def __init__(self, irreps_in: o3.Irreps, dipole_only: bool = False): + def __init__( + self, + irreps_in: o3.Irreps, + dipole_only: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): super().__init__() if dipole_only: self.irreps_out = o3.Irreps("1x1o") else: self.irreps_out = o3.Irreps("1x0e + 1x1o") - self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_out) + self.linear = Linear( + irreps_in=irreps_in, irreps_out=self.irreps_out, cueq_config=cueq_config + ) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] return self.linear(x) # [n_nodes, 1] @@ -109,6 +141,7 @@ def __init__( MLP_irreps: o3.Irreps, gate: Callable, dipole_only: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, ): super().__init__() self.hidden_irreps = MLP_irreps @@ -131,9 +164,13 @@ def __init__( irreps_gated=irreps_gated, ) self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() - self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.irreps_nonlin) - self.linear_2 = o3.Linear( - irreps_in=self.hidden_irreps, irreps_out=self.irreps_out + self.linear_1 = Linear( + irreps_in=irreps_in, irreps_out=self.irreps_nonlin, cueq_config=cueq_config + ) + self.linear_2 = Linear( + irreps_in=self.hidden_irreps, + irreps_out=self.irreps_out, + cueq_config=cueq_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] @@ -218,22 +255,25 @@ def __init__( correlation: int, use_sc: bool = True, num_elements: Optional[int] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, ) -> None: super().__init__() self.use_sc = use_sc - self.symmetric_contractions = SymmetricContraction( + self.symmetric_contractions = SymmetricContractionWrapper( irreps_in=node_feats_irreps, irreps_out=target_irreps, correlation=correlation, num_elements=num_elements, + cueq_config=cueq_config, ) # Update linear - self.linear = o3.Linear( + self.linear = Linear( target_irreps, target_irreps, internal_weights=True, shared_weights=True, + cueq_config=cueq_config, ) def forward( @@ -260,6 +300,7 @@ def __init__( hidden_irreps: o3.Irreps, avg_num_neighbors: float, radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, ) -> None: super().__init__() self.node_attrs_irreps = node_attrs_irreps @@ -272,6 +313,7 @@ def __init__( if radial_MLP is None: radial_MLP = [64, 64, 64] self.radial_MLP = radial_MLP + self.cueq_config = cueq_config self._setup() @@ -325,23 +367,29 @@ def __repr__(self): @compile_mode("script") class ResidualElementDependentInteractionBlock(InteractionBlock): def _setup(self) -> None: - self.linear_up = o3.Linear( + if not hasattr(self, "cueq_config"): + self.cueq_config = None + + # First linear + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) self.conv_tp_weights = TensorProductWeightsBlock( num_elements=self.node_attrs_irreps.num_irreps, @@ -353,13 +401,20 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) self.irreps_out = self.irreps_out.simplify() - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) def forward( @@ -389,23 +444,27 @@ def forward( @compile_mode("script") class AgnosticNonlinearInteractionBlock(InteractionBlock): def _setup(self) -> None: - self.linear_up = o3.Linear( + if not hasattr(self, "cueq_config"): + self.cueq_config = None + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -419,13 +478,20 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) self.irreps_out = self.irreps_out.simplify() - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.irreps_out, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) def forward( @@ -455,24 +521,28 @@ def forward( @compile_mode("script") class AgnosticResidualNonlinearInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -486,13 +556,20 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) self.irreps_out = self.irreps_out.simplify() - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) def forward( @@ -523,12 +600,15 @@ def forward( @compile_mode("script") class RealAgnosticInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -536,13 +616,14 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -553,17 +634,23 @@ def _setup(self) -> None: ) # Linear - irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.irreps_out, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, @@ -595,12 +682,15 @@ def forward( @compile_mode("script") class RealAgnosticResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -608,13 +698,14 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -625,17 +716,23 @@ def _setup(self) -> None: ) # Linear - irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, + self.node_attrs_irreps, + self.hidden_irreps, + cueq_config=self.cueq_config, ) - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, @@ -667,12 +764,15 @@ def forward( @compile_mode("script") class RealAgnosticDensityInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -680,13 +780,14 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -697,17 +798,22 @@ def _setup(self) -> None: ) # Linear - irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.irreps_out, self.node_attrs_irreps, self.irreps_out + self.skip_tp = FullyConnectedTensorProduct( + self.irreps_out, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, ) - self.reshape = reshape_irreps(self.irreps_out) # Density normalization self.density_fn = nn.FullyConnectedNet( @@ -718,7 +824,7 @@ def _setup(self) -> None: torch.nn.functional.silu, ) # Reshape - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, @@ -754,12 +860,16 @@ def forward( @compile_mode("script") class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None + # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -767,13 +877,14 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights @@ -784,17 +895,22 @@ def _setup(self) -> None: ) # Linear - irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, ) # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, + self.node_attrs_irreps, + self.hidden_irreps, + cueq_config=self.cueq_config, ) - self.reshape = reshape_irreps(self.irreps_out) # Density normalization self.density_fn = nn.FullyConnectedNet( @@ -806,7 +922,7 @@ def _setup(self) -> None: ) # Reshape - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) def forward( self, @@ -842,13 +958,16 @@ def forward( @compile_mode("script") class RealAgnosticAttResidualInteractionBlock(InteractionBlock): def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None self.node_feats_down_irreps = o3.Irreps("64x0e") # First linear - self.linear_up = o3.Linear( + self.linear_up = Linear( self.node_feats_irreps, self.node_feats_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) # TensorProduct irreps_mid, instructions = tp_out_irreps_with_instructions( @@ -856,21 +975,23 @@ def _setup(self) -> None: self.edge_attrs_irreps, self.target_irreps, ) - self.conv_tp = o3.TensorProduct( + self.conv_tp = TensorProduct( self.node_feats_irreps, self.edge_attrs_irreps, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, + cueq_config=self.cueq_config, ) # Convolution weights - self.linear_down = o3.Linear( + self.linear_down = Linear( self.node_feats_irreps, self.node_feats_down_irreps, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) input_dim = ( self.edge_feats_irreps.num_irreps @@ -882,19 +1003,21 @@ def _setup(self) -> None: ) # Linear - irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - self.linear = o3.Linear( + self.linear = Linear( irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True, + cueq_config=self.cueq_config, ) - self.reshape = reshape_irreps(self.irreps_out) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) # Skip connection. - self.skip_linear = o3.Linear(self.node_feats_irreps, self.hidden_irreps) + self.skip_linear = Linear( + self.node_feats_irreps, self.hidden_irreps, cueq_config=self.cueq_config + ) def forward( self, diff --git a/mace/modules/irreps_tools.py b/mace/modules/irreps_tools.py index b09601938..2e79c0abd 100644 --- a/mace/modules/irreps_tools.py +++ b/mace/modules/irreps_tools.py @@ -4,12 +4,14 @@ # This program is distributed under the MIT License (see MIT.md) ########################################################################################### -from typing import List, Tuple +from typing import List, Optional, Tuple import torch from e3nn import o3 from e3nn.util.jit import compile_mode +from mace.modules.wrapper_ops import CuEquivarianceConfig + # Based on mir-group/nequip def tp_out_irreps_with_instructions( @@ -64,9 +66,12 @@ def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: @compile_mode("script") class reshape_irreps(torch.nn.Module): - def __init__(self, irreps: o3.Irreps) -> None: + def __init__( + self, irreps: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None + ) -> None: super().__init__() self.irreps = o3.Irreps(irreps) + self.cueq_config = cueq_config self.dims = [] self.muls = [] for mul, ir in self.irreps: @@ -81,8 +86,19 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: for mul, d in zip(self.muls, self.dims): field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] ix += mul * d - field = field.reshape(batch, mul, d) + if hasattr(self, "cueq_config") and self.cueq_config is not None: + if self.cueq_config.layout_str == "mul_ir": + field = field.reshape(batch, mul, d) + else: + field = field.reshape(batch, d, mul) + else: + field = field.reshape(batch, mul, d) out.append(field) + + if hasattr(self, "cueq_config") and self.cueq_config is not None: + if self.cueq_config.layout_str == "mul_ir": + return torch.cat(out, dim=-1) + return torch.cat(out, dim=-2) return torch.cat(out, dim=-1) diff --git a/mace/modules/models.py b/mace/modules/models.py index c0d8ab430..0e03317e9 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -62,6 +62,7 @@ def __init__( radial_MLP: Optional[List[int]] = None, radial_type: Optional[str] = "bessel", heads: Optional[List[str]] = None, + cueq_config: Optional[Dict[str, Any]] = None, ): super().__init__() self.register_buffer( @@ -82,7 +83,9 @@ def __init__( node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + irreps_in=node_attr_irreps, + irreps_out=node_feats_irreps, + cueq_config=cueq_config, ) self.radial_embedding = RadialEmbeddingBlock( r_max=r_max, @@ -116,6 +119,7 @@ def __init__( hidden_irreps=hidden_irreps, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, + cueq_config=cueq_config, ) self.interactions = torch.nn.ModuleList([inter]) @@ -131,12 +135,15 @@ def __init__( correlation=correlation[0], num_elements=num_elements, use_sc=use_sc_first, + cueq_config=cueq_config, ) self.products = torch.nn.ModuleList([prod]) self.readouts = torch.nn.ModuleList() self.readouts.append( - LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) + LinearReadoutBlock( + hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config + ) ) for i in range(num_interactions - 1): @@ -155,6 +162,7 @@ def __init__( hidden_irreps=hidden_irreps_out, avg_num_neighbors=avg_num_neighbors, radial_MLP=radial_MLP, + cueq_config=cueq_config, ) self.interactions.append(inter) prod = EquivariantProductBasisBlock( @@ -163,6 +171,7 @@ def __init__( correlation=correlation[i + 1], num_elements=num_elements, use_sc=True, + cueq_config=cueq_config, ) self.products.append(prod) if i == num_interactions - 2: @@ -173,11 +182,14 @@ def __init__( gate, o3.Irreps(f"{len(heads)}x0e"), len(heads), + cueq_config, ) ) else: self.readouts.append( - LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) + LinearReadoutBlock( + hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config + ) ) def forward( @@ -471,6 +483,7 @@ def __init__( gate: Optional[Callable], avg_num_neighbors: float, atomic_numbers: List[int], + cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument ): super().__init__() self.r_max = r_max @@ -675,6 +688,7 @@ def __init__( ], # Just here to make it compatible with energy models, MUST be None radial_type: Optional[str] = "bessel", radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument ): super().__init__() self.register_buffer( @@ -876,6 +890,7 @@ def __init__( gate: Optional[Callable], atomic_energies: Optional[np.ndarray], radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument ): super().__init__() self.register_buffer( diff --git a/mace/modules/utils.py b/mace/modules/utils.py index d0a1e5f67..5731118c7 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -237,7 +237,8 @@ def _check_non_zero(std): def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int): out = [] - for i in range(num_layers - 1): + out.append(x[:, :num_features]) + for i in range(1, num_layers): out.append( x[ :, @@ -247,7 +248,6 @@ def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max * num_features, ] ) - out.append(x[:, -num_features:]) return torch.cat(out, dim=-1) diff --git a/mace/modules/wrapper_ops.py b/mace/modules/wrapper_ops.py new file mode 100644 index 000000000..580b4a0a6 --- /dev/null +++ b/mace/modules/wrapper_ops.py @@ -0,0 +1,269 @@ +""" +Wrapper class for o3.Linear that optionally uses cuet.Linear +""" + +import dataclasses +import itertools +import types +from typing import Iterator, List, Optional + +import numpy as np +import torch +from e3nn import o3 + +from mace.modules.symmetric_contraction import SymmetricContraction + +try: + import cuequivariance as cue + import cuequivariance_torch as cuet + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +if CUET_AVAILABLE: + + class O3_e3nn(cue.O3): + def __mul__( # pylint: disable=no-self-argument + rep1: "O3_e3nn", rep2: "O3_e3nn" + ) -> Iterator["O3_e3nn"]: + return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)] + + @classmethod + def clebsch_gordan( + cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn" + ) -> np.ndarray: + rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) + + if rep1.p * rep2.p == rep3.p: + return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt( + rep3.dim + ) + return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) + + def __lt__( # pylint: disable=no-self-argument + rep1: "O3_e3nn", rep2: "O3_e3nn" + ) -> bool: + rep2 = rep1._from(rep2) + return (rep1.l, rep1.p) < (rep2.l, rep2.p) + + @classmethod + def iterator(cls) -> Iterator["O3_e3nn"]: + for l in itertools.count(0): + yield O3_e3nn(l=l, p=1 * (-1) ** l) + yield O3_e3nn(l=l, p=-1 * (-1) ** l) + +else: + print( + "cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled." + ) + + +@dataclasses.dataclass +class CuEquivarianceConfig: + """Configuration for cuequivariance acceleration""" + + enabled: bool = False + layout: str = "mul_ir" # One of: mul_ir, ir_mul + layout_str: str = "mul_ir" + group: str = "O3" + optimize_all: bool = False # Set to True to enable all optimizations + optimize_linear: bool = False + optimize_channelwise: bool = False + optimize_symmetric: bool = False + optimize_fctp: bool = False + + def __post_init__(self): + if self.enabled and CUET_AVAILABLE: + self.layout_str = self.layout + self.layout = getattr(cue, self.layout) + self.group = ( + O3_e3nn if self.group == "O3_e3nn" else getattr(cue, self.group) + ) + + +class Linear: + """Returns either a cuet.Linear or o3.Linear based on config""" + + def __new__( + cls, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + shared_weights: bool = True, + internal_weights: bool = True, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_linear) + ): + instance = cuet.Linear( + cue.Irreps(cueq_config.group, irreps_in), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + optimize_fallback=True, + ) + instance.original_forward = instance.forward + + def cuet_forward(self, x: torch.Tensor) -> torch.Tensor: + return self.original_forward(x, use_fallback=True) + + instance.forward = types.MethodType(cuet_forward, instance) + return instance + + return o3.Linear( + irreps_in, + irreps_out, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + + +class TensorProduct: + """Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct""" + + def __new__( + cls, + irreps_in1: o3.Irreps, + irreps_in2: o3.Irreps, + irreps_out: o3.Irreps, + instructions: Optional[List] = None, + shared_weights: bool = False, + internal_weights: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_channelwise) + ): + instance = cuet.ChannelWiseTensorProduct( + cue.Irreps(cueq_config.group, irreps_in1), + cue.Irreps(cueq_config.group, irreps_in2), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + instance.original_forward = instance.forward + + def cuet_forward( + self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor + ) -> torch.Tensor: + return self.original_forward(x, y, z, use_fallback=None) + + instance.forward = types.MethodType(cuet_forward, instance) + return instance + + return o3.TensorProduct( + irreps_in1, + irreps_in2, + irreps_out, + instructions=instructions, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + + +class FullyConnectedTensorProduct: + """Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct""" + + def __new__( + cls, + irreps_in1: o3.Irreps, + irreps_in2: o3.Irreps, + irreps_out: o3.Irreps, + shared_weights: bool = True, + internal_weights: bool = True, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_fctp) + ): + instance = cuet.FullyConnectedTensorProduct( + cue.Irreps(cueq_config.group, irreps_in1), + cue.Irreps(cueq_config.group, irreps_in2), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + internal_weights=internal_weights, + optimize_fallback=True, + ) + instance.original_forward = instance.forward + + def cuet_forward( + self, x: torch.Tensor, attrs: torch.Tensor + ) -> torch.Tensor: + return self.original_forward(x, attrs, use_fallback=True) + + instance.forward = types.MethodType(cuet_forward, instance) + return instance + + return o3.FullyConnectedTensorProduct( + irreps_in1, + irreps_in2, + irreps_out, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + + +class SymmetricContractionWrapper: + """Wrapper around SymmetricContraction/cuet.SymmetricContraction""" + + def __new__( + cls, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + correlation: int, + num_elements: Optional[int] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_symmetric) + ): + instance = cuet.SymmetricContraction( + cue.Irreps(cueq_config.group, irreps_in), + cue.Irreps(cueq_config.group, irreps_out), + layout_in=cue.ir_mul, + layout_out=cueq_config.layout, + contraction_degree=correlation, + num_elements=num_elements, + original_mace=True, + dtype=torch.get_default_dtype(), + math_dtype=torch.get_default_dtype(), + ) + instance.original_forward = instance.forward + instance.layout = cueq_config.layout + + def cuet_forward( + self, x: torch.Tensor, attrs: torch.Tensor + ) -> torch.Tensor: + if self.layout == cue.mul_ir: + x = torch.transpose(x, 1, 2) + index_attrs = torch.nonzero(attrs)[:, 1].int() + return self.original_forward( + x.flatten(1), + index_attrs, + use_fallback=None, + ) + + instance.forward = types.MethodType(cuet_forward, instance) + return instance + + return SymmetricContraction( + irreps_in=irreps_in, + irreps_out=irreps_out, + correlation=correlation, + num_elements=num_elements, + ) diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index cb4f8ac53..e4e90a104 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -21,7 +21,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "--config", type=str, is_config_file=True, - help="config file to agregate options", + help="config file to aggregate options", ) except ImportError: parser = argparse.ArgumentParser( @@ -379,6 +379,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=int, default=1000, ) + parser.add_argument( + "--force_mh_ft_lr", + help="Force the multiheaded fine-tuning to use arg_parser lr", + type=str2bool, + default=False, + ) parser.add_argument( "--subselect_pt", help="Method to subselect the configurations of the pretraining set", @@ -660,6 +666,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=check_float_or_none, default=10.0, ) + # option for cuequivariance acceleration + parser.add_argument( + "--enable_cueq", + help="Enable cuequivariance acceleration", + type=str2bool, + default=False, + ) # options for using Weights and Biases for experiment tracking # to install see https://wandb.ai parser.add_argument( @@ -714,9 +727,24 @@ def build_default_arg_parser() -> argparse.ArgumentParser: def build_preprocess_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) + try: + import configargparse + + parser = configargparse.ArgumentParser( + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add( + "--config", + type=str, + is_config_file=True, + help="config file to aggregate options", + ) + except ImportError: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( "--train_file", help="Training set h5 file", diff --git a/mace/tools/arg_parser_tools.py b/mace/tools/arg_parser_tools.py index da64806a3..be714b26e 100644 --- a/mace/tools/arg_parser_tools.py +++ b/mace/tools/arg_parser_tools.py @@ -92,6 +92,15 @@ def check_args(args): # Loss and optimization # Check Stage Two loss start + if args.start_swa is not None: + args.swa = True + log_messages.append( + ( + "Stage Two is activated as start_stage_two was defined", + logging.INFO, + ) + ) + if args.swa: if args.start_swa is None: args.start_swa = max(1, args.max_num_epochs // 4 * 3) diff --git a/mace/tools/model_script_utils.py b/mace/tools/model_script_utils.py index 3f49eb418..d937446cb 100644 --- a/mace/tools/model_script_utils.py +++ b/mace/tools/model_script_utils.py @@ -146,15 +146,18 @@ def _build_model( args, model_config, model_config_foundation, heads ): # pylint: disable=too-many-return-statements if args.model == "MACE": + if args.interaction_first not in [ + "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + ]: + args.interaction_first = "RealAgnosticInteractionBlock" return modules.ScaleShiftMACE( **model_config, pair_repulsion=args.pair_repulsion, distance_transform=args.distance_transform, correlation=args.correlation, gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticInteractionBlock" - ], + interaction_cls_first=modules.interaction_classes[args.interaction_first], MLP_irreps=o3.Irreps(args.MLP_irreps), atomic_inter_scale=args.std, atomic_inter_shift=[0.0] * len(heads), diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index be96558de..9371e600f 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -175,6 +175,19 @@ def radial_to_transform(radial): scale = model.scale_shift.scale shift = model.scale_shift.shift + heads = model.heads if hasattr(model, "heads") else ["default"] + model_mlp_irreps = ( + o3.Irreps(str(model.readouts[-1].hidden_irreps)) + if model.num_interactions.item() > 1 + else 1 + ) + mlp_irreps = o3.Irreps(f"{model_mlp_irreps.count((0, 1)) // len(heads)}x0e") + try: + correlation = ( + len(model.products[0].symmetric_contractions.contractions[0].weights) + 1 + ) + except AttributeError: + correlation = model.products[0].symmetric_contractions.contraction_degree config = { "r_max": model.r_max.item(), "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), @@ -185,11 +198,7 @@ def radial_to_transform(radial): "num_interactions": model.num_interactions.item(), "num_elements": len(model.atomic_numbers), "hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)), - "MLP_irreps": ( - o3.Irreps(str(model.readouts[-1].hidden_irreps)) - if model.num_interactions.item() > 1 - else 1 - ), + "MLP_irreps": (mlp_irreps if model.num_interactions.item() > 1 else 1), "gate": ( model.readouts[-1] # pylint: disable=protected-access .non_linearity._modules["acts"][0] @@ -200,10 +209,7 @@ def radial_to_transform(radial): "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), "avg_num_neighbors": model.interactions[0].avg_num_neighbors, "atomic_numbers": model.atomic_numbers, - "correlation": len( - model.products[0].symmetric_contractions.contractions[0].weights - ) - + 1, + "correlation": correlation, "radial_type": radial_to_name( model.radial_embedding.bessel_fn.__class__.__name__ ), @@ -212,6 +218,7 @@ def radial_to_transform(radial): "distance_transform": radial_to_transform(model.radial_embedding), "atomic_inter_scale": scale.cpu().numpy(), "atomic_inter_shift": shift.cpu().numpy(), + "heads": heads, } return config @@ -265,8 +272,8 @@ def remove_pt_head( ) model_config["atomic_inter_scale"] = model.scale_shift.scale[head_idx].item() model_config["atomic_inter_shift"] = model.scale_shift.shift[head_idx].item() - mlp_count_irreps = model_config["MLP_irreps"].count((0, 1)) // len(model.heads) - model_config["MLP_irreps"] = o3.Irreps(f"{mlp_count_irreps}x0e") + mlp_count_irreps = model_config["MLP_irreps"].count((0, 1)) + # model_config["MLP_irreps"] = o3.Irreps(f"{mlp_count_irreps}x0e") new_model = model.__class__(**model_config) state_dict = model.state_dict() diff --git a/mace/tools/torch_tools.py b/mace/tools/torch_tools.py index e42a74f8e..31e837df1 100644 --- a/mace/tools/torch_tools.py +++ b/mace/tools/torch_tools.py @@ -6,7 +6,7 @@ import logging from contextlib import contextmanager -from typing import Dict +from typing import Dict, Union import numpy as np import torch @@ -129,13 +129,18 @@ def init_wandb(project: str, entity: str, name: str, config: dict, directory: st @contextmanager -def default_dtype(dtype: torch.dtype): +def default_dtype(dtype: Union[torch.dtype, str]): """Context manager for configuring the default_dtype used by torch Args: - dtype (torch.dtype): the default dtype to use within this context manager + dtype (torch.dtype|str): the default dtype to use within this context manager """ init = torch.get_default_dtype() - torch.set_default_dtype(dtype) + if isinstance(dtype, str): + set_default_dtype(dtype) + else: + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(init) diff --git a/pyproject.toml b/pyproject.toml index 489bc6e5e..c7644f784 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,3 +39,6 @@ ignore-paths = [ "^mace/tools/torch_geometric/.*$", "^mace/tools/scatter.py$", ] + +[tool.pylint.FORMAT] +max-module-lines = 1500 diff --git a/scripts/run_checks.sh b/scripts/run_checks.sh old mode 100755 new mode 100644 diff --git a/setup.cfg b/setup.cfg index 76467fdab..be4804c40 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,6 +29,7 @@ install_requires = GitPython pyYAML tqdm + cuequivariance-torch # for plotting: matplotlib pandas @@ -44,6 +45,8 @@ console_scripts = mace_finetuning = mace.cli.fine_tuning_select:main mace_convert_device = mace.cli.convert_device:main mace_select_head = mace.cli.select_head:main + mace_e3nn_cueq = mace.cli.convert_e3nn_cueq:main + mace_cueq_to_e3nn = mace.cli.convert_cueq_e3nn:main [options.extras_require] wandb = wandb @@ -57,3 +60,5 @@ dev = pytest-benchmark pylint schedulefree = schedulefree +cueq-cuda-11 = cuequivariance-ops-torch-cu11 +cueq-cuda-12 = cuequivariance-ops-torch-cu12 diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py new file mode 100644 index 000000000..78b04ccdf --- /dev/null +++ b/tests/test_benchmark.py @@ -0,0 +1,122 @@ +import os +from typing import Optional + +import pandas as pd +import json +import pytest +import torch +from ase import build + +from mace import data +from mace.calculators.foundations_models import mace_mp +from mace.tools import AtomicNumberTable, torch_geometric, torch_tools + + +def is_mace_full_bench(): + return os.environ.get("MACE_FULL_BENCH", "0") == "1" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +@pytest.mark.benchmark(warmup=True, warmup_iterations=4, min_rounds=8) +@pytest.mark.parametrize("size", (3, 5, 7, 9)) +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +@pytest.mark.parametrize("compile_mode", [None, "default"]) +def test_inference( + benchmark, size: int, dtype: str, compile_mode: Optional[str], device: str = "cuda" +): + if not is_mace_full_bench() and compile_mode is not None: + pytest.skip("Skipping long running benchmark, set MACE_FULL_BENCH=1 to execute") + + with torch_tools.default_dtype(dtype): + model = load_mace_mp_medium(dtype, compile_mode, device) + batch = create_batch(size, model, device) + log_bench_info(benchmark, dtype, compile_mode, batch) + + def func(): + torch.cuda.synchronize() + model(batch, training=compile_mode is not None, compute_force=True) + + torch.cuda.empty_cache() + benchmark(func) + + +def load_mace_mp_medium(dtype, compile_mode, device): + calc = mace_mp( + model="medium", + default_dtype=dtype, + device=device, + compile_mode=compile_mode, + fullgraph=False, + ) + model = calc.models[0].to(device) + return model + + +def create_batch(size: int, model: torch.nn.Module, device: str) -> dict: + cutoff = model.r_max.item() + z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + atoms = atoms.repeat((size, size, size)) + config = data.config_from_atoms(atoms) + dataset = [data.AtomicData.from_config(config, z_table=z_table, cutoff=cutoff)] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=dataset, + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch.to(device) + return batch.to_dict() + + +def log_bench_info(benchmark, dtype, compile_mode, batch): + benchmark.extra_info["num_atoms"] = int(batch["positions"].shape[0]) + benchmark.extra_info["num_edges"] = int(batch["edge_index"].shape[1]) + benchmark.extra_info["dtype"] = dtype + benchmark.extra_info["is_compiled"] = compile_mode is not None + benchmark.extra_info["device_name"] = torch.cuda.get_device_name() + + +def read_bench_results(files: list[str]) -> pd.DataFrame: + def read(file): + with open(file, "r") as f: + data = json.load(f) + + records = [] + for bench in data["benchmarks"]: + record = {**bench["extra_info"], **bench["stats"]} + records.append(record) + + df = pd.DataFrame(records) + df["ns/day (1 fs/step)"] = 0.086400 / df["median"] + df["Steps per day"] = df["ops"] * 86400 + columns = [ + "num_atoms", + "num_edges", + "dtype", + "is_compiled", + "device_name", + "median", + "Steps per day", + "ns/day (1 fs/step)", + ] + return df[columns] + + return pd.concat([read(f) for f in files]) + + +if __name__ == "__main__": + # Print to stdout a csv of the benchmark metrics + import subprocess + + result = subprocess.run( + ["pytest-benchmark", "list"], capture_output=True, text=True + ) + + if result.returncode != 0: + raise RuntimeError(f"Command failed with return code {result.returncode}") + + files = result.stdout.strip().split("\n") + df = read_bench_results(files) + print(df.to_csv(index=False)) diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 6590935cf..6c9e25635 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -16,6 +16,13 @@ from mace.calculators.mace import MACECalculator from mace.modules.models import ScaleShiftMACE +try: + import cuequivariance as cue # pylint: disable=unused-import + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + pytest_mace_dir = Path(__file__).parent.parent run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" @@ -177,6 +184,71 @@ def trained_model_equivariant_fixture(tmp_path_factory, fitting_configs): return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") +@pytest.fixture(scope="module", name="trained_equivariant_model_cueq") +def trained_model_equivariant_fixture_cueq(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "16x0e+16x1o", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", enable_cueq=True + ) + + @pytest.fixture(scope="module", name="trained_dipole_model") def trained_dipole_fixture(tmp_path_factory, fitting_configs): _mace_params = { @@ -468,18 +540,98 @@ def test_calculator_energy_dipole(fitting_configs, trained_energy_dipole_model): def test_calculator_descriptor(fitting_configs, trained_equivariant_model): at = fitting_configs[2].copy() - at.calc = trained_equivariant_model - - desc_invariant = at.calc.get_descriptors(at, invariants_only=True) - desc_single_layer = at.calc.get_descriptors(at, invariants_only=True, num_layers=1) - desc = at.calc.get_descriptors(at, invariants_only=False) + at_rotated = fitting_configs[2].copy() + at_rotated.rotate(90, "x") + calc = trained_equivariant_model + + desc_invariant = calc.get_descriptors(at, invariants_only=True) + desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) + desc_invariant_single_layer = calc.get_descriptors( + at, invariants_only=True, num_layers=1 + ) + desc_invariant_single_layer_rotated = calc.get_descriptors( + at_rotated, invariants_only=True, num_layers=1 + ) + desc = calc.get_descriptors(at, invariants_only=False) + desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) + desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) + desc_rotated_single_layer = calc.get_descriptors( + at_rotated, invariants_only=False, num_layers=1 + ) assert desc_invariant.shape[0] == 3 assert desc_invariant.shape[1] == 32 + assert desc_invariant_single_layer.shape[0] == 3 + assert desc_invariant_single_layer.shape[1] == 16 + assert desc.shape[0] == 3 + assert desc.shape[1] == 80 assert desc_single_layer.shape[0] == 3 - assert desc_single_layer.shape[1] == 16 + assert desc_single_layer.shape[1] == 16 * 4 + assert desc_rotated_single_layer.shape[0] == 3 + assert desc_rotated_single_layer.shape[1] == 16 * 4 + + np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) + np.testing.assert_allclose( + desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6 + ) + assert not np.allclose( + desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6 + ) + assert not np.allclose(desc, desc_rotated, atol=1e-6) + + +def test_calculator_descriptor_cueq(fitting_configs, trained_equivariant_model_cueq): + at = fitting_configs[2].copy() + at_rotated = fitting_configs[2].copy() + at_rotated.rotate(90, "x") + calc = trained_equivariant_model_cueq + + desc_invariant = calc.get_descriptors(at, invariants_only=True) + desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) + desc_invariant_single_layer = calc.get_descriptors( + at, invariants_only=True, num_layers=1 + ) + desc_invariant_single_layer_rotated = calc.get_descriptors( + at_rotated, invariants_only=True, num_layers=1 + ) + desc = calc.get_descriptors(at, invariants_only=False) + desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) + desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) + desc_rotated_single_layer = calc.get_descriptors( + at_rotated, invariants_only=False, num_layers=1 + ) + + assert desc_invariant.shape[0] == 3 + assert desc_invariant.shape[1] == 32 + assert desc_invariant_single_layer.shape[0] == 3 + assert desc_invariant_single_layer.shape[1] == 16 assert desc.shape[0] == 3 assert desc.shape[1] == 80 + assert desc_single_layer.shape[0] == 3 + assert desc_single_layer.shape[1] == 16 * 4 + assert desc_rotated_single_layer.shape[0] == 3 + assert desc_rotated_single_layer.shape[1] == 16 * 4 + + np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) + np.testing.assert_allclose( + desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6 + ) + assert not np.allclose( + desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6 + ) + assert not np.allclose(desc, desc_rotated, atol=1e-6) def test_mace_mp(capsys: pytest.CaptureFixture): @@ -506,3 +658,19 @@ def test_mace_off(): E = atoms.get_potential_energy() assert np.allclose(E, -2081.116128586803, atol=1e-9) + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +def test_mace_off_cueq(model="medium", device="cpu"): + mace_off_model = mace_off(model=model, device=device, enable_cueq=True) + assert isinstance(mace_off_model, MACECalculator) + assert mace_off_model.model_type == "MACE" + assert len(mace_off_model.models) == 1 + assert isinstance(mace_off_model.models[0], ScaleShiftMACE) + + atoms = build.molecule("H2O") + atoms.calc = mace_off_model + + E = atoms.get_potential_energy() + + assert np.allclose(E, -2081.116128586803, atol=1e-9) diff --git a/tests/test_cueq.py b/tests/test_cueq.py new file mode 100644 index 000000000..8d713c78f --- /dev/null +++ b/tests/test_cueq.py @@ -0,0 +1,177 @@ +from copy import deepcopy +from typing import Any, Dict + +import pytest +import torch +import torch.nn.functional as F +from e3nn import o3 + +from mace import data, modules, tools +from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq +from mace.tools import torch_geometric + +try: + import cuequivariance as cue # pylint: disable=unused-import + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +CUDA_AVAILABLE = torch.cuda.is_available() + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +class TestCueq: + @pytest.fixture + def model_config(self, interaction_cls_first, hidden_irreps) -> Dict[str, Any]: + table = tools.AtomicNumberTable([6]) + return { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 6, + "max_ell": 3, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": interaction_cls_first, + "num_interactions": 2, + "num_elements": 1, + "hidden_irreps": hidden_irreps, + "MLP_irreps": o3.Irreps("16x0e"), + "gate": F.silu, + "atomic_energies": torch.tensor([1.0]), + "avg_num_neighbors": 8, + "atomic_numbers": table.zs, + "correlation": 3, + "radial_type": "bessel", + "atomic_inter_scale": 1.0, + "atomic_inter_shift": 0.0, + } + + @pytest.fixture + def batch(self, device: str, default_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + from ase import build + + torch.set_default_dtype(default_dtype) + + table = tools.AtomicNumberTable([6]) + + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + import numpy as np + + displacement = np.random.uniform(-0.1, 0.1, size=atoms.positions.shape) + atoms.positions += displacement + atoms_list = [atoms.repeat((2, 2, 2))] + + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config(config, z_table=table, cutoff=5.0) + for config in configs + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + return batch.to(device).to_dict() + + @pytest.mark.parametrize( + "device", + ["cpu"] + (["cuda"] if CUDA_AVAILABLE else []), + ) + @pytest.mark.parametrize( + "interaction_cls_first", + [ + modules.interaction_classes["RealAgnosticResidualInteractionBlock"], + modules.interaction_classes["RealAgnosticInteractionBlock"], + modules.interaction_classes["RealAgnosticDensityInteractionBlock"], + ], + ) + @pytest.mark.parametrize( + "hidden_irreps", + [ + o3.Irreps("32x0e + 32x1o"), + o3.Irreps("32x0e + 32x1o + 32x2e"), + o3.Irreps("32x0e"), + ], + ) + @pytest.mark.parametrize("default_dtype", [torch.float32, torch.float64]) + def test_bidirectional_conversion( + self, + model_config: Dict[str, Any], + batch: Dict[str, torch.Tensor], + device: str, + default_dtype: torch.dtype, + ): + if device == "cuda" and not CUDA_AVAILABLE: + pytest.skip("CUDA not available") + torch.manual_seed(42) + + # Create original E3nn model + model_e3nn = modules.ScaleShiftMACE(**model_config).to(device) + + # Convert E3nn to CuEq + model_cueq = run_e3nn_to_cueq(model_e3nn).to(device) + + # Convert CuEq back to E3nn + model_e3nn_back = run_cueq_to_e3nn(model_cueq).to(device) + + # Test forward pass equivalence + out_e3nn = model_e3nn(deepcopy(batch), training=True, compute_stress=True) + out_cueq = model_cueq(deepcopy(batch), training=True, compute_stress=True) + out_e3nn_back = model_e3nn_back( + deepcopy(batch), training=True, compute_stress=True + ) + + # Check outputs match for both conversions + torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"]) + torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"]) + torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"]) + torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"]) + torch.testing.assert_close(out_e3nn["stress"], out_cueq["stress"]) + torch.testing.assert_close(out_cueq["stress"], out_e3nn_back["stress"]) + + # Test backward pass equivalence + loss_e3nn = out_e3nn["energy"].sum() + loss_cueq = out_cueq["energy"].sum() + loss_e3nn_back = out_e3nn_back["energy"].sum() + + loss_e3nn.backward() + loss_cueq.backward() + loss_e3nn_back.backward() + + # Compare gradients for all conversions + tol = 1e-4 if default_dtype == torch.float32 else 1e-8 + + def print_gradient_diff(name1, p1, name2, p2, conv_type): + if p1.grad is not None and p1.grad.shape == p2.grad.shape: + if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]: + error = torch.abs(p1.grad - p2.grad) + print( + f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}" + ) + torch.testing.assert_close(p1.grad, p2.grad, atol=tol, rtol=1e-10) + + # E3nn to CuEq gradients + for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip( + model_e3nn.named_parameters(), model_cueq.named_parameters() + ): + print_gradient_diff(name_e3nn, p_e3nn, name_cueq, p_cueq, "E3nn->CuEq") + + # CuEq to E3nn gradients + for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip( + model_cueq.named_parameters(), model_e3nn_back.named_parameters() + ): + print_gradient_diff( + name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn" + ) + + # Full circle comparison (E3nn -> E3nn) + for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip( + model_e3nn.named_parameters(), model_e3nn_back.named_parameters() + ): + print_gradient_diff( + name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle" + ) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index e0258bd4f..1f3068ba4 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -6,6 +6,7 @@ import ase.io import numpy as np import pytest +import yaml from ase.atoms import Atoms pytest_mace_dir = Path(__file__).parent.parent @@ -164,3 +165,42 @@ def test_preprocess_data(tmp_path, sample_configs): np.testing.assert_allclose(original_forces, h5_forces, rtol=1e-5, atol=1e-8) print("All checks passed successfully!") + + +def test_preprocess_config(tmp_path, sample_configs): + ase.io.write(tmp_path / "sample.xyz", sample_configs) + + preprocess_params = { + "train_file": str(tmp_path / "sample.xyz"), + "r_max": 5.0, + "config_type_weights": "{'Default':1.0}", + "num_process": 2, + "valid_fraction": 0.1, + "h5_prefix": str(tmp_path / "preprocessed_"), + "compute_statistics": None, + "seed": 42, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + } + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + yaml.dump(preprocess_params, file) + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(preprocess_data) + + " " + + "--config" + + " " + + str(filename) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 diff --git a/tests/test_run_train.py b/tests/test_run_train.py index ca196c476..2b56c10bd 100644 --- a/tests/test_run_train.py +++ b/tests/test_run_train.py @@ -7,10 +7,18 @@ import ase.io import numpy as np import pytest +import torch from ase.atoms import Atoms from mace.calculators.mace import MACECalculator +try: + import cuequivariance as cue # pylint: disable=unused-import + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" @@ -429,14 +437,10 @@ def test_run_train_foundation(tmp_path, fitting_configs): mace_params["num_radial_basis"] = 10 mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" mace_params["multiheads_finetuning"] = False - print("mace_params", mace_params) - # mace_params["num_samples_pt"] = 50 - # mace_params["subselect_pt"] = "random" - # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() sys.path.insert(0, str(Path(__file__).parent.parent)) run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) cmd = ( sys.executable @@ -541,6 +545,7 @@ def test_run_train_foundation_multihead(tmp_path, fitting_configs): mace_params["valid_batch_size"] = 1 mace_params["num_samples_pt"] = 50 mace_params["subselect_pt"] = "random" + mace_params["force_mh_ft_lr"] = True # make sure run_train.py is using the mace that is currently being tested run_env = os.environ.copy() sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -658,6 +663,7 @@ def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): mace_params["valid_batch_size"] = 1 mace_params["num_samples_pt"] = 50 mace_params["subselect_pt"] = "random" + mace_params["force_mh_ft_lr"] = True # make sure run_train.py is using the mace that is currently being tested run_env = os.environ.copy() sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -819,6 +825,7 @@ def test_run_train_multihead_replay_custum_finetuning( "pt_train_file": os.path.join(tmp_path, "pretrain.xyz"), "num_samples_pt": 3, "subselect_pt": "random", + "force_mh_ft_lr": True, } cmd = [sys.executable, str(run_train)] @@ -847,3 +854,200 @@ def test_run_train_multihead_replay_custum_finetuning( assert len(Es) == len(fitting_configs) assert all(isinstance(E, float) for E in Es) assert len(set(Es)) > 1 # Ens + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +def test_run_train_cueq(tmp_path, fitting_configs): + torch.set_default_dtype(torch.float64) + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["enable_cueq"] = True + mace_params["device"] = "cpu" + mace_params["default_dtype"] = "float64" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + Es = [] + for at in fitting_configs[2:]: + at.calc = calc + Es.append(at.get_potential_energy()) + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", enable_cueq=True + ) + Es_cueq = [] + for at in fitting_configs[2:]: + at.calc = calc + Es_cueq.append(at.get_potential_energy()) + + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 + ref_Es = [ + -0.039181344585828524, + -0.0915223395136733, + -0.14953484236456582, + -0.06662480820063998, + -0.09983737353050133, + 0.12477442296789745, + -0.06486086271762856, + -0.1460607988519944, + 0.12886334908465508, + -0.14000990081920373, + -0.05319886578958313, + 0.07780520158391, + -0.08895480281886901, + -0.15474719614734422, + 0.007756765146527644, + -0.044879267197498685, + -0.036065736712447574, + -0.24413743841886623, + -0.0838104612106429, + -0.14751978636626545, + ] + + assert np.allclose(Es, ref_Es) + assert np.allclose(ref_Es, Es_cueq) + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +def test_run_train_foundation_multihead_json_cueq(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + for i, c in enumerate(fitting_configs): + + if i in (0, 1): + continue # skip isolated atoms, as energies specified by json files below + if i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + # write E0s to json files + E0s = {1: 0.0, 8: 0.0} + with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + + heads = { + "DFT": { + "train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_dft.json", + }, + "MP2": { + "train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json", + }, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["config"] = tmp_path / "config.yaml" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + mace_params["enable_cueq"] = True + mace_params["force_mh_ft_lr"] = True + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 20/08/2024 on commit + ref_Es = [ + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, + ] + assert np.allclose(Es, ref_Es, atol=1e-1)