diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d8802a8a..f6df8f1d 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -16,7 +16,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.11" cache: pip diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index acaf8436..3f0c747f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,7 +27,7 @@ jobs: uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.9 cache: pip @@ -36,7 +36,7 @@ jobs: - name: Install dependencies run: | pip install cython - # install ase from main branch until FrechetCellFilter is release + # install ase from main branch until FrechetCellFilter is released # TODO remove pip install git+https://gitlab.com/ase/ase pip install git+https://gitlab.com/ase/ase python setup.py build_ext --inplace @@ -61,7 +61,7 @@ jobs: - name: Check out repo uses: actions/checkout@v4 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 name: Install Python with: python-version: "3.10" diff --git a/chgnet/model/composition_model.py b/chgnet/model/composition_model.py index 7f2c2825..d9d6f544 100644 --- a/chgnet/model/composition_model.py +++ b/chgnet/model/composition_model.py @@ -53,7 +53,7 @@ def _get_energy(self, composition_feas: Tensor) -> Tensor: prediction associated with each composition [batchsize]. """ composition_feas = self.activation(self.fc1(composition_feas)) - composition_feas = composition_feas + self.gated_mlp(composition_feas) + composition_feas += self.gated_mlp(composition_feas) return self.fc2(composition_feas).view(-1) def forward(self, graphs: list[CrystalGraph]) -> Tensor: diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 95ea7c20..0e6f1fcf 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import inspect import io import pickle import sys @@ -10,7 +11,6 @@ import torch from ase import Atoms, units from ase.calculators.calculator import Calculator, all_changes, all_properties -from ase.filters import Filter, FrechetCellFilter from ase.md.npt import NPT from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen from ase.md.nvtberendsen import NVTBerendsen @@ -33,6 +33,18 @@ from ase.io import Trajectory from ase.optimize.optimize import Optimizer +try: + from ase.filters import Filter, FrechetCellFilter +except ImportError: + print( + "We recommend using ase's unreleased FrechetCellFilter over ExpCellFilter for " + "CHGNet structural relaxation. ExpCellFilter has a bug in its calculation " + "of cell gradients which was fixed in FrechetCellFilter. Otherwise the two " + "are identical. ExpCellFilter was kept only for backwards compatibility and " + "should no longer be used. Run pip install git+https://gitlab.com/ase/ase to " + "install from main branch." + ) + # We would like to thank M3GNet develop team for this module # source: https://github.com/materialsvirtuallab/m3gnet @@ -211,7 +223,7 @@ def relax( fmax: float | None = 0.1, steps: int | None = 500, relax_cell: bool | None = True, - ase_filter: Filter = FrechetCellFilter, + ase_filter: str | Filter = FrechetCellFilter, save_path: str | None = None, loginterval: int | None = 1, crystal_feas_save_path: str | None = None, @@ -228,8 +240,8 @@ def relax( Default = 500 relax_cell (bool | None): Whether to relax the cell as well. Default = True - ase_filter (ase.filters.Filter): The filter to apply to the atoms object - for relaxation. Default = FrechetCellFilter + ase_filter (str | ase.filters.Filter): The filter to apply to the atoms + object for relaxation. Default = FrechetCellFilter Used to default to ExpCellFilter but was removed due to bug reported in https://gitlab.com/ase/ase/-/issues/1321 and fixed in https://gitlab.com/ase/ase/-/merge_requests/3024. @@ -248,6 +260,20 @@ def relax( dict[str, Structure | TrajectoryObserver]: A dictionary with 'final_structure' and 'trajectory'. """ + if isinstance(ase_filter, str): + try: + import ase.filters + + ase_filter = getattr(ase.filters, ase_filter) + except AttributeError as exc: + valid_filter_names = [ + name + for name, cls in inspect.getmembers(ase.filters, inspect.isclass) + if issubclass(cls, Filter) + ] + raise ValueError( + f"Invalid {ase_filter=}, must be one of {valid_filter_names}. " + ) from exc if isinstance(atoms, Structure): atoms = atoms.to_ase_atoms() diff --git a/chgnet/model/layers.py b/chgnet/model/layers.py index f087ecc4..725bba5e 100644 --- a/chgnet/model/layers.py +++ b/chgnet/model/layers.py @@ -117,7 +117,7 @@ def forward( # smooth out message by bond_weights bond_weight = torch.index_select(bond_weights, 0, directed2undirected) - messages = messages * bond_weight + messages *= bond_weight # Aggregate messages new_atom_feas = aggregate( diff --git a/chgnet/model/model.py b/chgnet/model/model.py index e337c9d7..537a380d 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -521,7 +521,7 @@ def _compute( # Normalize energy if model is intensive if self.is_intensive: - energy = energy / atoms_per_graph + energy /= atoms_per_graph prediction["e"] = energy return prediction diff --git a/examples/basics.ipynb b/examples/basics.ipynb index cb348136..cedde93b 100644 --- a/examples/basics.ipynb +++ b/examples/basics.ipynb @@ -19,7 +19,7 @@ " from chgnet.model import CHGNet\n", "except ImportError:\n", " # install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)\n", - " !pip install chgnet\n" + " !pip install chgnet" ] }, { @@ -37,7 +37,7 @@ "# If the above line fails in Google Colab due to numpy version issue,\n", "# please restart the runtime, and the problem will be solved\n", "\n", - "np.set_printoptions(precision=4, suppress=True)\n" + "np.set_printoptions(precision=4, suppress=True)" ] }, { @@ -89,7 +89,7 @@ " cif = urlopen(url).read().decode(\"utf-8\")\n", " structure = Structure.from_str(cif, fmt=\"cif\")\n", "\n", - "print(structure)\n" + "print(structure)" ] }, { @@ -110,7 +110,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "CHGNet initialized with 400,438 parameters\n" + "CHGNet v0.3.0 initialized with 412,525 parameters\n" ] } ], @@ -118,7 +118,7 @@ "chgnet = CHGNet.load()\n", "\n", "# Alternatively you can read your own model\n", - "# chgnet = CHGNet.from_file(model_path)\n" + "# chgnet = CHGNet.from_file(model_path)" ] }, { @@ -176,7 +176,7 @@ " (\"stress\", \"GPa\"),\n", " (\"magmom\", \"mu_B\"),\n", "]:\n", - " print(f\"CHGNet-predicted {key} ({unit}):\\n{prediction[key[0]]}\\n\")\n" + " print(f\"CHGNet-predicted {key} ({unit}):\\n{prediction[key[0]]}\\n\")" ] }, { @@ -197,7 +197,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "CHGNet initialized with 400,438 parameters\n", + "CHGNet v0.3.0 initialized with 412,525 parameters\n", "CHGNet will run on cpu\n" ] } @@ -205,7 +205,7 @@ "source": [ "from chgnet.model import StructOptimizer\n", "\n", - "relaxer = StructOptimizer()\n" + "relaxer = StructOptimizer()" ] }, { @@ -219,31 +219,53 @@ "output_type": "stream", "text": [ "\n", - "CHGNet took 29 steps. Relaxed structure:\n", + "FrechetCellFilter took 49 steps. Relaxed structure:\n", + "\n", + "Full Formula (Li2 Mn2 O4)\n", + "Reduced Formula: LiMnO2\n", + "abc : 2.876179 4.609830 5.862965\n", + "angles: 89.863012 89.706707 89.946402\n", + "pbc : True True True\n", + "Sites (8)\n", + " # SP a b c magmom\n", + "--- ---- -------- -------- -------- ----------\n", + " 0 Li+ 0.481659 0.517177 0.377268 0.00253999\n", + " 1 Li+ 0.012506 0.04361 0.611305 0.00241017\n", + " 2 Mn3+ 0.508999 0.531361 0.860983 3.88035\n", + " 3 Mn3+ 1.00639 0.032734 0.130496 3.87339\n", + " 4 O2- 0.506676 0.039479 0.353671 0.0450131\n", + " 5 O2- 0.009236 0.531155 0.094449 0.0397426\n", + " 6 O2- 0.503994 0.029776 0.896556 0.0407182\n", + " 7 O2- 0.009769 0.521232 0.636952 0.0473042\n", + "\n", + "ExpCellFilter took 83 steps. Relaxed structure:\n", + "\n", "Full Formula (Li2 Mn2 O4)\n", "Reduced Formula: LiMnO2\n", - "abc : 2.865864 4.648716 5.827764\n", - "angles: 89.917211 90.239405 89.975425\n", + "abc : 2.874395 4.611958 5.852410\n", + "angles: 89.943237 89.910969 89.994579\n", "pbc : True True True\n", "Sites (8)\n", - " # SP a b c magmom\n", - "--- ---- --------- --------- -------- ----------\n", - " 0 Li+ 0.494018 0.479737 0.387171 0.00498427\n", - " 1 Li+ 0.008464 0.006131 0.625817 0.00512926\n", - " 2 Mn3+ 0.50073 0.502478 0.869608 3.85374\n", - " 3 Mn3+ 0.997815 -0.000319 0.139344 3.859\n", - " 4 O2- 0.502142 0.009453 0.363411 0.0253105\n", - " 5 O2- 1.00293 0.502592 0.104559 0.0366638\n", - " 6 O2- 0.493749 0.998092 0.903592 0.0365367\n", - " 7 O2- -0.002278 0.495108 0.645655 0.0248522\n" + " # SP a b c magmom\n", + "--- ---- -------- -------- -------- ----------\n", + " 0 Li+ 0.474099 0.522936 0.375014 0.00291404\n", + " 1 Li+ 0.007464 0.033067 0.61184 0.00261617\n", + " 2 Mn3+ 0.512206 0.531325 0.861133 3.87057\n", + " 3 Mn3+ 1.00718 0.030874 0.130145 3.86706\n", + " 4 O2- 0.504485 0.035636 0.353984 0.0443497\n", + " 5 O2- 0.010645 0.532059 0.095251 0.0381828\n", + " 6 O2- 0.510762 0.030743 0.896625 0.0382355\n", + " 7 O2- 0.012389 0.529884 0.637688 0.0455911\n" ] } ], "source": [ "structure.perturb(0.1)\n", - "result = relaxer.relax(structure, verbose=False)\n", - "print(f\"\\nCHGNet took {len(result['trajectory'])} steps. Relaxed structure:\")\n", - "print(result[\"final_structure\"])\n" + "for ase_filter in (\"FrechetCellFilter\", \"ExpCellFilter\"):\n", + " result = relaxer.relax(structure, verbose=False, ase_filter=ase_filter)\n", + " n_steps = len(result[\"trajectory\"])\n", + " print(f\"\\n{ase_filter} took {n_steps} steps. Relaxed structure:\\n\")\n", + " print(result[\"final_structure\"])" ] }, { @@ -289,7 +311,7 @@ " logfile=\"md_out.log\",\n", " loginterval=100,\n", ")\n", - "md.run(50) # run a 0.1 ps MD simulation\n" + "md.run(50) # run a 0.1 ps MD simulation" ] }, { @@ -316,7 +338,7 @@ ], "source": [ "supercell = structure.make_supercell([2, 2, 2], in_place=False)\n", - "print(supercell.composition)\n" + "print(supercell.composition)" ] }, { @@ -340,7 +362,7 @@ "remove_ids = random.sample(list(range(n_Li)), n_Li // 2)\n", "\n", "supercell.remove_sites(remove_ids)\n", - "print(supercell.composition)\n" + "print(supercell.composition)" ] }, { @@ -756,7 +778,7 @@ } ], "source": [ - "result = relaxer.relax(supercell)\n" + "result = relaxer.relax(supercell)" ] }, { @@ -769,7 +791,7 @@ "import pandas as pd\n", "\n", "df_magmom = pd.DataFrame({\"Unrelaxed\": chgnet.predict_structure(supercell)[\"m\"]})\n", - "df_magmom[\"CHGNet relaxed\"] = result[\"final_structure\"].site_properties[\"magmom\"]\n" + "df_magmom[\"CHGNet relaxed\"] = result[\"final_structure\"].site_properties[\"magmom\"]" ] }, { @@ -1820,7 +1842,7 @@ ")\n", "fig.layout.legend.update(title=\"\", x=1, y=1, xanchor=\"right\", yanchor=\"top\")\n", "fig.layout.xaxis.title = \"Magnetic moment\"\n", - "fig.show()\n" + "fig.show()" ] } ],