Skip to content

Commit

Permalink
Add revert_default_dtype context manager to fix clashing global `to…
Browse files Browse the repository at this point in the history
…rch.dtype` between MACE and CHGNet (#766)

* add
revert_default_dtype context manager for restoring torch dtype

* use revert_default_dtype in MACE(Relax|Static)Makers to prevent MACE from clashing with other torch force fields running in the same session
  • Loading branch information
janosh authored Mar 6, 2024
1 parent 6592f7a commit 0ce2086
Show file tree
Hide file tree
Showing 26 changed files with 64 additions and 27 deletions.
1 change: 1 addition & 0 deletions src/atomate2/aims/files.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Functions dealing with FHI-aims files."""

from __future__ import annotations

import logging
Expand Down
1 change: 1 addition & 0 deletions src/atomate2/aims/flows/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""(Work)flows for FHI-aims."""

from __future__ import annotations

from copy import deepcopy
Expand Down
1 change: 1 addition & 0 deletions src/atomate2/aims/flows/gw.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""GW workflows for FHI-aims with automatic convergence."""

from dataclasses import dataclass, field

from atomate2.aims.jobs.convergence import ConvergenceMaker
Expand Down
1 change: 1 addition & 0 deletions src/atomate2/aims/flows/phonons.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Defines the phonon workflows for FHI-aims."""

from __future__ import annotations

from dataclasses import dataclass, field
Expand Down
1 change: 1 addition & 0 deletions src/atomate2/aims/jobs/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define all Core FHI-aims jobs."""

from __future__ import annotations

import logging
Expand Down
1 change: 1 addition & 0 deletions src/atomate2/aims/jobs/phonons.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define the PhononDisplacementMakers for FHI-aims."""

from dataclasses import dataclass, field

from pymatgen.io.aims.sets.base import AimsInputGenerator
Expand Down
1 change: 1 addition & 0 deletions src/atomate2/aims/run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""An FHI-aims jobflow runner."""

from __future__ import annotations

import json
Expand Down
1 change: 1 addition & 0 deletions src/atomate2/aims/schemas/calculation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Schemas for FHI-aims calculation objects."""

from __future__ import annotations

import os
Expand Down
1 change: 1 addition & 0 deletions src/atomate2/aims/schemas/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A definition of a MSON document representing an FHI-aims task."""

from __future__ import annotations

import json
Expand Down
1 change: 0 additions & 1 deletion src/atomate2/aims/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""A collection of helper utils found in atomate2 package."""


from datetime import datetime


Expand Down
1 change: 1 addition & 0 deletions src/atomate2/aims/utils/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Copied from GIMS as of now; should be in its own dedicated FHI-aims python package.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, TypedDict
Expand Down
1 change: 1 addition & 0 deletions src/atomate2/aims/utils/units.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define the Units for FHI-aims calculations."""

from numpy import pi

PI = pi
Expand Down
1 change: 1 addition & 0 deletions src/atomate2/common/flows/eos.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define common EOS flow agnostic to electronic-structure code."""

from __future__ import annotations

import contextlib
Expand Down
6 changes: 3 additions & 3 deletions src/atomate2/common/jobs/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ def eval(self) -> None:

self.results[jobtype]["EOS"] = {}
if ierr not in (1, 2, 3, 4):
self.results[jobtype]["EOS"][
"exception"
] = "Optimal EOS parameters not found."
self.results[jobtype]["EOS"]["exception"] = (
"Optimal EOS parameters not found."
)
else:
for i, key in enumerate(["b0", "b1", "v0"]):
self.results[jobtype]["EOS"][key] = eos_params[i]
Expand Down
1 change: 1 addition & 0 deletions src/atomate2/forcefields/flows/eos.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Flows to generate EOS fits using CHGNet, M3GNet, or MACE."""

from __future__ import annotations

from dataclasses import dataclass, field
Expand Down
20 changes: 11 additions & 9 deletions src/atomate2/forcefields/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from atomate2.forcefields import MLFF
from atomate2.forcefields.schemas import ForceFieldTaskDocument
from atomate2.forcefields.utils import Relaxer
from atomate2.forcefields.utils import Relaxer, revert_default_dtype

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -331,11 +331,12 @@ class MACERelaxMaker(ForceFieldRelaxMaker):
def _relax(self, structure: Structure) -> dict:
from mace.calculators import mace_mp

calculator = mace_mp(model=self.model, **self.model_kwargs)
relaxer = Relaxer(
calculator, relax_cell=self.relax_cell, **self.optimizer_kwargs
)
return relaxer.relax(structure, steps=self.steps, **self.relax_kwargs)
with revert_default_dtype():
calculator = mace_mp(model=self.model, **self.model_kwargs)
relaxer = Relaxer(
calculator, relax_cell=self.relax_cell, **self.optimizer_kwargs
)
return relaxer.relax(structure, steps=self.steps, **self.relax_kwargs)


@dataclass
Expand Down Expand Up @@ -370,9 +371,10 @@ class MACEStaticMaker(ForceFieldStaticMaker):
def _evaluate_static(self, structure: Structure) -> dict:
from mace.calculators import mace_mp

calculator = mace_mp(model=self.model, **self.model_kwargs)
relaxer = Relaxer(calculator, relax_cell=False)
return relaxer.relax(structure, steps=1)
with revert_default_dtype():
calculator = mace_mp(model=self.model, **self.model_kwargs)
relaxer = Relaxer(calculator, relax_cell=False)
return relaxer.relax(structure, steps=1)


@dataclass
Expand Down
18 changes: 18 additions & 0 deletions src/atomate2/forcefields/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pickle
import sys
import warnings
from contextlib import contextmanager
from typing import TYPE_CHECKING

from ase.optimize import BFGS, FIRE, LBFGS, BFGSLineSearch, LBFGSLineSearch, MDMin
Expand All @@ -36,6 +37,7 @@
)

if TYPE_CHECKING:
from collections.abc import Generator
from os import PathLike
from typing import Any

Expand Down Expand Up @@ -211,3 +213,19 @@ def relax(

struct = self.ase_adaptor.get_structure(atoms)
return {"final_structure": struct, "trajectory": obs}


@contextmanager
def revert_default_dtype() -> Generator[None, None, None]:
"""Context manager for torch.default_dtype.
Reverts it to whatever torch.get_default_dtype() was when entering the context.
Originally added for use with MACE(Relax|Static)Maker.
https://github.com/ACEsuit/mace/issues/328
"""
import torch

orig = torch.get_default_dtype()
yield
torch.set_default_dtype(orig)
6 changes: 3 additions & 3 deletions src/atomate2/vasp/flows/electrode.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,6 @@ def update_static_maker(self) -> None:
self.static_maker.task_document_kwargs.get("store_volumetric_data", [])
)
store_volumetric_data.extend(["aeccar0", "aeccar2"])
self.static_maker.task_document_kwargs[
"store_volumetric_data"
] = store_volumetric_data
self.static_maker.task_document_kwargs["store_volumetric_data"] = (
store_volumetric_data
)
1 change: 1 addition & 0 deletions src/atomate2/vasp/flows/phonons.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define the VASP PhononMaker."""

from __future__ import annotations

from dataclasses import dataclass, field
Expand Down
1 change: 1 addition & 0 deletions src/atomate2/vasp/jobs/phonons.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define the PhononDisplacementMaker for VASP."""

from dataclasses import dataclass, field

from atomate2.vasp.jobs.base import BaseVaspMaker
Expand Down
1 change: 1 addition & 0 deletions tests/aims/test_flows/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test core FHI-aims workflows"""

import os

import pytest
Expand Down
1 change: 1 addition & 0 deletions tests/aims/test_flows/test_gw_convergence.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A test for GW workflows for FHI-aims."""

import pytest

# from atomate2.aims.utils.msonable_atoms import MSONableAtoms
Expand Down
1 change: 1 addition & 0 deletions tests/aims/test_flows/test_phonon_workflow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test various makers"""

import json
import os

Expand Down
3 changes: 1 addition & 2 deletions tests/aims/test_makers/test_convergence.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
""" A test for AIMS convergence maker (used for GW, for instance)
"""
"""A test for AIMS convergence maker (used for GW, for instance)"""

import os

Expand Down
1 change: 1 addition & 0 deletions tests/aims/test_makers/test_static.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test various makers"""

import os

import pytest
Expand Down
18 changes: 9 additions & 9 deletions tests/vasp/flows/test_eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ def test_mp_eos_maker(
"EOS MP GGA relax 2": expected_incar_relax,
}

for i in range(2):
ref_paths[f"EOS MP GGA relax {1+i}"] = f"mp-149-PBE-EOS_MP_GGA_relax_{1+i}"
for idx in range(2):
ref_paths[f"EOS MP GGA relax {1+idx}"] = f"mp-149-PBE-EOS_MP_GGA_relax_{1+idx}"

for i in range(nframes):
ref_paths[
f"EOS MP GGA relax deformation {i}"
] = f"mp-149-PBE-EOS_Deformation_Relax_{i}"
expected_incars[f"EOS MP GGA relax deformation {i}"] = expected_incar_deform
for idx in range(nframes):
ref_paths[f"EOS MP GGA relax deformation {idx}"] = (
f"mp-149-PBE-EOS_Deformation_Relax_{idx}"
)
expected_incars[f"EOS MP GGA relax deformation {idx}"] = expected_incar_deform

if do_statics:
ref_paths[f"EOS MP GGA static {i}"] = f"mp-149-PBE-EOS_Static_{i}"
expected_incars[f"EOS MP GGA static {i}"] = expected_incar_static
ref_paths[f"EOS MP GGA static {idx}"] = f"mp-149-PBE-EOS_Static_{idx}"
expected_incars[f"EOS MP GGA static {idx}"] = expected_incar_static

if do_statics:
ref_paths["EOS equilibrium static"] = "mp-149-PBE-EOS_equilibrium_static"
Expand Down

0 comments on commit 0ce2086

Please sign in to comment.