Skip to content

Commit

Permalink
Fix jupyter notebook for M3GNet property model and MLIP training (#247)
Browse files Browse the repository at this point in the history
* model version for Potential class is added

* model version for Potential class is modified

* Enable the smooth version of Spherical Bessel function in TensorNet

* max_n, max_l for SphericalBessel radial basis functions are included in TensorNet class

* adding united tests for improving the coverage score

* little clean up in _so3.py and so3.py

* remove unnecessary data storage in dgl graphs

* update pymatgen version to fix the bug

* refractor all include_states into include_state for consistency

* change include_states into include_state in test_graph_conv.py

* Ensure the state attr from molecule graph is consistent with matgl.float_th and including linear layer in TensorNet to match the original implementations

* Fix the jupyter-notebook for M3GNet training
  • Loading branch information
kenko911 authored Apr 1, 2024
1 parent cb41e60 commit 46653e4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,13 @@
"elem_list = get_element_list(structures)\n",
"# setup a graph converter\n",
"converter = Structure2Graph(element_types=elem_list, cutoff=4.0)\n",
"# convert the raw dataset into MEGNetDataset\n",
"# convert the raw dataset into M3GNetDataset\n",
"mp_dataset = MGLDataset(\n",
" threebody_cutoff=4.0, structures=structures, converter=converter, labels={\"eform\": eform_per_atom}\n",
" threebody_cutoff=4.0,\n",
" structures=structures,\n",
" converter=converter,\n",
" labels={\"eform\": eform_per_atom},\n",
" include_line_graph=True,\n",
")"
]
},
Expand Down Expand Up @@ -170,7 +174,7 @@
"source": [
"# Model setup\n",
"\n",
"In the next step, we setup the model and the ModelLightningModule. Here, we have initialized a MEGNet model from scratch. Alternatively, you can also load one of the pre-trained models for transfer learning, which may speed up the training."
"In the next step, we setup the model and the ModelLightningModule. Here, we have initialized a M3GNet model from scratch. Alternatively, you can also load one of the pre-trained models for transfer learning, which may speed up the training."
]
},
{
Expand All @@ -180,14 +184,14 @@
"metadata": {},
"outputs": [],
"source": [
"# setup the architecture of MEGNet model\n",
"# setup the architecture of M3GNet model\n",
"model = M3GNet(\n",
" element_types=elem_list,\n",
" is_intensive=True,\n",
" readout_type=\"set2set\",\n",
")\n",
"# setup the MEGNetTrainer\n",
"lit_module = ModelLightningModule(model=model)"
"# setup the M3GNetTrainer\n",
"lit_module = ModelLightningModule(model=model, include_line_graph=True)"
]
},
{
Expand Down Expand Up @@ -221,7 +225,7 @@
"source": [
"# Visualizing the convergence\n",
"\n",
"Finally, we can plot the convergence plot for the loss metrics. You can see that the MAE is already going down nicely with 20 epochs. Obviously, this is nowhere state of the art performance for the formation energies, but a longer training time should lead to results consistent with what was reported in the original MEGNet work."
"Finally, we can plot the convergence plot for the loss metrics. You can see that the MAE is already going down nicely with 20 epochs. Obviously, this is nowhere state of the art performance for the formation energies, but a longer training time should lead to results consistent with what was reported in the original M3GNet work."
]
},
{
Expand Down Expand Up @@ -273,7 +277,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"\n",
"import numpy as np\n",
"import pytorch_lightning as pl\n",
"from functools import partial\n",
"from dgl.data.utils import split_dataset\n",
"from mp_api.client import MPRester\n",
"from pytorch_lightning.loggers import CSVLogger\n",
Expand Down Expand Up @@ -126,30 +127,28 @@
"element_types = get_element_list(structures)\n",
"converter = Structure2Graph(element_types=element_types, cutoff=5.0)\n",
"dataset = MGLDataset(\n",
" threebody_cutoff=4.0,\n",
" structures=structures,\n",
" converter=converter,\n",
" labels=labels,\n",
" threebody_cutoff=4.0, structures=structures, converter=converter, labels=labels, include_line_graph=True\n",
")\n",
"train_data, val_data, test_data = split_dataset(\n",
" dataset,\n",
" frac_list=[0.8, 0.1, 0.1],\n",
" shuffle=True,\n",
" random_state=42,\n",
")\n",
"my_collate_fn = partial(collate_fn_efs, include_line_graph=True)\n",
"train_loader, val_loader, test_loader = MGLDataLoader(\n",
" train_data=train_data,\n",
" val_data=val_data,\n",
" test_data=test_data,\n",
" collate_fn=collate_fn_efs,\n",
" collate_fn=my_collate_fn,\n",
" batch_size=2,\n",
" num_workers=0,\n",
")\n",
"model = M3GNet(\n",
" element_types=element_types,\n",
" is_intensive=False,\n",
")\n",
"lit_module = PotentialLightningModule(model=model)"
"lit_module = PotentialLightningModule(model=model, include_line_graph=True)"
]
},
{
Expand Down Expand Up @@ -268,7 +267,7 @@
"# download a pre-trained M3GNet\n",
"m3gnet_nnp = matgl.load_model(\"M3GNet-MP-2021.2.8-PES\")\n",
"model_pretrained = m3gnet_nnp.model\n",
"lit_module_finetune = PotentialLightningModule(model=model_pretrained, lr=1e-4)"
"lit_module_finetune = PotentialLightningModule(model=model_pretrained, lr=1e-4, include_line_graph=True)"
]
},
{
Expand Down Expand Up @@ -384,7 +383,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 46653e4

Please sign in to comment.