From 0ce2086792c78d3294ba71fac309d1f51607c027 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 6 Mar 2024 14:00:50 +0100 Subject: [PATCH] Add `revert_default_dtype` context manager to fix clashing global `torch.dtype` between MACE and CHGNet (#766) * add revert_default_dtype context manager for restoring torch dtype * use revert_default_dtype in MACE(Relax|Static)Makers to prevent MACE from clashing with other torch force fields running in the same session --- src/atomate2/aims/files.py | 1 + src/atomate2/aims/flows/core.py | 1 + src/atomate2/aims/flows/gw.py | 1 + src/atomate2/aims/flows/phonons.py | 1 + src/atomate2/aims/jobs/core.py | 1 + src/atomate2/aims/jobs/phonons.py | 1 + src/atomate2/aims/run.py | 1 + src/atomate2/aims/schemas/calculation.py | 1 + src/atomate2/aims/schemas/task.py | 1 + src/atomate2/aims/utils/__init__.py | 1 - src/atomate2/aims/utils/bands.py | 1 + src/atomate2/aims/utils/units.py | 1 + src/atomate2/common/flows/eos.py | 1 + src/atomate2/common/jobs/eos.py | 6 +++--- src/atomate2/forcefields/flows/eos.py | 1 + src/atomate2/forcefields/jobs.py | 20 ++++++++++--------- src/atomate2/forcefields/utils.py | 18 +++++++++++++++++ src/atomate2/vasp/flows/electrode.py | 6 +++--- src/atomate2/vasp/flows/phonons.py | 1 + src/atomate2/vasp/jobs/phonons.py | 1 + tests/aims/test_flows/test_core.py | 1 + tests/aims/test_flows/test_gw_convergence.py | 1 + tests/aims/test_flows/test_phonon_workflow.py | 1 + tests/aims/test_makers/test_convergence.py | 3 +-- tests/aims/test_makers/test_static.py | 1 + tests/vasp/flows/test_eos.py | 18 ++++++++--------- 26 files changed, 64 insertions(+), 27 deletions(-) diff --git a/src/atomate2/aims/files.py b/src/atomate2/aims/files.py index 5e1a9d149f..fd330a004a 100644 --- a/src/atomate2/aims/files.py +++ b/src/atomate2/aims/files.py @@ -1,4 +1,5 @@ """Functions dealing with FHI-aims files.""" + from __future__ import annotations import logging diff --git a/src/atomate2/aims/flows/core.py b/src/atomate2/aims/flows/core.py index a47c0124a7..283417887c 100644 --- a/src/atomate2/aims/flows/core.py +++ b/src/atomate2/aims/flows/core.py @@ -1,4 +1,5 @@ """(Work)flows for FHI-aims.""" + from __future__ import annotations from copy import deepcopy diff --git a/src/atomate2/aims/flows/gw.py b/src/atomate2/aims/flows/gw.py index 75c26f3c48..d63d103cf5 100644 --- a/src/atomate2/aims/flows/gw.py +++ b/src/atomate2/aims/flows/gw.py @@ -1,4 +1,5 @@ """GW workflows for FHI-aims with automatic convergence.""" + from dataclasses import dataclass, field from atomate2.aims.jobs.convergence import ConvergenceMaker diff --git a/src/atomate2/aims/flows/phonons.py b/src/atomate2/aims/flows/phonons.py index 817ee9a842..d3bc4d532e 100644 --- a/src/atomate2/aims/flows/phonons.py +++ b/src/atomate2/aims/flows/phonons.py @@ -1,4 +1,5 @@ """Defines the phonon workflows for FHI-aims.""" + from __future__ import annotations from dataclasses import dataclass, field diff --git a/src/atomate2/aims/jobs/core.py b/src/atomate2/aims/jobs/core.py index a4f403b1d0..4bfca85377 100644 --- a/src/atomate2/aims/jobs/core.py +++ b/src/atomate2/aims/jobs/core.py @@ -1,4 +1,5 @@ """Define all Core FHI-aims jobs.""" + from __future__ import annotations import logging diff --git a/src/atomate2/aims/jobs/phonons.py b/src/atomate2/aims/jobs/phonons.py index 13b11dac81..e19faca568 100644 --- a/src/atomate2/aims/jobs/phonons.py +++ b/src/atomate2/aims/jobs/phonons.py @@ -1,4 +1,5 @@ """Define the PhononDisplacementMakers for FHI-aims.""" + from dataclasses import dataclass, field from pymatgen.io.aims.sets.base import AimsInputGenerator diff --git a/src/atomate2/aims/run.py b/src/atomate2/aims/run.py index 93b3218644..1d2294f55e 100644 --- a/src/atomate2/aims/run.py +++ b/src/atomate2/aims/run.py @@ -1,4 +1,5 @@ """An FHI-aims jobflow runner.""" + from __future__ import annotations import json diff --git a/src/atomate2/aims/schemas/calculation.py b/src/atomate2/aims/schemas/calculation.py index 6636314d58..8b99d24438 100644 --- a/src/atomate2/aims/schemas/calculation.py +++ b/src/atomate2/aims/schemas/calculation.py @@ -1,4 +1,5 @@ """Schemas for FHI-aims calculation objects.""" + from __future__ import annotations import os diff --git a/src/atomate2/aims/schemas/task.py b/src/atomate2/aims/schemas/task.py index 6628737453..6dc5e6f860 100644 --- a/src/atomate2/aims/schemas/task.py +++ b/src/atomate2/aims/schemas/task.py @@ -1,4 +1,5 @@ """A definition of a MSON document representing an FHI-aims task.""" + from __future__ import annotations import json diff --git a/src/atomate2/aims/utils/__init__.py b/src/atomate2/aims/utils/__init__.py index fd0d7dc2a3..c59cdb9718 100644 --- a/src/atomate2/aims/utils/__init__.py +++ b/src/atomate2/aims/utils/__init__.py @@ -1,6 +1,5 @@ """A collection of helper utils found in atomate2 package.""" - from datetime import datetime diff --git a/src/atomate2/aims/utils/bands.py b/src/atomate2/aims/utils/bands.py index 20b690d249..ead5651a26 100644 --- a/src/atomate2/aims/utils/bands.py +++ b/src/atomate2/aims/utils/bands.py @@ -2,6 +2,7 @@ Copied from GIMS as of now; should be in its own dedicated FHI-aims python package. """ + from __future__ import annotations from typing import TYPE_CHECKING, TypedDict diff --git a/src/atomate2/aims/utils/units.py b/src/atomate2/aims/utils/units.py index 72a8d81e8c..51ea034922 100644 --- a/src/atomate2/aims/utils/units.py +++ b/src/atomate2/aims/utils/units.py @@ -1,4 +1,5 @@ """Define the Units for FHI-aims calculations.""" + from numpy import pi PI = pi diff --git a/src/atomate2/common/flows/eos.py b/src/atomate2/common/flows/eos.py index b919f5b78a..6dc3d4531d 100644 --- a/src/atomate2/common/flows/eos.py +++ b/src/atomate2/common/flows/eos.py @@ -1,4 +1,5 @@ """Define common EOS flow agnostic to electronic-structure code.""" + from __future__ import annotations import contextlib diff --git a/src/atomate2/common/jobs/eos.py b/src/atomate2/common/jobs/eos.py index b5c0b5cbad..ffc4d0e54d 100644 --- a/src/atomate2/common/jobs/eos.py +++ b/src/atomate2/common/jobs/eos.py @@ -320,9 +320,9 @@ def eval(self) -> None: self.results[jobtype]["EOS"] = {} if ierr not in (1, 2, 3, 4): - self.results[jobtype]["EOS"][ - "exception" - ] = "Optimal EOS parameters not found." + self.results[jobtype]["EOS"]["exception"] = ( + "Optimal EOS parameters not found." + ) else: for i, key in enumerate(["b0", "b1", "v0"]): self.results[jobtype]["EOS"][key] = eos_params[i] diff --git a/src/atomate2/forcefields/flows/eos.py b/src/atomate2/forcefields/flows/eos.py index 9dc8756fc4..2208e34f82 100644 --- a/src/atomate2/forcefields/flows/eos.py +++ b/src/atomate2/forcefields/flows/eos.py @@ -1,4 +1,5 @@ """Flows to generate EOS fits using CHGNet, M3GNet, or MACE.""" + from __future__ import annotations from dataclasses import dataclass, field diff --git a/src/atomate2/forcefields/jobs.py b/src/atomate2/forcefields/jobs.py index e372c1d2eb..047694654c 100644 --- a/src/atomate2/forcefields/jobs.py +++ b/src/atomate2/forcefields/jobs.py @@ -10,7 +10,7 @@ from atomate2.forcefields import MLFF from atomate2.forcefields.schemas import ForceFieldTaskDocument -from atomate2.forcefields.utils import Relaxer +from atomate2.forcefields.utils import Relaxer, revert_default_dtype if TYPE_CHECKING: from collections.abc import Sequence @@ -331,11 +331,12 @@ class MACERelaxMaker(ForceFieldRelaxMaker): def _relax(self, structure: Structure) -> dict: from mace.calculators import mace_mp - calculator = mace_mp(model=self.model, **self.model_kwargs) - relaxer = Relaxer( - calculator, relax_cell=self.relax_cell, **self.optimizer_kwargs - ) - return relaxer.relax(structure, steps=self.steps, **self.relax_kwargs) + with revert_default_dtype(): + calculator = mace_mp(model=self.model, **self.model_kwargs) + relaxer = Relaxer( + calculator, relax_cell=self.relax_cell, **self.optimizer_kwargs + ) + return relaxer.relax(structure, steps=self.steps, **self.relax_kwargs) @dataclass @@ -370,9 +371,10 @@ class MACEStaticMaker(ForceFieldStaticMaker): def _evaluate_static(self, structure: Structure) -> dict: from mace.calculators import mace_mp - calculator = mace_mp(model=self.model, **self.model_kwargs) - relaxer = Relaxer(calculator, relax_cell=False) - return relaxer.relax(structure, steps=1) + with revert_default_dtype(): + calculator = mace_mp(model=self.model, **self.model_kwargs) + relaxer = Relaxer(calculator, relax_cell=False) + return relaxer.relax(structure, steps=1) @dataclass diff --git a/src/atomate2/forcefields/utils.py b/src/atomate2/forcefields/utils.py index e0875e0d80..c02fef56a7 100644 --- a/src/atomate2/forcefields/utils.py +++ b/src/atomate2/forcefields/utils.py @@ -14,6 +14,7 @@ import pickle import sys import warnings +from contextlib import contextmanager from typing import TYPE_CHECKING from ase.optimize import BFGS, FIRE, LBFGS, BFGSLineSearch, LBFGSLineSearch, MDMin @@ -36,6 +37,7 @@ ) if TYPE_CHECKING: + from collections.abc import Generator from os import PathLike from typing import Any @@ -211,3 +213,19 @@ def relax( struct = self.ase_adaptor.get_structure(atoms) return {"final_structure": struct, "trajectory": obs} + + +@contextmanager +def revert_default_dtype() -> Generator[None, None, None]: + """Context manager for torch.default_dtype. + + Reverts it to whatever torch.get_default_dtype() was when entering the context. + + Originally added for use with MACE(Relax|Static)Maker. + https://github.com/ACEsuit/mace/issues/328 + """ + import torch + + orig = torch.get_default_dtype() + yield + torch.set_default_dtype(orig) diff --git a/src/atomate2/vasp/flows/electrode.py b/src/atomate2/vasp/flows/electrode.py index 85a40ef19d..b416a02722 100644 --- a/src/atomate2/vasp/flows/electrode.py +++ b/src/atomate2/vasp/flows/electrode.py @@ -78,6 +78,6 @@ def update_static_maker(self) -> None: self.static_maker.task_document_kwargs.get("store_volumetric_data", []) ) store_volumetric_data.extend(["aeccar0", "aeccar2"]) - self.static_maker.task_document_kwargs[ - "store_volumetric_data" - ] = store_volumetric_data + self.static_maker.task_document_kwargs["store_volumetric_data"] = ( + store_volumetric_data + ) diff --git a/src/atomate2/vasp/flows/phonons.py b/src/atomate2/vasp/flows/phonons.py index 33a2a3e4e3..e07748db61 100644 --- a/src/atomate2/vasp/flows/phonons.py +++ b/src/atomate2/vasp/flows/phonons.py @@ -1,4 +1,5 @@ """Define the VASP PhononMaker.""" + from __future__ import annotations from dataclasses import dataclass, field diff --git a/src/atomate2/vasp/jobs/phonons.py b/src/atomate2/vasp/jobs/phonons.py index b4f8938966..8b8a169560 100644 --- a/src/atomate2/vasp/jobs/phonons.py +++ b/src/atomate2/vasp/jobs/phonons.py @@ -1,4 +1,5 @@ """Define the PhononDisplacementMaker for VASP.""" + from dataclasses import dataclass, field from atomate2.vasp.jobs.base import BaseVaspMaker diff --git a/tests/aims/test_flows/test_core.py b/tests/aims/test_flows/test_core.py index 6a67d792ed..4c467a39dc 100644 --- a/tests/aims/test_flows/test_core.py +++ b/tests/aims/test_flows/test_core.py @@ -1,4 +1,5 @@ """Test core FHI-aims workflows""" + import os import pytest diff --git a/tests/aims/test_flows/test_gw_convergence.py b/tests/aims/test_flows/test_gw_convergence.py index f5f6fa4a0e..d3c5c0fb2e 100644 --- a/tests/aims/test_flows/test_gw_convergence.py +++ b/tests/aims/test_flows/test_gw_convergence.py @@ -1,4 +1,5 @@ """A test for GW workflows for FHI-aims.""" + import pytest # from atomate2.aims.utils.msonable_atoms import MSONableAtoms diff --git a/tests/aims/test_flows/test_phonon_workflow.py b/tests/aims/test_flows/test_phonon_workflow.py index 9aad38cc22..1270905fb9 100644 --- a/tests/aims/test_flows/test_phonon_workflow.py +++ b/tests/aims/test_flows/test_phonon_workflow.py @@ -1,4 +1,5 @@ """Test various makers""" + import json import os diff --git a/tests/aims/test_makers/test_convergence.py b/tests/aims/test_makers/test_convergence.py index 451ace7d48..be767eac49 100644 --- a/tests/aims/test_makers/test_convergence.py +++ b/tests/aims/test_makers/test_convergence.py @@ -1,5 +1,4 @@ -""" A test for AIMS convergence maker (used for GW, for instance) -""" +"""A test for AIMS convergence maker (used for GW, for instance)""" import os diff --git a/tests/aims/test_makers/test_static.py b/tests/aims/test_makers/test_static.py index fdaf4ee680..24a3c02ff1 100644 --- a/tests/aims/test_makers/test_static.py +++ b/tests/aims/test_makers/test_static.py @@ -1,4 +1,5 @@ """Test various makers""" + import os import pytest diff --git a/tests/vasp/flows/test_eos.py b/tests/vasp/flows/test_eos.py index af0bd4463c..1dfe500c0b 100644 --- a/tests/vasp/flows/test_eos.py +++ b/tests/vasp/flows/test_eos.py @@ -85,18 +85,18 @@ def test_mp_eos_maker( "EOS MP GGA relax 2": expected_incar_relax, } - for i in range(2): - ref_paths[f"EOS MP GGA relax {1+i}"] = f"mp-149-PBE-EOS_MP_GGA_relax_{1+i}" + for idx in range(2): + ref_paths[f"EOS MP GGA relax {1+idx}"] = f"mp-149-PBE-EOS_MP_GGA_relax_{1+idx}" - for i in range(nframes): - ref_paths[ - f"EOS MP GGA relax deformation {i}" - ] = f"mp-149-PBE-EOS_Deformation_Relax_{i}" - expected_incars[f"EOS MP GGA relax deformation {i}"] = expected_incar_deform + for idx in range(nframes): + ref_paths[f"EOS MP GGA relax deformation {idx}"] = ( + f"mp-149-PBE-EOS_Deformation_Relax_{idx}" + ) + expected_incars[f"EOS MP GGA relax deformation {idx}"] = expected_incar_deform if do_statics: - ref_paths[f"EOS MP GGA static {i}"] = f"mp-149-PBE-EOS_Static_{i}" - expected_incars[f"EOS MP GGA static {i}"] = expected_incar_static + ref_paths[f"EOS MP GGA static {idx}"] = f"mp-149-PBE-EOS_Static_{idx}" + expected_incars[f"EOS MP GGA static {idx}"] = expected_incar_static if do_statics: ref_paths["EOS equilibrium static"] = "mp-149-PBE-EOS_equilibrium_static"