Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pip install git+https://gitlab.com/ase/ase user advice on FrechetCellFilter ImportError and allow ase_filter to be str #104

Merged
merged 6 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: pip
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.9
cache: pip
Expand All @@ -36,7 +36,7 @@ jobs:
- name: Install dependencies
run: |
pip install cython
# install ase from main branch until FrechetCellFilter is release
# install ase from main branch until FrechetCellFilter is released
# TODO remove pip install git+https://gitlab.com/ase/ase
pip install git+https://gitlab.com/ase/ase
python setup.py build_ext --inplace
Expand All @@ -61,7 +61,7 @@ jobs:
- name: Check out repo
uses: actions/checkout@v4

- uses: actions/setup-python@v4
- uses: actions/setup-python@v5
name: Install Python
with:
python-version: "3.10"
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _get_energy(self, composition_feas: Tensor) -> Tensor:
prediction associated with each composition [batchsize].
"""
composition_feas = self.activation(self.fc1(composition_feas))
composition_feas = composition_feas + self.gated_mlp(composition_feas)
composition_feas += self.gated_mlp(composition_feas)
return self.fc2(composition_feas).view(-1)

def forward(self, graphs: list[CrystalGraph]) -> Tensor:
Expand Down
34 changes: 30 additions & 4 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import inspect
import io
import pickle
import sys
Expand All @@ -10,7 +11,6 @@
import torch
from ase import Atoms, units
from ase.calculators.calculator import Calculator, all_changes, all_properties
from ase.filters import Filter, FrechetCellFilter
from ase.md.npt import NPT
from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen
from ase.md.nvtberendsen import NVTBerendsen
Expand All @@ -33,6 +33,18 @@
from ase.io import Trajectory
from ase.optimize.optimize import Optimizer

try:
from ase.filters import Filter, FrechetCellFilter
except ImportError:
print(
"We recommend using ase's unreleased FrechetCellFilter over ExpCellFilter for "
"CHGNet structural relaxation. ExpCellFilter has a bug in its calculation "
"of cell gradients which was fixed in FrechetCellFilter. Otherwise the two "
"are identical. ExpCellFilter was kept only for backwards compatibility and "
"should no longer be used. Run pip install git+https://gitlab.com/ase/ase to "
"install from main branch."
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BowenD-UCB We now print this advice when ase.filters.FrechetCellFilter can't be imported.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for this!

# We would like to thank M3GNet develop team for this module
# source: https://github.com/materialsvirtuallab/m3gnet

Expand Down Expand Up @@ -211,7 +223,7 @@ def relax(
fmax: float | None = 0.1,
steps: int | None = 500,
relax_cell: bool | None = True,
ase_filter: Filter = FrechetCellFilter,
ase_filter: str | Filter = FrechetCellFilter,
save_path: str | None = None,
loginterval: int | None = 1,
crystal_feas_save_path: str | None = None,
Expand All @@ -228,8 +240,8 @@ def relax(
Default = 500
relax_cell (bool | None): Whether to relax the cell as well.
Default = True
ase_filter (ase.filters.Filter): The filter to apply to the atoms object
for relaxation. Default = FrechetCellFilter
ase_filter (str | ase.filters.Filter): The filter to apply to the atoms
object for relaxation. Default = FrechetCellFilter
Used to default to ExpCellFilter but was removed due to bug reported in
https://gitlab.com/ase/ase/-/issues/1321 and fixed in
https://gitlab.com/ase/ase/-/merge_requests/3024.
Expand All @@ -248,6 +260,20 @@ def relax(
dict[str, Structure | TrajectoryObserver]:
A dictionary with 'final_structure' and 'trajectory'.
"""
if isinstance(ase_filter, str):
try:
import ase.filters

