diff --git a/examples/Training a M3GNet Formation Energy Model with PyTorch Lightning.ipynb b/examples/Training a M3GNet Formation Energy Model with PyTorch Lightning.ipynb index 51732ae7..05f23eb6 100644 --- a/examples/Training a M3GNet Formation Energy Model with PyTorch Lightning.ipynb +++ b/examples/Training a M3GNet Formation Energy Model with PyTorch Lightning.ipynb @@ -125,9 +125,13 @@ "elem_list = get_element_list(structures)\n", "# setup a graph converter\n", "converter = Structure2Graph(element_types=elem_list, cutoff=4.0)\n", - "# convert the raw dataset into MEGNetDataset\n", + "# convert the raw dataset into M3GNetDataset\n", "mp_dataset = MGLDataset(\n", - " threebody_cutoff=4.0, structures=structures, converter=converter, labels={\"eform\": eform_per_atom}\n", + " threebody_cutoff=4.0,\n", + " structures=structures,\n", + " converter=converter,\n", + " labels={\"eform\": eform_per_atom},\n", + " include_line_graph=True,\n", ")" ] }, @@ -170,7 +174,7 @@ "source": [ "# Model setup\n", "\n", - "In the next step, we setup the model and the ModelLightningModule. Here, we have initialized a MEGNet model from scratch. Alternatively, you can also load one of the pre-trained models for transfer learning, which may speed up the training." + "In the next step, we setup the model and the ModelLightningModule. Here, we have initialized a M3GNet model from scratch. Alternatively, you can also load one of the pre-trained models for transfer learning, which may speed up the training." ] }, { @@ -180,14 +184,14 @@ "metadata": {}, "outputs": [], "source": [ - "# setup the architecture of MEGNet model\n", + "# setup the architecture of M3GNet model\n", "model = M3GNet(\n", " element_types=elem_list,\n", " is_intensive=True,\n", " readout_type=\"set2set\",\n", ")\n", - "# setup the MEGNetTrainer\n", - "lit_module = ModelLightningModule(model=model)" + "# setup the M3GNetTrainer\n", + "lit_module = ModelLightningModule(model=model, include_line_graph=True)" ] }, { @@ -221,7 +225,7 @@ "source": [ "# Visualizing the convergence\n", "\n", - "Finally, we can plot the convergence plot for the loss metrics. You can see that the MAE is already going down nicely with 20 epochs. Obviously, this is nowhere state of the art performance for the formation energies, but a longer training time should lead to results consistent with what was reported in the original MEGNet work." + "Finally, we can plot the convergence plot for the loss metrics. You can see that the MAE is already going down nicely with 20 epochs. Obviously, this is nowhere state of the art performance for the formation energies, but a longer training time should lead to results consistent with what was reported in the original M3GNet work." ] }, { @@ -273,7 +277,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.9.18" } }, "nbformat": 4, diff --git a/examples/Training a M3GNet Potential with PyTorch Lightning.ipynb b/examples/Training a M3GNet Potential with PyTorch Lightning.ipynb index 6766d505..6e8ab53a 100644 --- a/examples/Training a M3GNet Potential with PyTorch Lightning.ipynb +++ b/examples/Training a M3GNet Potential with PyTorch Lightning.ipynb @@ -30,6 +30,7 @@ "\n", "import numpy as np\n", "import pytorch_lightning as pl\n", + "from functools import partial\n", "from dgl.data.utils import split_dataset\n", "from mp_api.client import MPRester\n", "from pytorch_lightning.loggers import CSVLogger\n", @@ -126,10 +127,7 @@ "element_types = get_element_list(structures)\n", "converter = Structure2Graph(element_types=element_types, cutoff=5.0)\n", "dataset = MGLDataset(\n", - " threebody_cutoff=4.0,\n", - " structures=structures,\n", - " converter=converter,\n", - " labels=labels,\n", + " threebody_cutoff=4.0, structures=structures, converter=converter, labels=labels, include_line_graph=True\n", ")\n", "train_data, val_data, test_data = split_dataset(\n", " dataset,\n", @@ -137,11 +135,12 @@ " shuffle=True,\n", " random_state=42,\n", ")\n", + "my_collate_fn = partial(collate_fn_efs, include_line_graph=True)\n", "train_loader, val_loader, test_loader = MGLDataLoader(\n", " train_data=train_data,\n", " val_data=val_data,\n", " test_data=test_data,\n", - " collate_fn=collate_fn_efs,\n", + " collate_fn=my_collate_fn,\n", " batch_size=2,\n", " num_workers=0,\n", ")\n", @@ -149,7 +148,7 @@ " element_types=element_types,\n", " is_intensive=False,\n", ")\n", - "lit_module = PotentialLightningModule(model=model)" + "lit_module = PotentialLightningModule(model=model, include_line_graph=True)" ] }, { @@ -268,7 +267,7 @@ "# download a pre-trained M3GNet\n", "m3gnet_nnp = matgl.load_model(\"M3GNet-MP-2021.2.8-PES\")\n", "model_pretrained = m3gnet_nnp.model\n", - "lit_module_finetune = PotentialLightningModule(model=model_pretrained, lr=1e-4)" + "lit_module_finetune = PotentialLightningModule(model=model_pretrained, lr=1e-4, include_line_graph=True)" ] }, { @@ -384,7 +383,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.9.18" } }, "nbformat": 4,