Skip to content

Commit

Permalink
Define MLFF Enum to ensure consistent force field names (#729)
Browse files Browse the repository at this point in the history
* use np.eye(3) for identity matrix

* add MLFF enum to ensure consistent force field names

also support mace-torch in phonon schema pkg version provenance

* fix pydantic

* move MLFF to __init__ to avoid circular import

* fix ElectronPhononRenormalisationDoc.from_band_structures "band structure is metallic" warning

* i->idx
  • Loading branch information
janosh authored Feb 18, 2024
1 parent db340fe commit eb6c547
Show file tree
Hide file tree
Showing 20 changed files with 127 additions and 110 deletions.
10 changes: 5 additions & 5 deletions src/atomate2/common/jobs/defect.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def get_charged_structures(structure: Structure, charges: Iterable) -> list[Stru
A dictionary with the two structures with the charge states added.
"""
structs_out = [structure.copy() for _ in charges]
for i, q in enumerate(charges):
structs_out[i].set_charge(q)
for idx, q in enumerate(charges):
structs_out[idx].set_charge(q)
return structs_out


Expand Down Expand Up @@ -108,15 +108,15 @@ def spawn_energy_curve_calcs(
distorted_structure, nimages=s_distortions
)
# add all the distorted structures
for i, d_struct in enumerate(distorted_structures):
for idx, d_struct in enumerate(distorted_structures):
static_job = static_maker.make(d_struct, prev_dir=prev_dir)
suffix = f" {i}" if add_name == "" else f" {add_name} {i}"
suffix = f" {idx}" if add_name == "" else f" {add_name} {idx}"

# write some provenances data in info.json file
info = {
"relaxed_structure": relaxed_structure,
"distorted_structure": distorted_structure,
"distortion": s_distortions[i],
"distortion": s_distortions[idx],
}
if add_info is not None:
info.update(add_info)
Expand Down
4 changes: 2 additions & 2 deletions src/atomate2/common/jobs/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def run_elastic_deformations(
"""
relaxations = []
outputs = []
for i, deformation in enumerate(deformations):
for idx, deformation in enumerate(deformations):
# deform the structure
dst = DeformStructureTransformation(deformation=deformation)
ts = TransformedStructure(structure, transformations=[dst])
Expand All @@ -146,7 +146,7 @@ def run_elastic_deformations(
elastic_job_kwargs[prev_dir_argname] = prev_dir
# create the job
relax_job = elastic_relax_maker.make(deformed_structure, **elastic_job_kwargs)
relax_job.append_name(f" {i + 1}/{len(deformations)}")
relax_job.append_name(f" {idx + 1}/{len(deformations)}")
relaxations.append(relax_job)

# extract the outputs we want
Expand Down
8 changes: 2 additions & 6 deletions src/atomate2/common/jobs/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import numpy as np
from jobflow import Flow, Response, job
from phonopy import Phonopy
from phonopy.units import VaspToTHz
Expand All @@ -25,7 +26,6 @@
if TYPE_CHECKING:
from pathlib import Path

import numpy as np
from emmet.core.math import Matrix3D

from atomate2.forcefields.jobs import ForceFieldStaticMaker
Expand Down Expand Up @@ -156,11 +156,7 @@ def generate_phonon_displacements(
# a bit of code repetition here as I currently
# do not see how to pass the phonopy object?
if use_symmetrized_structure == "primitive" and kpath_scheme != "seekpath":
primitive_matrix: list[list[float]] | str = [
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
primitive_matrix: np.ndarray | str = np.eye(3)
else:
primitive_matrix = "auto"
phonon = Phonopy(
Expand Down
12 changes: 7 additions & 5 deletions src/atomate2/common/schemas/cclib.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,13 +376,15 @@ def _get_homos_lumos(
The HOMO-LUMO gaps (eV), calculated as LUMO_alpha-HOMO_alpha and
LUMO_beta-HOMO_beta
"""
homo_energies = [mo_energies[i][h] for i, h in enumerate(homo_indices)]
# Make sure that the HOMO+1 (i.e. LUMO) is in moenergies (sometimes virtual
homo_energies = [mo_energies[idx][homo] for idx, homo in enumerate(homo_indices)]
# Make sure that the HOMO+1 (i.e. LUMO) is in MO energies (sometimes virtual
# orbitals aren't printed in the output)
for i, h in enumerate(homo_indices):
if len(mo_energies[i]) < h + 2:
for idx, homo in enumerate(homo_indices):
if len(mo_energies[idx]) < homo + 2:
return homo_energies, None, None
lumo_energies = [mo_energies[i][h + 1] for i, h in enumerate(homo_indices)]
lumo_energies = [
mo_energies[idx][homo + 1] for idx, homo in enumerate(homo_indices)
]
homo_lumo_gaps = [
lumo_energies[i] - homo_energies[i] for i in range(len(homo_energies))
]
Expand Down
14 changes: 7 additions & 7 deletions src/atomate2/common/schemas/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,17 @@ def _expand_strains(
`generate_elastic_deformations()`.
"""
sga = SpacegroupAnalyzer(structure, symprec=symprec)
symmops = sga.get_symmetry_operations(cartesian=True)
symm_ops = sga.get_symmetry_operations(cartesian=True)

full_strains = deepcopy(strains)
full_stresses = deepcopy(stresses)
full_uuids = deepcopy(uuids)
full_job_dirs = deepcopy(job_dirs)

mapping = TensorMapping(full_strains, [True for _ in full_strains])
for i, strain in enumerate(strains):
for symmop in symmops:
rotated_strain = strain.transform(symmop)
for idx, strain in enumerate(strains):
for symm_op in symm_ops:
rotated_strain = strain.transform(symm_op)

# check if we have more than one perturbed strain component
if sum(np.abs(rotated_strain.voigt) > tol) > 1:
Expand All @@ -304,8 +304,8 @@ def _expand_strains(

# expand the other properties
full_strains.append(rotated_strain)
full_stresses.append(stresses[i].transform(symmop))
full_uuids.append(uuids[i])
full_job_dirs.append(job_dirs[i])
full_stresses.append(stresses[idx].transform(symm_op))
full_uuids.append(uuids[idx])
full_job_dirs.append(job_dirs[idx])

return full_strains, full_stresses, full_uuids, full_job_dirs
6 changes: 1 addition & 5 deletions src/atomate2/common/schemas/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,7 @@ def from_forces_born(
cell = get_phonopy_structure(structure)

if use_symmetrized_structure == "primitive" and kpath_scheme != "seekpath":
primitive_matrix: Union[list[list[float]], str] = [
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
primitive_matrix: Union[np.ndarray, str] = np.eye(3)
else:
primitive_matrix = "auto"
phonon = Phonopy(
Expand Down
10 changes: 10 additions & 0 deletions src/atomate2/forcefields/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,11 @@
"""Tools and functions common to all forcefields."""
from enum import Enum


class MLFF(Enum): # TODO inherit from StrEnum when 3.11+
"""Names of ML force fields."""

MACE = "MACE"
GAP = "GAP"
M3GNet = "M3GNet"
CHGNet = "CHGNet"
9 changes: 5 additions & 4 deletions src/atomate2/forcefields/flows/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from jobflow import Flow, Maker

from atomate2.forcefields import MLFF
from atomate2.forcefields.jobs import CHGNetRelaxMaker, M3GNetRelaxMaker
from atomate2.vasp.jobs.core import RelaxMaker

Expand All @@ -31,7 +32,7 @@ class CHGNetVaspRelaxMaker(Maker):
Maker to generate a VASP relaxation job.
"""

name: str = "CHGNet relax followed by a VASP relax"
name: str = f"{MLFF.CHGNet} relax followed by a VASP relax"
chgnet_maker: CHGNetRelaxMaker = field(default_factory=CHGNetRelaxMaker)
vasp_maker: BaseVaspMaker = field(default_factory=RelaxMaker)

Expand All @@ -50,7 +51,7 @@ def make(self, structure: Structure) -> Flow:
A flow containing a CHGNet relaxation followed by a VASP relaxation
"""
chgnet_relax_job = self.chgnet_maker.make(structure)
chgnet_relax_job.name = "CHGNet pre-relax"
chgnet_relax_job.name = f"{MLFF.CHGNet} pre-relax"

vasp_job = self.vasp_maker.make(chgnet_relax_job.output.structure)

Expand All @@ -72,7 +73,7 @@ class M3GNetVaspRelaxMaker(Maker):
Maker to generate a VASP relaxation job.
"""

name: str = "M3GNet relax followed by a VASP relax"
name: str = f"{MLFF.M3GNet} relax followed by a VASP relax"
m3gnet_maker: M3GNetRelaxMaker = field(default_factory=M3GNetRelaxMaker)
vasp_maker: BaseVaspMaker = field(default_factory=RelaxMaker)

Expand All @@ -91,7 +92,7 @@ def make(self, structure: Structure) -> Flow:
A flow containing a M3GNet relaxation followed by a VASP relaxation
"""
m3gnet_relax_job = self.m3gnet_maker.make(structure)
m3gnet_relax_job.name = "M3GNet pre-relax"
m3gnet_relax_job.name = f"{MLFF.M3GNet} pre-relax"

vasp_job = self.vasp_maker.make(m3gnet_relax_job.output.structure)

Expand Down
43 changes: 23 additions & 20 deletions src/atomate2/forcefields/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from jobflow import Maker, job

from atomate2.forcefields import MLFF
from atomate2.forcefields.schemas import ForceFieldTaskDocument
from atomate2.forcefields.utils import Relaxer

Expand All @@ -25,6 +26,8 @@ class ForceFieldRelaxMaker(Maker):
"""
Base Maker to calculate forces and stresses using any force field.
Should be subclassed to use a specific force field.
Parameters
----------
name : str
Expand All @@ -43,8 +46,8 @@ class ForceFieldRelaxMaker(Maker):
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""

name: str = "Forcefield relax"
force_field_name: str = "Forcefield"
name: str = "Force field relax"
force_field_name: str = "Force field"
relax_cell: bool = True
steps: int = 500
relax_kwargs: dict = field(default_factory=dict)
Expand Down Expand Up @@ -103,8 +106,8 @@ class ForceFieldStaticMaker(ForceFieldRelaxMaker):
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""

name: str = "ForceField static"
force_field_name: str = "Forcefield"
name: str = "Force field static"
force_field_name: str = "Force field"
task_document_kwargs: dict = field(default_factory=dict)

@job(output_schema=ForceFieldTaskDocument)
Expand Down Expand Up @@ -165,8 +168,8 @@ class CHGNetRelaxMaker(ForceFieldRelaxMaker):
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""

name: str = "CHGNet relax"
force_field_name = "CHGNet"
name: str = f"{MLFF.CHGNet} relax"
force_field_name = f"{MLFF.CHGNet}"
relax_cell: bool = True
steps: int = 500
relax_kwargs: dict = field(default_factory=dict)
Expand Down Expand Up @@ -195,8 +198,8 @@ class CHGNetStaticMaker(ForceFieldStaticMaker):
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""

name: str = "CHGNet static"
force_field_name = "CHGNet"
name: str = f"{MLFF.CHGNet} static"
force_field_name = f"{MLFF.CHGNet}"
task_document_kwargs: dict = field(default_factory=dict)

def _evaluate_static(self, structure: Structure) -> dict:
Expand Down Expand Up @@ -229,8 +232,8 @@ class M3GNetRelaxMaker(ForceFieldRelaxMaker):
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""

name: str = "M3GNet relax"
force_field_name: str = "M3GNet"
name: str = f"{MLFF.M3GNet} relax"
force_field_name: str = f"{MLFF.M3GNet}"
relax_cell: bool = True
steps: int = 500
relax_kwargs: dict = field(default_factory=dict)
Expand Down Expand Up @@ -267,8 +270,8 @@ class M3GNetStaticMaker(ForceFieldStaticMaker):
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""

name: str = "M3GNet static"
force_field_name: str = "M3GNet"
name: str = f"{MLFF.M3GNet} static"
force_field_name: str = f"{MLFF.M3GNet}"
task_document_kwargs: dict = field(default_factory=dict)

def _evaluate_static(self, structure: Structure) -> dict:
Expand Down Expand Up @@ -315,8 +318,8 @@ class MACERelaxMaker(ForceFieldRelaxMaker):
:obj:`mace.calculators.MACECalculator()'`.
"""

name: str = "MACE relax"
force_field_name: str = "MACE"
name: str = f"{MLFF.MACE} relax"
force_field_name: str = f"{MLFF.MACE}"
relax_cell: bool = True
steps: int = 500
relax_kwargs: dict = field(default_factory=dict)
Expand Down Expand Up @@ -358,8 +361,8 @@ class MACEStaticMaker(ForceFieldStaticMaker):
:obj:`mace.calculators.MACECalculator()'`.
"""

name: str = "MACE static"
force_field_name: str = "MACE"
name: str = f"{MLFF.MACE} static"
force_field_name: str = f"{MLFF.MACE}"
task_document_kwargs: dict = field(default_factory=dict)
model: str | Path | Sequence[str | Path] | None = None
model_kwargs: dict = field(default_factory=dict)
Expand Down Expand Up @@ -401,8 +404,8 @@ class GAPRelaxMaker(ForceFieldRelaxMaker):
Further keywords for :obj:`quippy.potential.Potential()'`.
"""

name: str = "GAP relax"
force_field_name: str = "GAP"
name: str = f"{MLFF.GAP} relax"
force_field_name: str = f"{MLFF.GAP}"
relax_cell: bool = True
steps: int = 500
relax_kwargs: dict = field(default_factory=dict)
Expand Down Expand Up @@ -447,8 +450,8 @@ class GAPStaticMaker(ForceFieldStaticMaker):
Further keywords for :obj:`quippy.potential.Potential()'`.
"""

name: str = "GAP static"
force_field_name: str = "GAP"
name: str = f"{MLFF.GAP} static"
force_field_name: str = f"{MLFF.GAP}"
task_document_kwargs: dict = field(default_factory=dict)
potential_args_str: str = "IP GAP"
potential_param_file_name: str | Path = "gap.xml"
Expand Down
27 changes: 15 additions & 12 deletions src/atomate2/forcefields/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pymatgen.core.structure import Structure
from pymatgen.io.ase import AseAtomsAdaptor

from atomate2.forcefields import MLFF


class IonicStep(BaseModel, extra="allow"): # type: ignore[call-arg]
"""Document defining the information at each ionic step."""
Expand Down Expand Up @@ -135,12 +137,11 @@ def from_ase_compatible_result(
"""
trajectory = result["trajectory"].__dict__

# NOTE: units for stresses were converted from eV/Angstrom³ to kBar
# (* -1 from standard output)
# NOTE: convert stress units from eV/A³ to kBar (* -1 from standard output)
# and to 3x3 matrix to comply with MP convention
for i in range(len(trajectory["stresses"])):
trajectory["stresses"][i] = voigt_6_to_full_3x3_stress(
trajectory["stresses"][i] * -10 / GPa
for idx in range(len(trajectory["stresses"])):
trajectory["stresses"][idx] = voigt_6_to_full_3x3_stress(
trajectory["stresses"][idx] * -10 / GPa
)

species = AseAtomsAdaptor.get_structure(trajectory["atoms"]).species
Expand Down Expand Up @@ -238,14 +239,16 @@ def from_ase_compatible_result(
n_steps=n_steps,
)

if forcefield_name == "M3GNet":
import matgl

version = matgl.__version__
elif forcefield_name == "CHGNet":
import chgnet
# map force field name to its package name
pkg_name = {
MLFF.M3GNet: "matgl",
MLFF.CHGNet: "chgnet",
MLFF.MACE: "mace-torch",
}.get(forcefield_name) # type: ignore[call-overload]
if pkg_name:
import importlib.metadata

version = chgnet.__version__
version = importlib.metadata.version(pkg_name)
else:
version = "Unknown"
return cls.from_structure(
Expand Down
Loading

0 comments on commit eb6c547

Please sign in to comment.