Skip to content

Commit

Permalink
Add missing calculator_kwargs and remove outdated model/`model_kw…
Browse files Browse the repository at this point in the history
…args` in `ForceFieldRelaxMaker` doc strings (#830)

* refactor

* document calculator_kwargs in all ForceFieldRelaxMaker subclasses

* document calculator_kwargs in Mace Makers and remove no-longer-existent model and model_kwargs from doc string

* improve test_ext_load
  • Loading branch information
janosh authored May 3, 2024
1 parent 3486792 commit 7e22064
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 36 deletions.
46 changes: 30 additions & 16 deletions src/atomate2/forcefields/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ class ForceFieldStaticMaker(ForceFieldRelaxMaker):
The job name.
force_field_name : str
The name of the force field.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down Expand Up @@ -209,6 +211,8 @@ class CHGNetRelaxMaker(ForceFieldRelaxMaker):
Keyword arguments that will get passed to :obj:`Relaxer.relax`.
optimizer_kwargs : dict
Keyword arguments that will get passed to :obj:`Relaxer()`.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down Expand Up @@ -236,6 +240,8 @@ class CHGNetStaticMaker(ForceFieldStaticMaker):
----------
name : str
The job name.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down Expand Up @@ -272,6 +278,8 @@ class M3GNetRelaxMaker(ForceFieldRelaxMaker):
Keyword arguments that will get passed to :obj:`Relaxer.relax`.
optimizer_kwargs : dict
Keyword arguments that will get passed to :obj:`Relaxer()`.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down Expand Up @@ -314,6 +322,8 @@ class NequipRelaxMaker(ForceFieldRelaxMaker):
Keyword arguments that will get passed to :obj:`Relaxer.relax`.
optimizer_kwargs : dict
Keyword arguments that will get passed to :obj:`Relaxer()`.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand All @@ -340,6 +350,8 @@ class NequipStaticMaker(ForceFieldStaticMaker):
The job name.
force_field_name : str
The name of the force field.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand All @@ -360,6 +372,8 @@ class M3GNetStaticMaker(ForceFieldStaticMaker):
The job name.
force_field_name : str
The name of the force field.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down Expand Up @@ -396,16 +410,14 @@ class MACERelaxMaker(ForceFieldRelaxMaker):
Keyword arguments that will get passed to :obj:`Relaxer.relax`.
optimizer_kwargs : dict
Keyword arguments that will get passed to :obj:`Relaxer()`.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator. E.g. the "model"
key configures which checkpoint to load with mace.calculators.MACECalculator().
Can be a URL starting with https://. If not set, loads the universal MACE-MP
trained for Matbench Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
model: str | Path | None
Checkpoint to load with :obj:`mace.calculators.MACECalculator()'`. Can be a URL
starting with https://. If None, loads the universal MACE trained for Matbench
Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
model_kwargs: dict[str, Any]
Further keywords (e.g. device, default_dtype, model) for
:obj:`mace.calculators.MACECalculator()'`.
"""

name: str = f"{MLFF.MACE} relax"
Expand All @@ -430,16 +442,14 @@ class MACEStaticMaker(ForceFieldStaticMaker):
The job name.
force_field_name : str
The name of the force field.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator. E.g. the "model"
key configures which checkpoint to load with mace.calculators.MACECalculator().
Can be a URL starting with https://. If not set, loads the universal MACE-MP
trained for Matbench Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
model: str | Path | None
Checkpoint to load with :obj:`mace.calculators.MACECalculator()'`. Can be a URL
starting with https://. If None, loads the universal MACE trained for Matbench
Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
model_kwargs: dict[str, Any]
Further keywords (e.g. device, default_dtype, model) for
:obj:`mace.calculators.MACECalculator()'`.
"""

name: str = f"{MLFF.MACE} static"
Expand Down Expand Up @@ -471,6 +481,8 @@ class GAPRelaxMaker(ForceFieldRelaxMaker):
Keyword arguments that will get passed to :obj:`Relaxer.relax`.
optimizer_kwargs : dict
Keyword arguments that will get passed to :obj:`Relaxer()`.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down Expand Up @@ -503,6 +515,8 @@ class GAPStaticMaker(ForceFieldStaticMaker):
The job name.
force_field_name : str
The name of the force field.
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
"""
Expand Down
4 changes: 1 addition & 3 deletions src/atomate2/forcefields/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,7 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N
"""
calculator = None

if isinstance(calculator_meta, str) and calculator_meta in [
f"{name}" for name in MLFF
]:
if isinstance(calculator_meta, str) and calculator_meta in map(str, MLFF):
calculator_name = MLFF(calculator_meta.split("MLFF.")[-1])

if calculator_name == MLFF.CHGNet:
Expand Down
10 changes: 5 additions & 5 deletions src/atomate2/vasp/flows/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def make(
"""
md_job = None
md_jobs = []
for idx, maker in enumerate(self.md_makers, 1):
if md_job is None:
md_structure = structure
md_prev_dir = prev_dir
else:
md_structure = structure
md_prev_dir = prev_dir

for idx, maker in enumerate(self.md_makers, start=1):
if md_job is not None:
md_structure = md_job.output.structure
md_prev_dir = md_job.output.dir_name
md_job = maker.make(md_structure, prev_dir=md_prev_dir)
Expand Down
10 changes: 5 additions & 5 deletions src/atomate2/vasp/jobs/lobster.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ def get_basis_infos(
address_basis_file_min=address_min_basis,
)

nband_list = []
n_band_list: list[int] = []
for dict_for_basis in list_basis_dict:
basis = [f"{key} {value}" for key, value in dict_for_basis.items()]
lobsterin = Lobsterin(settingsdict={"basisfunctions": basis})
nbands = lobsterin._get_nbands(structure=structure)
nband_list.append(nbands)
n_bands = lobsterin._get_nbands(structure=structure)
n_band_list.append(n_bands)

return {"nbands": max(nband_list), "basis_dict": list_basis_dict}
return {"nbands": max(n_band_list), "basis_dict": list_basis_dict}


@job
Expand All @@ -143,7 +143,7 @@ def update_user_incar_settings_maker(
Parameters
----------
vasp_maker : .BaseVaspMaker
A maker for the static run with all parammeters
A maker for the static run with all parameters
relevant for Lobster.
nbands : int
integer indicating the correct number of bands
Expand Down
16 changes: 9 additions & 7 deletions tests/forcefields/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,17 @@ def test_relaxer(si_structure, test_dir, tmp_dir, optimizer, traj_file):
assert os.path.isfile(traj_file)


def test_ext_load():
force_field_to_callable = {
@pytest.mark.parametrize(("force_field"), ["CHGNet", "MACE"])
def test_ext_load(force_field: str):
decode_dict = {
"CHGNet": {"@module": "chgnet.model.dynamics", "@callable": "CHGNetCalculator"},
"MACE": {"@module": "mace.calculators", "@callable": "mace_mp"},
}
for force_field in ("CHGNet", "MACE"):
calc_from_decode = ase_calculator(force_field_to_callable[force_field])
calc_from_preset = ase_calculator(f"{MLFF(force_field)}")
assert isinstance(calc_from_decode, type(calc_from_preset))
}[force_field]
calc_from_decode = ase_calculator(decode_dict)
calc_from_preset = ase_calculator(str(MLFF(force_field)))
assert type(calc_from_decode) == type(calc_from_preset)
assert calc_from_decode.name == calc_from_preset.name
assert calc_from_decode.parameters == calc_from_preset.parameters == {}


@pytest.mark.parametrize(("fix_symmetry"), [True, False])
Expand Down

0 comments on commit 7e22064

Please sign in to comment.