Skip to content

Commit

Permalink
Add nequip force field relax- and staticmaker and corresponding tests (
Browse files Browse the repository at this point in the history
…#764)

* add nequip force field relax- and staticmaker and corresponding tests

* fix model_kwargs doc string

* change tests to smaller force field

* add nequip and torch-runstats to testing.yml for CI

* Add nequip and torch-runstats to Install dependencies in deploy.yml

* Revert "Add nequip and torch-runstats to Install dependencies in deploy.yml"

This reverts commit 7080af5.

* minor refactor

* rename nequip_sto_test_ff.pth to nequip_ff_sr_ti_o3.pth

---------

Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
JonathanSchmidt1 and janosh authored Mar 7, 2024
1 parent ee0f525 commit 573184c
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ jobs:
# ase needed to get FrechetCellFilter used by ML force fields
pip install git+https://gitlab.com/ase/ase
pip install .[strict,tests]
pip install torch-runstats
pip install --no-deps nequip
- name: Test
env:
Expand Down
1 change: 1 addition & 0 deletions src/atomate2/forcefields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ class MLFF(Enum): # TODO inherit from StrEnum when 3.11+
GAP = "GAP"
M3GNet = "M3GNet"
CHGNet = "CHGNet"
Nequip = "Nequip"
92 changes: 92 additions & 0 deletions src/atomate2/forcefields/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,98 @@ def _relax(self, structure: Structure) -> dict:
return relaxer.relax(structure, steps=self.steps, **self.relax_kwargs)


@dataclass
class NequipRelaxMaker(ForceFieldRelaxMaker):
"""
Maker to perform a relaxation using a Nequip force field.
Parameters
----------
name : str
The job name.
force_field_name : str
The name of the force field.
relax_cell : bool = True
Whether to allow the cell shape/volume to change during relaxation.
steps : int
Maximum number of ionic steps allowed during relaxation.
relax_kwargs : dict
Keyword arguments that will get passed to :obj:`Relaxer.relax`.
optimizer_kwargs : dict
Keyword arguments that will get passed to :obj:`Relaxer()`.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
model_path: str | Path
deployed model checkpoint to load with
:obj:`nequip.calculators.NequipCalculator.from_deployed_model()'`.
model_kwargs: dict[str, Any]
Further keywords (e.g. device: Union[str, torch.device],
species_to_type_name: Optional[Dict[str, str]] = None) for
:obj:`nequip.calculators.NequipCalculator()'`.
"""

name: str = f"{MLFF.Nequip} relax"
force_field_name: str = f"{MLFF.Nequip}"
relax_cell: bool = True
steps: int = 500
relax_kwargs: dict = field(default_factory=dict)
optimizer_kwargs: dict = field(default_factory=dict)
task_document_kwargs: dict = field(default_factory=dict)
model_path: str | Path = ""
model_kwargs: dict = field(default_factory=dict)

def _relax(self, structure: Structure) -> dict:
from nequip.ase import NequIPCalculator

calculator = NequIPCalculator.from_deployed_model(
self.model_path, **self.model_kwargs
)
relaxer = Relaxer(
calculator, relax_cell=self.relax_cell, **self.optimizer_kwargs
)

return relaxer.relax(structure, steps=self.steps, **self.relax_kwargs)


@dataclass
class NequipStaticMaker(ForceFieldStaticMaker):
"""
Maker to calculate energies, forces and stresses using a nequip force field.
Parameters
----------
name : str
The job name.
force_field_name : str
The name of the force field.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
model_path: str | Path
deployed model checkpoint to load with
:obj:`nequip.calculators.NequipCalculator()'`.
model_kwargs: dict[str, Any]
Further keywords (e.g. device: Union[str, torch.device],
species_to_type_name: Optional[Dict[str, str]] = None) for
:obj:`nequip.calculators.NequipCalculator()'`.
"""

name: str = f"{MLFF.Nequip} static"
force_field_name: str = f"{MLFF.Nequip}"
task_document_kwargs: dict = field(default_factory=dict)
model_path: str | Path = ""
model_kwargs: dict = field(default_factory=dict)

def _evaluate_static(self, structure: Structure) -> dict:
from nequip.ase import NequIPCalculator

calculator = NequIPCalculator.from_deployed_model(
self.model_path, **self.model_kwargs
)
relaxer = Relaxer(calculator, relax_cell=False)

return relaxer.relax(structure, steps=1)


@dataclass
class M3GNetStaticMaker(ForceFieldStaticMaker):
"""
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ def si_structure(test_dir):
return Structure.from_file(test_dir / "structures" / "Si.cif")


@pytest.fixture()
def sr_ti_o3_structure(test_dir):
return Structure.from_file(test_dir / "structures" / "SrTiO3.cif")


@pytest.fixture(autouse=True)
def mock_jobflow_settings(memory_jobstore):
"""Mock the jobflow settings to use our specific jobstore (with data store)."""
Expand Down
63 changes: 58 additions & 5 deletions tests/forcefields/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from jobflow import run_locally
from pymatgen.core import Structure
from pytest import approx, importorskip

from atomate2.forcefields.jobs import (
Expand All @@ -13,6 +14,8 @@
M3GNetStaticMaker,
MACERelaxMaker,
MACEStaticMaker,
NequipRelaxMaker,
NequipStaticMaker,
)
from atomate2.forcefields.schemas import ForceFieldTaskDocument

Expand All @@ -35,7 +38,7 @@ def test_chgnet_static_maker(si_structure):


@pytest.mark.parametrize("relax_cell", [True, False])
def test_chgnet_relax_maker(si_structure, relax_cell: bool):
def test_chgnet_relax_maker(si_structure: Structure, relax_cell: bool):
# translate one atom to ensure a small number of relaxation steps are taken
si_structure.translate_sites(0, [0, 0, 0.1])

Expand Down Expand Up @@ -101,7 +104,7 @@ def test_m3gnet_relax_maker(si_structure):


@mace_paths
def test_mace_static_maker(si_structure, test_dir, model):
def test_mace_static_maker(si_structure: Structure, test_dir: Path, model):
task_doc_kwargs = {"ionic_step_data": ("structure", "energy")}

# generate job
Expand All @@ -122,7 +125,9 @@ def test_mace_static_maker(si_structure, test_dir, model):

@pytest.mark.parametrize("relax_cell", [True, False])
@mace_paths
def test_mace_relax_maker(si_structure, test_dir, model, relax_cell: bool):
def test_mace_relax_maker(
si_structure: Structure, test_dir: Path, model, relax_cell: bool
):
# translate one atom to ensure a small number of relaxation steps are taken
si_structure.translate_sites(0, [0, 0, 0.1])

Expand All @@ -149,7 +154,7 @@ def test_mace_relax_maker(si_structure, test_dir, model, relax_cell: bool):
assert output1.output.n_steps == 4


def test_gap_static_maker(si_structure, test_dir):
def test_gap_static_maker(si_structure: Structure, test_dir):
importorskip("quippy")

task_doc_kwargs = {"ionic_step_data": ("structure", "energy")}
Expand All @@ -172,8 +177,56 @@ def test_gap_static_maker(si_structure, test_dir):
assert output1.output.n_steps == 1


def test_nequip_static_maker(sr_ti_o3_structure: Structure, test_dir: Path):
task_doc_kwargs = {"ionic_step_data": ("structure", "energy")}

# generate job
# NOTE the test model is not trained on Si, so the energy is not accurate
job = NequipStaticMaker(
task_document_kwargs=task_doc_kwargs,
model_path=test_dir / "forcefields" / "nequip" / "nequip_ff_sr_ti_o3.pth",
).make(sr_ti_o3_structure)

# run the flow or job and ensure that it finished running successfully
responses = run_locally(job, ensure_success=True)

# validation the outputs of the job
output1 = responses[job.uuid][1].output
assert isinstance(output1, ForceFieldTaskDocument)
assert output1.output.energy == approx(-44.40017, rel=1e-4)
assert output1.output.n_steps == 1


@pytest.mark.parametrize("relax_cell", [True, False])
def test_nequip_relax_maker(
sr_ti_o3_structure: Structure, test_dir: Path, relax_cell: bool
):
# translate one atom to ensure a small number of relaxation steps are taken
sr_ti_o3_structure.translate_sites(0, [0, 0, 0.2])
# generate job
job = NequipRelaxMaker(
steps=25,
optimizer_kwargs={"optimizer": "BFGSLineSearch"},
relax_cell=relax_cell,
model_path=test_dir / "forcefields" / "nequip" / "nequip_ff_sr_ti_o3.pth",
).make(sr_ti_o3_structure)

# run the flow or job and ensure that it finished running successfully
responses = run_locally(job, ensure_success=True)

# validation the outputs of the job
output1 = responses[job.uuid][1].output
assert isinstance(output1, ForceFieldTaskDocument)
if relax_cell:
assert output1.output.energy == approx(-44.407, rel=1e-3)
assert output1.output.n_steps == 5
else:
assert output1.output.energy == approx(-44.40015, rel=1e-4)
assert output1.output.n_steps == 5


@pytest.mark.parametrize("relax_cell", [True, False])
def test_gap_relax_maker(si_structure, test_dir, relax_cell: bool):
def test_gap_relax_maker(si_structure: Structure, test_dir: Path, relax_cell: bool):
importorskip("quippy")

# translate one atom to ensure a small number of relaxation steps are taken
Expand Down
Binary file not shown.
81 changes: 81 additions & 0 deletions tests/test_data/structures/SrTiO3.cif
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# generated using pymatgen
data_SrTiO3
_symmetry_space_group_name_H-M Pm-3m
_cell_length_a 3.91270131
_cell_length_b 3.91270131
_cell_length_c 3.91270131
_cell_angle_alpha 90.00000000
_cell_angle_beta 90.00000000
_cell_angle_gamma 90.00000000
_symmetry_Int_Tables_number 221
_chemical_formula_structural SrTiO3
_chemical_formula_sum 'Sr1 Ti1 O3'
_cell_volume 59.90045031
_cell_formula_units_Z 1
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x, y, z'
2 '-x, -y, -z'
3 '-y, x, z'
4 'y, -x, -z'
5 '-x, -y, z'
6 'x, y, -z'
7 'y, -x, z'
8 '-y, x, -z'
9 'x, -y, -z'
10 '-x, y, z'
11 '-y, -x, -z'
12 'y, x, z'
13 '-x, y, -z'
14 'x, -y, z'
15 'y, x, -z'
16 '-y, -x, z'
17 'z, x, y'
18 '-z, -x, -y'
19 'z, -y, x'
20 '-z, y, -x'
21 'z, -x, -y'
22 '-z, x, y'
23 'z, y, -x'
24 '-z, -y, x'
25 '-z, x, -y'
26 'z, -x, y'
27 '-z, -y, -x'
28 'z, y, x'
29 '-z, -x, y'
30 'z, x, -y'
31 '-z, y, x'
32 'z, -y, -x'
33 'y, z, x'
34 '-y, -z, -x'
35 'x, z, -y'
36 '-x, -z, y'
37 '-y, z, -x'
38 'y, -z, x'
39 '-x, z, y'
40 'x, -z, -y'
41 '-y, -z, x'
42 'y, z, -x'
43 '-x, -z, -y'
44 'x, z, y'
45 'y, -z, -x'
46 '-y, z, x'
47 'x, -z, y'
48 '-x, z, -y'
loop_
_atom_type_symbol
Sr
Ti
O2
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Sr Sr 1 0.00000000 0.00000000 0.00000000 1
Ti Ti 1 0.50000000 0.50000000 0.50000000 1
O O 3 0.00000000 0.50000000 0.50000000 1

0 comments on commit 573184c

Please sign in to comment.