Skip to content

Commit

Permalink
restructure matgl modules for CHGNet implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
kenko911 committed May 5, 2024
1 parent a95c104 commit e035d1c
Show file tree
Hide file tree
Showing 22 changed files with 281 additions and 1,306 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
"predicted = []\n",
"mp = []\n",
"os.environ[\"MPRESTER_MUTE_PROGRESS_BARS\"] = \"true\"\n",
"mpr = MPRester(\"FwTXcju8unkI2VbInEgZDTN8coDB6S6U\")\n",
"mpr = MPRester(\"YOUR_API_KEY\")\n",
"\n",
"# Load the pre-trained M3GNet Potential\n",
"pot = matgl.load_model(\"M3GNet-MP-2021.2.8-PES\")\n",
Expand Down Expand Up @@ -265,7 +265,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.14"
},
"vscode": {
"interpreter": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@
"\n",
"import matgl\n",
"from matgl.ext.pymatgen import Structure2Graph, get_element_list\n",
"from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn_efs\n",
"from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn_pes\n",
"from matgl.models import M3GNet\n",
"from matgl.utils.training import PotentialLightningModule\n",
"from matgl.config import DEFAULT_ELEMENTS\n",
"\n",
"# To suppress warnings for clearer output\n",
"warnings.simplefilter(\"ignore\")"
Expand Down Expand Up @@ -123,7 +124,7 @@
},
"outputs": [],
"source": [
"element_types = get_element_list(structures)\n",
"element_types = DEFAULT_ELEMENTS\n",
"converter = Structure2Graph(element_types=element_types, cutoff=5.0)\n",
"dataset = MGLDataset(\n",
" threebody_cutoff=4.0, structures=structures, converter=converter, labels=labels, include_line_graph=True\n",
Expand All @@ -134,7 +135,7 @@
" shuffle=True,\n",
" random_state=42,\n",
")\n",
"my_collate_fn = partial(collate_fn_efs, include_line_graph=True)\n",
"my_collate_fn = partial(collate_fn_pes, include_line_graph=True)\n",
"train_loader, val_loader, test_loader = MGLDataLoader(\n",
" train_data=train_data,\n",
" val_data=val_data,\n",
Expand Down Expand Up @@ -239,7 +240,7 @@
"source": [
"# save trained model\n",
"model_export_path = \"./trained_model/\"\n",
"model.save(model_export_path)\n",
"lit_module.model.save(model_export_path)\n",
"\n",
"# load trained model\n",
"model = matgl.load_model(path=model_export_path)"
Expand Down Expand Up @@ -335,7 +336,7 @@
"source": [
"# save trained model\n",
"model_save_path = \"./finetuned_model/\"\n",
"model_pretrained.save(model_save_path)\n",
"lit_module_finetune.model.save(model_save_path)\n",
"# load trained model\n",
"trained_model = matgl.load_model(path=model_save_path)"
]
Expand Down Expand Up @@ -382,7 +383,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit e035d1c

Please sign in to comment.