ase_filter = getattr(ase.filters, ase_filter)
except AttributeError as exc:
valid_filter_names = [
name
for name, cls in inspect.getmembers(ase.filters, inspect.isclass)
if issubclass(cls, Filter)
]
raise ValueError(
f"Invalid {ase_filter=}, must be one of {valid_filter_names}. "
) from exc
if isinstance(atoms, Structure):
atoms = atoms.to_ase_atoms()

Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def forward(
bond_vectors (Tensor): normalized bond vectors, for tracking the bond
directions [n_bond, 3]
"""
neighbor = neighbor + image @ lattice
neighbor += image @ lattice
bond_vectors = center - neighbor
bond_lengths = torch.norm(bond_vectors, dim=1)
# Normalize the bond vectors
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def forward(

# smooth out message by bond_weights
bond_weight = torch.index_select(bond_weights, 0, directed2undirected)
messages = messages * bond_weight
messages *= bond_weight

# Aggregate messages
new_atom_feas = aggregate(
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def _compute(

# Normalize energy if model is intensive
if self.is_intensive:
energy = energy / atoms_per_graph
energy /= atoms_per_graph
prediction["e"] = energy

return prediction
Expand Down
82 changes: 52 additions & 30 deletions examples/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
" from chgnet.model import CHGNet\n",
"except ImportError:\n",
" # install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)\n",
" !pip install chgnet\n"
" !pip install chgnet"
]
},
{
Expand All @@ -37,7 +37,7 @@
"# If the above line fails in Google Colab due to numpy version issue,\n",
"# please restart the runtime, and the problem will be solved\n",
"\n",
"np.set_printoptions(precision=4, suppress=True)\n"
"np.set_printoptions(precision=4, suppress=True)"
]
},
{
Expand Down Expand Up @@ -89,7 +89,7 @@
" cif = urlopen(url).read().decode(\"utf-8\")\n",
" structure = Structure.from_str(cif, fmt=\"cif\")\n",
"\n",
"print(structure)\n"
"print(structure)"
]
},
{
Expand All @@ -110,15 +110,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CHGNet initialized with 400,438 parameters\n"
"CHGNet v0.3.0 initialized with 412,525 parameters\n"
]
}
],
"source": [
"chgnet = CHGNet.load()\n",
"\n",
"# Alternatively you can read your own model\n",
"# chgnet = CHGNet.from_file(model_path)\n"
"# chgnet = CHGNet.from_file(model_path)"
]
},
{
Expand Down Expand Up @@ -176,7 +176,7 @@
" (\"stress\", \"GPa\"),\n",
" (\"magmom\", \"mu_B\"),\n",
"]:\n",
" print(f\"CHGNet-predicted {key} ({unit}):\\n{prediction[key[0]]}\\n\")\n"
" print(f\"CHGNet-predicted {key} ({unit}):\\n{prediction[key[0]]}\\n\")"
]
},
{
Expand All @@ -197,15 +197,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CHGNet initialized with 400,438 parameters\n",
"CHGNet v0.3.0 initialized with 412,525 parameters\n",
"CHGNet will run on cpu\n"
]
}
],
"source": [
"from chgnet.model import StructOptimizer\n",
"\n",
"relaxer = StructOptimizer()\n"
"relaxer = StructOptimizer()"
]
},
{
Expand All @@ -219,31 +219,53 @@
"output_type": "stream",
"text": [
"\n",
"CHGNet took 29 steps. Relaxed structure:\n",
"FrechetCellFilter took 49 steps. Relaxed structure:\n",
"\n",
"Full Formula (Li2 Mn2 O4)\n",
"Reduced Formula: LiMnO2\n",
"abc : 2.876179 4.609830 5.862965\n",
"angles: 89.863012 89.706707 89.946402\n",
"pbc : True True True\n",
"Sites (8)\n",
" # SP a b c magmom\n",
"--- ---- -------- -------- -------- ----------\n",
" 0 Li+ 0.481659 0.517177 0.377268 0.00253999\n",
" 1 Li+ 0.012506 0.04361 0.611305 0.00241017\n",
" 2 Mn3+ 0.508999 0.531361 0.860983 3.88035\n",
" 3 Mn3+ 1.00639 0.032734 0.130496 3.87339\n",
" 4 O2- 0.506676 0.039479 0.353671 0.0450131\n",
" 5 O2- 0.009236 0.531155 0.094449 0.0397426\n",
" 6 O2- 0.503994 0.029776 0.896556 0.0407182\n",
" 7 O2- 0.009769 0.521232 0.636952 0.0473042\n",
"\n",
"ExpCellFilter took 83 steps. Relaxed structure:\n",
"\n",
"Full Formula (Li2 Mn2 O4)\n",
"Reduced Formula: LiMnO2\n",
"abc : 2.865864 4.648716 5.827764\n",
"angles: 89.917211 90.239405 89.975425\n",
"abc : 2.874395 4.611958 5.852410\n",
"angles: 89.943237 89.910969 89.994579\n",
"pbc : True True True\n",
"Sites (8)\n",
" # SP a b c magmom\n",
"--- ---- --------- --------- -------- ----------\n",
" 0 Li+ 0.494018 0.479737 0.387171 0.00498427\n",
" 1 Li+ 0.008464 0.006131 0.625817 0.00512926\n",
" 2 Mn3+ 0.50073 0.502478 0.869608 3.85374\n",
" 3 Mn3+ 0.997815 -0.000319 0.139344 3.859\n",
" 4 O2- 0.502142 0.009453 0.363411 0.0253105\n",
" 5 O2- 1.00293 0.502592 0.104559 0.0366638\n",
" 6 O2- 0.493749 0.998092 0.903592 0.0365367\n",
" 7 O2- -0.002278 0.495108 0.645655 0.0248522\n"
" # SP a b c magmom\n",
"--- ---- -------- -------- -------- ----------\n",
" 0 Li+ 0.474099 0.522936 0.375014 0.00291404\n",
" 1 Li+ 0.007464 0.033067 0.61184 0.00261617\n",
" 2 Mn3+ 0.512206 0.531325 0.861133 3.87057\n",
" 3 Mn3+ 1.00718 0.030874 0.130145 3.86706\n",
" 4 O2- 0.504485 0.035636 0.353984 0.0443497\n",
" 5 O2- 0.010645 0.532059 0.095251 0.0381828\n",
" 6 O2- 0.510762 0.030743 0.896625 0.0382355\n",
" 7 O2- 0.012389 0.529884 0.637688 0.0455911\n"
]
}
],
"source": [
"structure.perturb(0.1)\n",
"result = relaxer.relax(structure, verbose=False)\n",
"print(f\"\\nCHGNet took {len(result['trajectory'])} steps. Relaxed structure:\")\n",
"print(result[\"final_structure\"])\n"
"for ase_filter in (\"FrechetCellFilter\", \"ExpCellFilter\"):\n",
" result = relaxer.relax(structure, verbose=False, ase_filter=ase_filter)\n",
" n_steps = len(result[\"trajectory\"])\n",
" print(f\"\\n{ase_filter} took {n_steps} steps. Relaxed structure:\\n\")\n",
" print(result[\"final_structure\"])"
]
},
{
Expand Down Expand Up @@ -289,7 +311,7 @@
" logfile=\"md_out.log\",\n",
" loginterval=100,\n",
")\n",
"md.run(50) # run a 0.1 ps MD simulation\n"
"md.run(50) # run a 0.1 ps MD simulation"
]
},
{
Expand All @@ -316,7 +338,7 @@
],
"source": [
"supercell = structure.make_supercell([2, 2, 2], in_place=False)\n",
"print(supercell.composition)\n"
"print(supercell.composition)"
]
},
{
Expand All @@ -340,7 +362,7 @@
"remove_ids = random.sample(list(range(n_Li)), n_Li // 2)\n",
"\n",
"supercell.remove_sites(remove_ids)\n",
"print(supercell.composition)\n"
"print(supercell.composition)"
]
},
{
Expand Down Expand Up @@ -756,7 +778,7 @@
}
],
"source": [
"result = relaxer.relax(supercell)\n"
"result = relaxer.relax(supercell)"
]
},
{
Expand All @@ -769,7 +791,7 @@
"import pandas as pd\n",
"\n",
"df_magmom = pd.DataFrame({\"Unrelaxed\": chgnet.predict_structure(supercell)[\"m\"]})\n",
"df_magmom[\"CHGNet relaxed\"] = result[\"final_structure\"].site_properties[\"magmom\"]\n"
"df_magmom[\"CHGNet relaxed\"] = result[\"final_structure\"].site_properties[\"magmom\"]"
]
},
{
Expand Down Expand Up @@ -1820,7 +1842,7 @@
")\n",
"fig.layout.legend.update(title=\"\", x=1, y=1, xanchor=\"right\", yanchor=\"top\")\n",
"fig.layout.xaxis.title = \"Magnetic moment\"\n",
"fig.show()\n"
"fig.show()"
]
}
],
Expand Down
Loading