Skip to content

Commit

Permalink
Add pip install git+https://gitlab.com/ase/ase user advice on Frec…
Browse files Browse the repository at this point in the history
…hetCellFilter` `ImportError` and allow `ase_filter` to be `str` (#104)

* use in-place arithmetic ops

* bump actions/setup-python to v5

* fix torch non-differentiable in-place op errors

* add ase git install advice and recommend to use FrechetCellFilter for CHGNet structural relaxation

* allow StructOptimizer ase_filter keyword to be str allowed values err msg on invalid name

* revert BondEncoder neighbor calc
  • Loading branch information
janosh authored Dec 11, 2023
1 parent ec953df commit 3165efe
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 41 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: pip
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.9
cache: pip
Expand All @@ -36,7 +36,7 @@ jobs:
- name: Install dependencies
run: |
pip install cython
# install ase from main branch until FrechetCellFilter is release
# install ase from main branch until FrechetCellFilter is released
# TODO remove pip install git+https://gitlab.com/ase/ase
pip install git+https://gitlab.com/ase/ase
python setup.py build_ext --inplace
Expand All @@ -61,7 +61,7 @@ jobs:
- name: Check out repo
uses: actions/checkout@v4

- uses: actions/setup-python@v4
- uses: actions/setup-python@v5
name: Install Python
with:
python-version: "3.10"
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _get_energy(self, composition_feas: Tensor) -> Tensor:
prediction associated with each composition [batchsize].
"""
composition_feas = self.activation(self.fc1(composition_feas))
composition_feas = composition_feas + self.gated_mlp(composition_feas)
composition_feas += self.gated_mlp(composition_feas)
return self.fc2(composition_feas).view(-1)

def forward(self, graphs: list[CrystalGraph]) -> Tensor:
Expand Down
34 changes: 30 additions & 4 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import inspect
import io
import pickle
import sys
Expand All @@ -10,7 +11,6 @@
import torch
from ase import Atoms, units
from ase.calculators.calculator import Calculator, all_changes, all_properties
from ase.filters import Filter, FrechetCellFilter
from ase.md.npt import NPT
from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen
from ase.md.nvtberendsen import NVTBerendsen
Expand All @@ -33,6 +33,18 @@
from ase.io import Trajectory
from ase.optimize.optimize import Optimizer

try:
from ase.filters import Filter, FrechetCellFilter
except ImportError:
print(
"We recommend using ase's unreleased FrechetCellFilter over ExpCellFilter for "
"CHGNet structural relaxation. ExpCellFilter has a bug in its calculation "
"of cell gradients which was fixed in FrechetCellFilter. Otherwise the two "
"are identical. ExpCellFilter was kept only for backwards compatibility and "
"should no longer be used. Run pip install git+https://gitlab.com/ase/ase to "
"install from main branch."
)

# We would like to thank M3GNet develop team for this module
# source: https://github.com/materialsvirtuallab/m3gnet

Expand Down Expand Up @@ -211,7 +223,7 @@ def relax(
fmax: float | None = 0.1,
steps: int | None = 500,
relax_cell: bool | None = True,
ase_filter: Filter = FrechetCellFilter,
ase_filter: str | Filter = FrechetCellFilter,
save_path: str | None = None,
loginterval: int | None = 1,
crystal_feas_save_path: str | None = None,
Expand All @@ -228,8 +240,8 @@ def relax(
Default = 500
relax_cell (bool | None): Whether to relax the cell as well.
Default = True
ase_filter (ase.filters.Filter): The filter to apply to the atoms object
for relaxation. Default = FrechetCellFilter
ase_filter (str | ase.filters.Filter): The filter to apply to the atoms
object for relaxation. Default = FrechetCellFilter
Used to default to ExpCellFilter but was removed due to bug reported in
https://gitlab.com/ase/ase/-/issues/1321 and fixed in
https://gitlab.com/ase/ase/-/merge_requests/3024.
Expand All @@ -248,6 +260,20 @@ def relax(
dict[str, Structure | TrajectoryObserver]:
A dictionary with 'final_structure' and 'trajectory'.
"""
if isinstance(ase_filter, str):
try:
import ase.filters

ase_filter = getattr(ase.filters, ase_filter)
except AttributeError as exc:
valid_filter_names = [
name
for name, cls in inspect.getmembers(ase.filters, inspect.isclass)
if issubclass(cls, Filter)
]
raise ValueError(
f"Invalid {ase_filter=}, must be one of {valid_filter_names}. "
) from exc
if isinstance(atoms, Structure):
atoms = atoms.to_ase_atoms()

Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def forward(

# smooth out message by bond_weights
bond_weight = torch.index_select(bond_weights, 0, directed2undirected)
messages = messages * bond_weight
messages *= bond_weight

# Aggregate messages
new_atom_feas = aggregate(
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def _compute(

# Normalize energy if model is intensive
if self.is_intensive:
energy = energy / atoms_per_graph
energy /= atoms_per_graph
prediction["e"] = energy

return prediction
Expand Down
82 changes: 52 additions & 30 deletions examples/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
" from chgnet.model import CHGNet\n",
"except ImportError:\n",
" # install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)\n",
" !pip install chgnet\n"
" !pip install chgnet"
]
},
{
Expand All @@ -37,7 +37,7 @@
"# If the above line fails in Google Colab due to numpy version issue,\n",
"# please restart the runtime, and the problem will be solved\n",
"\n",
"np.set_printoptions(precision=4, suppress=True)\n"
"np.set_printoptions(precision=4, suppress=True)"
]
},
{
Expand Down Expand Up @@ -89,7 +89,7 @@
" cif = urlopen(url).read().decode(\"utf-8\")\n",
" structure = Structure.from_str(cif, fmt=\"cif\")\n",
"\n",
"print(structure)\n"
"print(structure)"
]
},
{
Expand All @@ -110,15 +110,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CHGNet initialized with 400,438 parameters\n"
"CHGNet v0.3.0 initialized with 412,525 parameters\n"
]
}
],
"source": [
"chgnet = CHGNet.load()\n",
"\n",
"# Alternatively you can read your own model\n",
"# chgnet = CHGNet.from_file(model_path)\n"
"# chgnet = CHGNet.from_file(model_path)"
]
},
{
Expand Down Expand Up @@ -176,7 +176,7 @@
" (\"stress\", \"GPa\"),\n",
" (\"magmom\", \"mu_B\"),\n",
"]:\n",
" print(f\"CHGNet-predicted {key} ({unit}):\\n{prediction[key[0]]}\\n\")\n"
" print(f\"CHGNet-predicted {key} ({unit}):\\n{prediction[key[0]]}\\n\")"
]
},
{
Expand All @@ -197,15 +197,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CHGNet initialized with 400,438 parameters\n",
"CHGNet v0.3.0 initialized with 412,525 parameters\n",
"CHGNet will run on cpu\n"
]
}
],
"source": [
"from chgnet.model import StructOptimizer\n",
"\n",
"relaxer = StructOptimizer()\n"
"relaxer = StructOptimizer()"
]
},
{
Expand All @@ -219,31 +219,53 @@
"output_type": "stream",
"text": [
"\n",
"CHGNet took 29 steps. Relaxed structure:\n",
"FrechetCellFilter took 49 steps. Relaxed structure:\n",
"\n",
"Full Formula (Li2 Mn2 O4)\n",
"Reduced Formula: LiMnO2\n",
"abc : 2.876179 4.609830 5.862965\n",
"angles: 89.863012 89.706707 89.946402\n",
"pbc : True True True\n",
"Sites (8)\n",
" # SP a b c magmom\n",
"--- ---- -------- -------- -------- ----------\n",
" 0 Li+ 0.481659 0.517177 0.377268 0.00253999\n",
" 1 Li+ 0.012506 0.04361 0.611305 0.00241017\n",
" 2 Mn3+ 0.508999 0.531361 0.860983 3.88035\n",
" 3 Mn3+ 1.00639 0.032734 0.130496 3.87339\n",
" 4 O2- 0.506676 0.039479 0.353671 0.0450131\n",
" 5 O2- 0.009236 0.531155 0.094449 0.0397426\n",
" 6 O2- 0.503994 0.029776 0.896556 0.0407182\n",
" 7 O2- 0.009769 0.521232 0.636952 0.0473042\n",
"\n",
"ExpCellFilter took 83 steps. Relaxed structure:\n",
"\n",
"Full Formula (Li2 Mn2 O4)\n",
"Reduced Formula: LiMnO2\n",
"abc : 2.865864 4.648716 5.827764\n",
"angles: 89.917211 90.239405 89.975425\n",
"abc : 2.874395 4.611958 5.852410\n",
"angles: 89.943237 89.910969 89.994579\n",
"pbc : True True True\n",
"Sites (8)\n",
" # SP a b c magmom\n",
"--- ---- --------- --------- -------- ----------\n",
" 0 Li+ 0.494018 0.479737 0.387171 0.00498427\n",
" 1 Li+ 0.008464 0.006131 0.625817 0.00512926\n",
" 2 Mn3+ 0.50073 0.502478 0.869608 3.85374\n",
" 3 Mn3+ 0.997815 -0.000319 0.139344 3.859\n",
" 4 O2- 0.502142 0.009453 0.363411 0.0253105\n",
" 5 O2- 1.00293 0.502592 0.104559 0.0366638\n",
" 6 O2- 0.493749 0.998092 0.903592 0.0365367\n",
" 7 O2- -0.002278 0.495108 0.645655 0.0248522\n"
" # SP a b c magmom\n",
"--- ---- -------- -------- -------- ----------\n",
" 0 Li+ 0.474099 0.522936 0.375014 0.00291404\n",
" 1 Li+ 0.007464 0.033067 0.61184 0.00261617\n",
" 2 Mn3+ 0.512206 0.531325 0.861133 3.87057\n",
" 3 Mn3+ 1.00718 0.030874 0.130145 3.86706\n",
" 4 O2- 0.504485 0.035636 0.353984 0.0443497\n",
" 5 O2- 0.010645 0.532059 0.095251 0.0381828\n",
" 6 O2- 0.510762 0.030743 0.896625 0.0382355\n",
" 7 O2- 0.012389 0.529884 0.637688 0.0455911\n"
]
}
],
"source": [
"structure.perturb(0.1)\n",
"result = relaxer.relax(structure, verbose=False)\n",
"print(f\"\\nCHGNet took {len(result['trajectory'])} steps. Relaxed structure:\")\n",
"print(result[\"final_structure\"])\n"
"for ase_filter in (\"FrechetCellFilter\", \"ExpCellFilter\"):\n",
" result = relaxer.relax(structure, verbose=False, ase_filter=ase_filter)\n",
" n_steps = len(result[\"trajectory\"])\n",
" print(f\"\\n{ase_filter} took {n_steps} steps. Relaxed structure:\\n\")\n",
" print(result[\"final_structure\"])"
]
},
{
Expand Down Expand Up @@ -289,7 +311,7 @@
" logfile=\"md_out.log\",\n",
" loginterval=100,\n",
")\n",
"md.run(50) # run a 0.1 ps MD simulation\n"
"md.run(50) # run a 0.1 ps MD simulation"
]
},
{
Expand All @@ -316,7 +338,7 @@
],
"source": [
"supercell = structure.make_supercell([2, 2, 2], in_place=False)\n",
"print(supercell.composition)\n"
"print(supercell.composition)"
]
},
{
Expand All @@ -340,7 +362,7 @@
"remove_ids = random.sample(list(range(n_Li)), n_Li // 2)\n",
"\n",
"supercell.remove_sites(remove_ids)\n",
"print(supercell.composition)\n"
"print(supercell.composition)"
]
},
{
Expand Down Expand Up @@ -756,7 +778,7 @@
}
],
"source": [
"result = relaxer.relax(supercell)\n"
"result = relaxer.relax(supercell)"
]
},
{
Expand All @@ -769,7 +791,7 @@
"import pandas as pd\n",
"\n",
"df_magmom = pd.DataFrame({\"Unrelaxed\": chgnet.predict_structure(supercell)[\"m\"]})\n",
"df_magmom[\"CHGNet relaxed\"] = result[\"final_structure\"].site_properties[\"magmom\"]\n"
"df_magmom[\"CHGNet relaxed\"] = result[\"final_structure\"].site_properties[\"magmom\"]"
]
},
{
Expand Down Expand Up @@ -1820,7 +1842,7 @@
")\n",
"fig.layout.legend.update(title=\"\", x=1, y=1, xanchor=\"right\", yanchor=\"top\")\n",
"fig.layout.xaxis.title = \"Magnetic moment\"\n",
"fig.show()\n"
"fig.show()"
]
}
],
Expand Down

0 comments on commit 3165efe

Please sign in to comment.