Skip to content

Commit

Permalink
Merge the changes from the branch tko (#227)
Browse files Browse the repository at this point in the history
* major changes including new model architectures, refactoring variables and unifying training frameworks. (#226)

* model version for Potential class is added

* preliminary TensorNet and SO3Net DGL implementations are added

* Minor warning fixes and more descriptions added

* refractor the potential and dataset class for other GNN model architectures

* refractor the ase-calculator and dataset class for other GNN model architectures

* Added unit tests for TensorNet and SO3Net training, Unified the training framework for all model architectures

* Fixed black and modified some united tests

* Modified united tests to improve the coverage score

* fixed a minor bug in test_readout.py

* M3GNetCalculator is added for backward compatibility
  • Loading branch information
kenko911 authored Feb 14, 2024
1 parent 2fd7eba commit 1677d7e
Show file tree
Hide file tree
Showing 47 changed files with 4,469 additions and 1,170 deletions.
10 changes: 5 additions & 5 deletions docs/genindex.html
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ <h2 id="G">G</h2>
<h2 id="H">H</h2>
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="matgl.graph.html#matgl.graph.data.M3GNetDataset.has_cache">has_cache() (matgl.graph.data.M3GNetDataset method)</a>
<li><a href="matgl.graph.html#matgl.graph.data.MGLDataset.has_cache">has_cache() (matgl.graph.data.MGLDataset method)</a>

<ul>
<li><a href="matgl.graph.html#matgl.graph.data.MEGNetDataset.has_cache">(matgl.graph.data.MEGNetDataset method)</a>
Expand Down Expand Up @@ -286,7 +286,7 @@ <h2 id="L">L</h2>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="matgl.layers.html#matgl.layers._core.MLP.last_linear">last_linear (matgl.layers._core.MLP property)</a>
</li>
<li><a href="matgl.graph.html#matgl.graph.data.M3GNetDataset.load">load() (matgl.graph.data.M3GNetDataset method)</a>
<li><a href="matgl.graph.html#matgl.graph.data.MGLDataset.load">load() (matgl.graph.data.MGLDataset method)</a>

<ul>
<li><a href="matgl.graph.html#matgl.graph.data.MEGNetDataset.load">(matgl.graph.data.MEGNetDataset method)</a>
Expand Down Expand Up @@ -318,7 +318,7 @@ <h2 id="M">M</h2>
</li>
<li><a href="matgl.ext.html#matgl.ext.ase.M3GNetCalculator">M3GNetCalculator (class in matgl.ext.ase)</a>
</li>
<li><a href="matgl.graph.html#matgl.graph.data.M3GNetDataset">M3GNetDataset (class in matgl.graph.data)</a>
<li><a href="matgl.graph.html#matgl.graph.data.MGLDataset">MGLDataset (class in matgl.graph.data)</a>
</li>
<li><a href="matgl.layers.html#matgl.layers._graph_convolution.M3GNetGraphConv">M3GNetGraphConv (class in matgl.layers._graph_convolution)</a>
</li>
Expand Down Expand Up @@ -678,7 +678,7 @@ <h2 id="P">P</h2>
<li><a href="matgl.models.html#matgl.models._wrappers.TransformedTargetModel.predict_structure">(matgl.models._wrappers.TransformedTargetModel method)</a>
</li>
</ul></li>
<li><a href="matgl.graph.html#matgl.graph.data.M3GNetDataset.process">process() (matgl.graph.data.M3GNetDataset method)</a>
<li><a href="matgl.graph.html#matgl.graph.data.MGLDataset.process">process() (matgl.graph.data.MGLDataset method)</a>

<ul>
<li><a href="matgl.graph.html#matgl.graph.data.MEGNetDataset.process">(matgl.graph.data.MEGNetDataset method)</a>
Expand Down Expand Up @@ -721,7 +721,7 @@ <h2 id="S">S</h2>
<li><a href="matgl.ext.html#matgl.ext.ase.TrajectoryObserver.save">save() (matgl.ext.ase.TrajectoryObserver method)</a>

<ul>
<li><a href="matgl.graph.html#matgl.graph.data.M3GNetDataset.save">(matgl.graph.data.M3GNetDataset method)</a>
<li><a href="matgl.graph.html#matgl.graph.data.MGLDataset.save">(matgl.graph.data.MGLDataset method)</a>
</li>
<li><a href="matgl.graph.html#matgl.graph.data.MEGNetDataset.save">(matgl.graph.data.MEGNetDataset method)</a>
</li>
Expand Down
2 changes: 1 addition & 1 deletion docs/matgl.graph.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Construct a dgl graph from processed structure and bond information.

Tools to construct a dataset of DGL graphs.

### *class* matgl.graph.data.M3GNetDataset(converter: GraphConverter, threebody_cutoff: float, structures: list, energies: list | None = None, forces: list | None = None, stresses: list | None = None, labels: list | None = None, name=’M3GNETDataset’, label_name: str | None = None, graph_labels: list | None = None)
### *class* matgl.graph.data.MGLDataset(converter: GraphConverter, threebody_cutoff: float, structures: list, energies: list | None = None, forces: list | None = None, stresses: list | None = None, labels: list | None = None, name=’M3GNETDataset’, label_name: str | None = None, graph_labels: list | None = None)

Bases: `DGLDataset`

Expand Down
10 changes: 5 additions & 5 deletions docs/matgl.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ MatGL (Materials Graph Library) is a graph deep learning library for materials s
* `GraphConverter.get_graph()`
* `GraphConverter.get_graph_from_processed_structure()`
* matgl.graph.data module
* `M3GNetDataset`
* `M3GNetDataset.has_cache()`
* `M3GNetDataset.load()`
* `M3GNetDataset.process()`
* `M3GNetDataset.save()`
* `MGLDataset`
* `MGLDataset.has_cache()`
* `MGLDataset.load()`
* `MGLDataset.process()`
* `MGLDataset.save()`
* `MEGNetDataset`
* `MEGNetDataset.has_cache()`
* `MEGNetDataset.load()`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
"data = data[data[\"Material\"] != \"Kr\"]\n",
"data = data[data[\"Material\"] != \"Rb\"]\n",
"data = data.set_index(\"Material\")\n",
"print(data)"
"print(data[61:80])"
]
},
{
Expand All @@ -136,7 +136,7 @@
"predicted = []\n",
"mp = []\n",
"os.environ[\"MPRESTER_MUTE_PROGRESS_BARS\"] = \"true\"\n",
"mpr = MPRester()\n",
"mpr = MPRester(\"FwTXcju8unkI2VbInEgZDTN8coDB6S6U\")\n",
"\n",
"# Load the pre-trained M3GNet Potential\n",
"pot = matgl.load_model(\"M3GNet-MP-2021.2.8-PES\")\n",
Expand Down Expand Up @@ -242,10 +242,10 @@
"source": [
"# This generates a pretty markdown table output.\n",
"\n",
"#df = data.sort_values(\"% error vs MP\", key=abs).replace([np.inf, -np.inf], np.nan).dropna()\n",
"#df[\"% error vs MP\"] = [f\"{v*100:.3f}%\" for v in df[\"% error vs MP\"]]\n",
"#df[\"% error vs Expt\"] = [f\"{v*100:.3f}%\" for v in df[\"% error vs Expt\"]]\n",
"#print(df.to_markdown())"
"# df = data.sort_values(\"% error vs MP\", key=abs).replace([np.inf, -np.inf], np.nan).dropna()\n",
"# df[\"% error vs MP\"] = [f\"{v*100:.3f}%\" for v in df[\"% error vs MP\"]]\n",
"# df[\"% error vs Expt\"] = [f\"{v*100:.3f}%\" for v in df[\"% error vs Expt\"]]\n",
"# print(df.to_markdown())"
]
}
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,33 +47,11 @@
"execution_count": null,
"id": "83bf12c4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Full Formula (Sr1 Ti1 O3)\n",
"Reduced Formula: SrTiO3\n",
"abc : 4.500000 4.500000 4.500000\n",
"angles: 90.000000 90.000000 90.000000\n",
"pbc : True True True\n",
"Sites (5)\n",
" # SP a b c\n",
"--- ---- --- --- ---\n",
" 0 Sr 0 0 0\n",
" 1 Ti 0.5 0.5 0.5\n",
" 2 O 0.5 0 0.5\n",
" 3 O 0 0.5 0.5\n",
" 4 O 0.5 0.5 0\n"
]
}
],
"outputs": [],
"source": [
"sto = Structure.from_spacegroup(\n",
" \"Pm-3m\",\n",
" Lattice.cubic(4.5),\n",
" [\"Sr\", \"Ti\", \"O\"],\n",
" [[0, 0, 0], [0.5, 0.5, 0.5], [0.5, 0.5, 0]])\n",
" \"Pm-3m\", Lattice.cubic(4.5), [\"Sr\", \"Ti\", \"O\"], [[0, 0, 0], [0.5, 0.5, 0.5], [0.5, 0.5, 0]]\n",
")\n",
"print(sto)"
]
},
Expand All @@ -90,22 +68,7 @@
"execution_count": null,
"id": "8078bd8d",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1827af8bbaf94a06a1d620b464e6b6a3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Retrieving SummaryDoc documents: 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"mpr = MPRester()\n",
"doc = mpr.summary.search(material_ids=[\"mp-5229\"])[0]\n",
Expand Down Expand Up @@ -137,27 +100,7 @@
"execution_count": null,
"id": "2bb72614",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Full Formula (Sr1 Ti1 O3)\n",
"Reduced Formula: SrTiO3\n",
"abc : 3.937988 3.937988 3.937988\n",
"angles: 90.000001 90.000000 90.000000\n",
"pbc : True True True\n",
"Sites (5)\n",
" # SP a b c\n",
"--- ---- ---- --- ----\n",
" 0 Sr 0 0 -0\n",
" 1 Ti 0.5 0.5 0.5\n",
" 2 O 0.5 0 0.5\n",
" 3 O -0 0.5 0.5\n",
" 4 O 0.5 0.5 0\n"
]
}
],
"outputs": [],
"source": [
"relaxer = Relaxer(potential=pot)\n",
"relax_results = relaxer.relax(sto, fmax=0.01)\n",
Expand All @@ -178,27 +121,7 @@
"execution_count": null,
"id": "814f82d9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Full Formula (Sr1 Ti1 O3)\n",
"Reduced Formula: SrTiO3\n",
"abc : 3.912701 3.912701 3.912701\n",
"angles: 90.000000 90.000000 90.000000\n",
"pbc : True True True\n",
"Sites (5)\n",
" # SP a b c magmom\n",
"--- ---- ---- ---- ---- --------\n",
" 0 Sr -0 -0 -0 -0\n",
" 1 Ti 0.5 0.5 0.5 -0\n",
" 2 O 0.5 -0 0.5 0\n",
" 3 O 0.5 0.5 -0 0\n",
" 4 O -0 0.5 0.5 0\n"
]
}
],
"outputs": [],
"source": [
"print(sto_dft)"
]
Expand All @@ -218,17 +141,7 @@
"execution_count": null,
"id": "bb9276bd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The predicted formation energy for the unrelaxed SrTiO3 is -2.240 eV/atom.\n",
"The predicted formation energy for the relaxed SrTiO3 is -3.495 eV/atom.\n",
"The Materials Project formation energy for DFT-relaxed SrTiO3 is -3.552 eV/atom.\n"
]
}
],
"outputs": [],
"source": [
"# Load the pre-trained MEGNet formation energy model.\n",
"model = matgl.load_model(\"M3GNet-MP-2018.6.1-Eform\")\n",
Expand Down Expand Up @@ -263,40 +176,15 @@
"execution_count": null,
"id": "ab890308",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PBE band gap\n",
"\tUnrelaxed STO = 0.16 eV.\n",
"\tRelaxed STO = 1.72 eV.\n",
"GLLB-SC band gap\n",
"\tUnrelaxed STO = 2.36 eV.\n",
"\tRelaxed STO = 3.22 eV.\n",
"HSE band gap\n",
"\tUnrelaxed STO = 0.53 eV.\n",
"\tRelaxed STO = 3.07 eV.\n",
"SCAN band gap\n",
"\tUnrelaxed STO = 0.72 eV.\n",
"\tRelaxed STO = 2.19 eV.\n",
"The PBE band gap for STO from Materials Project is 1.77 eV.\n"
]
}
],
"outputs": [],
"source": [
"model = matgl.load_model(\"MEGNet-MP-2019.4.1-BandGap-mfi\")\n",
"\n",
"# For multi-fidelity models, we need to define graph label (\"0\": PBE, \"1\": GLLB-SC, \"2\": HSE, \"3\": SCAN)\n",
"for i, method in ((0, \"PBE\"), (1, \"GLLB-SC\"), (2, \"HSE\"), (3, \"SCAN\")):\n",
"\n",
" graph_attrs = torch.tensor([i])\n",
" bandgap_sto = model.predict_structure(\n",
" structure=sto, state_feats=graph_attrs\n",
" )\n",
" bandgap_relaxed_sto = model.predict_structure(\n",
" structure=relaxed_sto, state_feats=graph_attrs\n",
" )\n",
" bandgap_sto = model.predict_structure(structure=sto, state_attr=graph_attrs)\n",
" bandgap_relaxed_sto = model.predict_structure(structure=relaxed_sto, state_attr=graph_attrs)\n",
"\n",
" print(f\"{method} band gap\")\n",
" print(f\"\\tUnrelaxed STO = {float(bandgap_sto):.2f} eV.\")\n",
Expand Down
38 changes: 4 additions & 34 deletions examples/Property Predictions using MEGNet or M3GNet Models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,7 @@
"execution_count": null,
"id": "ce4e9336",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The predicted formation energy for CsCl is -2.272 eV/atom.\n"
]
}
],
"outputs": [],
"source": [
"# Load the pre-trained MEGNet formation energy model.\n",
"model = matgl.load_model(\"MEGNet-MP-2018.6.1-Eform\")\n",
Expand All @@ -103,15 +95,7 @@
"execution_count": null,
"id": "4aef3fec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The predicted formation energy for CsCl is -2.220 eV/atom.\n"
]
}
],
"outputs": [],
"source": [
"# Load the pre-trained M3GNet formation energy model\n",
"model = matgl.load_model(\"M3GNet-MP-2018.6.1-Eform\")\n",
Expand All @@ -135,28 +119,14 @@
"execution_count": null,
"id": "f2ec796f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The predicted PBE band gap for CsCl is 4.825 eV.\n",
"The predicted GLLB-SC band gap for CsCl is 8.323 eV.\n",
"The predicted HSE band gap for CsCl is 6.317 eV.\n",
"The predicted SCAN band gap for CsCl is 5.965 eV.\n"
]
}
],
"outputs": [],
"source": [
"model = matgl.load_model(\"MEGNet-MP-2019.4.1-BandGap-mfi\")\n",
"\n",
"# For multi-fidelity models, we need to define graph label (\"0\": PBE, \"1\": GLLB-SC, \"2\": HSE, \"3\": SCAN)\n",
"for i, method in ((0, \"PBE\"), (1, \"GLLB-SC\"), (2, \"HSE\"), (3, \"SCAN\")):\n",
"\n",
" graph_attrs = torch.tensor([i])\n",
" bandgap = model.predict_structure(\n",
" structure=struct, state_feats=graph_attrs\n",
" )\n",
" bandgap = model.predict_structure(structure=struct, state_attr=graph_attrs)\n",
" print(f\"The predicted {method} band gap for CsCl is {float(bandgap):.3f} eV.\")"
]
}
Expand Down
Loading

0 comments on commit 1677d7e

Please sign in to comment.