Skip to content

Commit

Permalink
Add possibility to use your own M3GNet potential (#911)
Browse files Browse the repository at this point in the history
* allow the possibility to use your own M3GNet potential

allow the possibility to use your own M3GNet potential, instead of the pretrained model only.

* added a unit test

* added a unit test

* test_dir not needed

* change kwargs passing
  • Loading branch information
QuantumChemist authored Jul 10, 2024
1 parent def6a4e commit d30890f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/atomate2/forcefields/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,8 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N
import matgl
from matgl.ext.ase import PESCalculator

potential = matgl.load_model("M3GNet-MP-2021.2.8-PES")
path = kwargs.get("path", "M3GNet-MP-2021.2.8-PES")
potential = matgl.load_model(path)
calculator = PESCalculator(potential, **kwargs)

elif calculator_name == MLFF.MACE:
Expand Down
22 changes: 22 additions & 0 deletions tests/forcefields/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,25 @@ def test_fix_symmetry(fix_symmetry):
assert symmetry_init["number"] == symmetry_final["number"] == 229
else:
assert symmetry_init["number"] != symmetry_final["number"] == 99


def test_m3gnet_pot():
import matgl
from matgl.ext.ase import PESCalculator

kwargs_calc = {"path": "M3GNet-MP-2021.2.8-DIRECT-PES", "stress_weight": 2.0}
kwargs_default = {"stress_weight": 2.0}

m3gnet_calculator = ase_calculator(calculator_meta="MLFF.M3GNet", **kwargs_calc)

# uses "M3GNet-MP-2021.2.8-PES" per default
m3gnet_default = ase_calculator(calculator_meta="MLFF.M3GNet", **kwargs_default)

potential = matgl.load_model("M3GNet-MP-2021.2.8-DIRECT-PES")
m3gnet_pes_calc = PESCalculator(potential=potential, stress_weight=2.0)

assert str(m3gnet_pes_calc.potential) == str(m3gnet_calculator.potential)
# casting necessary because <class 'matgl.apps.pes.Potential'> can't be compared
assert str(m3gnet_pes_calc.potential) != str(m3gnet_default.potential)
assert m3gnet_pes_calc.stress_weight == m3gnet_calculator.stress_weight
assert m3gnet_pes_calc.stress_weight == m3gnet_default.stress_weight

0 comments on commit d30890f

Please sign in to comment.