Skip to content

Commit

Permalink
hotfix mattersim pickling issue; add phonon task (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
chiang-yuan authored Jan 17, 2025
1 parent 8cb1d3b commit 5716d3b
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 53 deletions.
22 changes: 15 additions & 7 deletions mlip_arena/models/externals/mattersim.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,21 @@ def __init__(
load_path=checkpoint, device=str(device or get_freer_device()), **kwargs
)

def calculate(
self,
atoms: Atoms | None = None,
properties: list | None = None,
system_changes: list | None = None,
):
super().calculate(atoms, properties, system_changes)
def __getstate__(self):
state = self.__dict__.copy()

# BUG: remove unpicklizable potential
state.pop("potential", None)

return state

# def calculate(
# self,
# atoms: Atoms | None = None,
# properties: list | None = None,
# system_changes: list | None = None,
# ):
# super().calculate(atoms, properties, system_changes)

# # convert unpicklizable atoms back to picklizable atoms to avoid prefect pickling error
# if isinstance(self.atoms, MSONAtoms):
Expand Down
73 changes: 37 additions & 36 deletions mlip_arena/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import yaml
from huggingface_hub import HfApi, HfFileSystem, hf_hub_download

from mlip_arena.models import MLIP
from mlip_arena.models import REGISTRY as MODEL_REGISTRY
# from mlip_arena.models import MLIP
# from mlip_arena.models import REGISTRY as MODEL_REGISTRY

try:
from .elasticity import run as ELASTICITY
Expand All @@ -13,52 +13,53 @@
from .neb import run as NEB
from .neb import run_from_endpoints as NEB_FROM_ENDPOINTS
from .optimize import run as OPT
from .phonon import run as PHONON

__all__ = ["OPT", "EOS", "MD", "NEB", "NEB_FROM_ENDPOINTS", "ELASTICITY"]
__all__ = ["OPT", "EOS", "MD", "NEB", "NEB_FROM_ENDPOINTS", "ELASTICITY", "PHONON"]
except ImportError:
pass

with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
REGISTRY = yaml.safe_load(f)


class Task:
def __init__(self):
self.name: str = self.__class__.__name__ # display name on the leaderboard
# class Task:
# def __init__(self):
# self.name: str = self.__class__.__name__ # display name on the leaderboard

def run_local(self, model: MLIP):
"""Run the task using the given model and return the results."""
raise NotImplementedError
# def run_local(self, model: MLIP):
# """Run the task using the given model and return the results."""
# raise NotImplementedError

def run_hf(self, model: MLIP):
"""Run the task using the given model and return the results."""
raise NotImplementedError
# def run_hf(self, model: MLIP):
# """Run the task using the given model and return the results."""
# raise NotImplementedError

# Calcualte evaluation metrics and postprocessed data
api = HfApi()
api.upload_file(
path_or_fileobj="results.json",
path_in_repo=f"{self.__class__.__name__}/{model.__class__.__name__}/results.json", # Upload to a specific folder
repo_id="atomind/mlip-arena",
repo_type="dataset",
)
# # Calcualte evaluation metrics and postprocessed data
# api = HfApi()
# api.upload_file(
# path_or_fileobj="results.json",
# path_in_repo=f"{self.__class__.__name__}/{model.__class__.__name__}/results.json", # Upload to a specific folder
# repo_id="atomind/mlip-arena",
# repo_type="dataset",
# )

def run_nersc(self, model: MLIP):
"""Run the task using the given model and return the results."""
raise NotImplementedError
# def run_nersc(self, model: MLIP):
# """Run the task using the given model and return the results."""
# raise NotImplementedError

def get_results(self):
"""Get the results from the task."""
# fs = HfFileSystem()
# files = fs.glob(f"datasets/atomind/mlip-arena/{self.__class__.__name__}/*/*.json")
# def get_results(self):
# """Get the results from the task."""
# # fs = HfFileSystem()
# # files = fs.glob(f"datasets/atomind/mlip-arena/{self.__class__.__name__}/*/*.json")

for model, metadata in MODEL_REGISTRY.items():
results = hf_hub_download(
repo_id="atomind/mlip-arena",
filename="results.json",
subfolder=f"{self.__class__.__name__}/{model}",
repo_type="dataset",
revision=None,
)
# for model, metadata in MODEL_REGISTRY.items():
# results = hf_hub_download(
# repo_id="atomind/mlip-arena",
# filename="results.json",
# subfolder=f"{self.__class__.__name__}/{model}",
# repo_type="dataset",
# revision=None,
# )

return results
# return results
162 changes: 162 additions & 0 deletions mlip_arena/tasks/phonon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""
This module has been adapted from Quacc (https://github.com/Quantum-Accelerators/quacc). By using this software, you agree to the Quacc license agreement: https://github.com/Quantum-Accelerators/quacc/blob/main/LICENSE.md
BSD 3-Clause License
Copyright (c) 2025, Andrew S. Rosen.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
- Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
- Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

from pathlib import Path

import numpy as np
from phonopy import Phonopy
from phonopy.structure.atoms import PhonopyAtoms
from prefect import task
from prefect.cache_policies import INPUTS, TASK_SOURCE
from prefect.runtime import task_run

from ase import Atoms
from ase.calculators.calculator import BaseCalculator


@task(cache_policy=TASK_SOURCE + INPUTS)
def get_phonopy(
atoms: Atoms,
supercell_matrix: list[int] | None = None,
min_lengths: float | tuple[float, float, float] | None = None,
symprec: float = 1e-5,
distance: float = 0.01,
phonopy_kwargs: dict = {},
) -> Phonopy:
if supercell_matrix is None and min_lengths is not None:
supercell_matrix = np.diag(
np.round(np.ceil(min_lengths / atoms.cell.lengths()))
)

phonon = Phonopy(
PhonopyAtoms(
symbols=atoms.get_chemical_symbols(),
cell=atoms.get_cell(),
scaled_positions=atoms.get_scaled_positions(wrap=True),
masses=atoms.get_masses(),
),
symprec=symprec,
supercell_matrix=supercell_matrix,
**phonopy_kwargs,
)
phonon.generate_displacements(distance=distance)

return phonon


def _get_forces(
phononpy_atoms: PhonopyAtoms,
calculator: BaseCalculator,
) -> np.ndarray:
atoms = Atoms(
symbols=phononpy_atoms.symbols,
cell=phononpy_atoms.cell,
scaled_positions=phononpy_atoms.scaled_positions,
pbc=True,
)

atoms.calc = calculator

return atoms.get_forces()


def _generate_task_run_name():
task_name = task_run.task_name
parameters = task_run.parameters

atoms = parameters["atoms"]
calculator = parameters["calculator"]

return (
f"{task_name}: {atoms.get_chemical_formula()} - {calculator.__class__.__name__}"
)


@task(
name="PHONON",
task_run_name=_generate_task_run_name,
cache_policy=TASK_SOURCE + INPUTS,
)
def run(
atoms: Atoms,
calculator: BaseCalculator,
supercell_matrix: list[int] | None = None,
min_lengths: float | tuple[float, float, float] | None = None,
symprec: float = 1e-5,
distance: float = 0.01,
phonopy_kwargs: dict = {},
symmetry: bool = False,
t_min: float = 0.0,
t_max: float = 1000.0,
t_step: float = 10.0,
outdir: str | None = None,
):
phonon = get_phonopy(
atoms=atoms,
supercell_matrix=supercell_matrix,
min_lengths=min_lengths,
symprec=symprec,
distance=distance,
phonopy_kwargs=phonopy_kwargs,
)

supercells_with_displacements = phonon.supercells_with_displacements

phonon.forces = [
_get_forces(supercell, calculator)
for supercell in supercells_with_displacements
if supercell is not None
]
phonon.produce_force_constants()

if symmetry:
phonon.symmetrize_force_constants()
phonon.symmetrize_force_constants_by_space_group()

phonon.run_mesh(with_eigenvectors=True)
phonon.run_total_dos()
phonon.run_thermal_properties(t_step=t_step, t_max=t_max, t_min=t_min) # type: ignore
phonon.auto_band_structure(
write_yaml=True if outdir is not None else False,
filename=Path(outdir, "band.yaml") if outdir is not None else "band.yaml",
)
if outdir:
phonon.save(
Path(outdir, "phonopy.yaml"), settings={"force_constants": True}
)

return {
"phonon": phonon,
}
58 changes: 48 additions & 10 deletions mlip_arena/tasks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from __future__ import annotations

from pprint import pformat

import torch
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator

from ase import units
from ase.calculators.calculator import Calculator, BaseCalculator
from ase.calculators.calculator import BaseCalculator
from ase.calculators.mixing import SumCalculator
from mlip_arena.models import MLIPEnum
from mlip_arena.models.utils import get_freer_device

try:
from prefect.logging import get_run_logger
Expand All @@ -17,16 +19,48 @@
except (ImportError, RuntimeError):
from loguru import logger

from pprint import pformat

def get_freer_device() -> torch.device:
"""Get the GPU with the most free memory, or use MPS if available.
s
Returns:
torch.device: The selected GPU device or MPS.
Raises:
ValueError: If no GPU or MPS is available.
"""
device_count = torch.cuda.device_count()
if device_count > 0:
# If CUDA GPUs are available, select the one with the most free memory
mem_free = [
torch.cuda.get_device_properties(i).total_memory
- torch.cuda.memory_allocated(i)
for i in range(device_count)
]
free_gpu_index = mem_free.index(max(mem_free))
device = torch.device(f"cuda:{free_gpu_index}")
logger.info(
f"Selected GPU {device} with {mem_free[free_gpu_index] / 1024**2:.2f} MB free memory from {device_count} GPUs"
)
elif torch.backends.mps.is_available():
# If no CUDA GPUs are available but MPS is, use MPS
logger.info("No GPU available. Using MPS.")
device = torch.device("mps")
else:
# Fallback to CPU if neither CUDA GPUs nor MPS are available
logger.info("No GPU or MPS available. Using CPU.")
device = torch.device("cpu")

return device


def get_calculator(
calculator_name: str | MLIPEnum | Calculator | SumCalculator,
calculator_kwargs: dict | None,
calculator_name: str | MLIPEnum | BaseCalculator,
calculator_kwargs: dict | None = None,
dispersion: bool = False,
dispersion_kwargs: dict | None = None,
device: str | None = None,
) -> Calculator | SumCalculator:
) -> BaseCalculator:
"""Get a calculator with optional dispersion correction."""

device = device or str(get_freer_device())
Expand All @@ -40,11 +74,15 @@ def get_calculator(
calc = calculator_name.value(**calculator_kwargs)
elif isinstance(calculator_name, str) and hasattr(MLIPEnum, calculator_name):
calc = MLIPEnum[calculator_name].value(**calculator_kwargs)
elif isinstance(calculator_name, type) and issubclass(calculator_name, BaseCalculator):
elif isinstance(calculator_name, type) and issubclass(
calculator_name, BaseCalculator
):
logger.warning(f"Using custom calculator class: {calculator_name}")
calc = calculator_name(**calculator_kwargs)
elif isinstance(calculator_name, Calculator | SumCalculator):
logger.warning(f"Using custom calculator object (kwargs are ignored): {calculator_name}")
elif isinstance(calculator_name, BaseCalculator):
logger.warning(
f"Using custom calculator object (kwargs are ignored): {calculator_name}"
)
calc = calculator_name
else:
raise ValueError(f"Invalid calculator: {calculator_name}")
Expand All @@ -69,5 +107,5 @@ def get_calculator(
if dispersion_kwargs:
logger.info(pformat(dispersion_kwargs))

assert isinstance(calc, Calculator | SumCalculator)
assert isinstance(calc, BaseCalculator)
return calc

0 comments on commit 5716d3b

Please sign in to comment.