From 6408de5a2ae5760ad2f69c9272a3a26d37f83c01 Mon Sep 17 00:00:00 2001 From: BowenD-UCB <84425382+BowenD-UCB@users.noreply.github.com> Date: Mon, 29 Apr 2024 16:27:36 -0700 Subject: [PATCH] Adding function to initialize StructureData from vasp dir --- chgnet/data/dataset.py | 49 +++++++++++- chgnet/utils/vasp_utils.py | 12 ++- examples/fine_tuning.ipynb | 155 +++++++++++++++++++++++++------------ tests/test_dataset.py | 2 +- tests/test_vasp_utils.py | 16 +++- 5 files changed, 181 insertions(+), 53 deletions(-) diff --git a/chgnet/data/dataset.py b/chgnet/data/dataset.py index e2352951..084a8ca5 100644 --- a/chgnet/data/dataset.py +++ b/chgnet/data/dataset.py @@ -35,7 +35,7 @@ def __init__( forces: list[Sequence[Sequence[float]]], stresses: list[Sequence[Sequence[float]]] | None = None, magmoms: list[Sequence[Sequence[float]]] | None = None, - structure_ids: list[str] | None = None, + structure_ids: list | None = None, graph_converter: CrystalGraphConverter | None = None, shuffle: bool = True, ) -> None: @@ -49,7 +49,7 @@ def __init__( Default = None magmoms (list[list[float]], optional): [data_size, n_atoms, 1] Default = None - structure_ids (list[str], optional): a list of ids to track the structures + structure_ids (list, optional): a list of ids to track the structures Default = None graph_converter (CrystalGraphConverter, optional): Converts the structures to graphs. If None, it will be set to CHGNet 0.3.0 converter @@ -87,6 +87,51 @@ def __init__( self.failed_idx: list[int] = [] self.failed_graph_id: dict[str, str] = {} + @classmethod + def from_vasp( + cls, + file_root: str, + check_electronic_convergence: bool = True, + save_path: str | None = None, + graph_converter: CrystalGraphConverter | None = None, + shuffle: bool = True, + ): + """Parse VASP output files into structures and labels and feed into the dataset. + + Args: + file_root (str): the directory of the VASP calculation outputs + check_electronic_convergence (bool): if set to True, this function will + raise Exception to VASP calculation that did not achieve + electronic convergence. + Default = True + save_path (str): path to save the parsed VASP labels + Default = None + graph_converter (CrystalGraphConverter, optional): Converts the structures + to graphs. If None, it will be set to CHGNet 0.3.0 converter + with AtomGraph cutoff = 6A. + shuffle (bool): whether to shuffle the sequence of dataset + Default = True + """ + result_dict = utils.parse_vasp_dir( + file_root=file_root, + check_electronic_convergence=check_electronic_convergence, + save_path=save_path, + ) + return cls( + structures=result_dict["structure"], + energies=result_dict["energy_per_atom"], + forces=result_dict["force"], + stresses=None + if result_dict["stress"] in [None, []] + else result_dict["stress"], + magmoms=None + if result_dict["magmom"] in [None, []] + else result_dict["magmom"], + structure_ids=np.arange(len(result_dict["structure"])), + graph_converter=graph_converter, + shuffle=shuffle, + ) + def __len__(self) -> int: """Get the number of structures in this dataset.""" return len(self.keys) diff --git a/chgnet/utils/vasp_utils.py b/chgnet/utils/vasp_utils.py index 6cead739..0fdcf70e 100644 --- a/chgnet/utils/vasp_utils.py +++ b/chgnet/utils/vasp_utils.py @@ -7,12 +7,16 @@ from monty.io import reverse_readfile from pymatgen.io.vasp.outputs import Oszicar, Vasprun +from chgnet.utils import write_json + if TYPE_CHECKING: from pymatgen.core import Structure def parse_vasp_dir( - file_root: str, check_electronic_convergence: bool = True + file_root: str, + check_electronic_convergence: bool = True, + save_path: str | None = None, ) -> dict[str, list]: """Parse VASP output files into structures and labels By default, the magnetization is read from mag_x from VASP, @@ -22,6 +26,8 @@ def parse_vasp_dir( file_root (str): the directory of the VASP calculation outputs check_electronic_convergence (bool): if set to True, this function will raise Exception to VASP calculation that did not achieve electronic convergence. + Default = True + save_path (str): path to save the parsed VASP labels """ if os.path.exists(file_root) is False: raise FileNotFoundError("No such file or directory") @@ -153,6 +159,10 @@ def parse_vasp_dir( if dataset["uncorrected_total_energy"] == []: raise RuntimeError(f"No data parsed from {file_root}!") + if save_path is not None: + save_dict = dataset.copy() + save_dict["structure"] = [struct.as_dict() for struct in dataset["structure"]] + write_json(save_dict, save_path) return dataset diff --git a/examples/fine_tuning.ipynb b/examples/fine_tuning.ipynb index c8b26eb8..07de02ad 100644 --- a/examples/fine_tuning.ipynb +++ b/examples/fine_tuning.ipynb @@ -69,7 +69,7 @@ "from chgnet.utils import parse_vasp_dir\n", "\n", "# ./my_vasp_calc_dir contains vasprun.xml OSZICAR etc.\n", - "dataset_dict = parse_vasp_dir(file_root=\"./my_vasp_calc_dir\")\n", + "dataset_dict = parse_vasp_dir(file_root=\"./my_vasp_calc_dir\", save_path='./my_vasp_calc_dir/chgnet_dataset.json')\n", "print(list(dataset_dict))" ] }, @@ -78,19 +78,31 @@ "id": "6", "metadata": {}, "source": [ - "After the DFT calculations are parsed, we can save the parsed structures and labels to disk,\n", - "so that they can be easily reloaded during multiple rounds of training.\n", - "The Pymatgen structures can be saved in either json, pickle, cif, or CHGNet graph.\n", + "The parsed python dictionary includes information for CHGNet inputs (structures), and CHGNet prediction labels (energy, force, stress ,magmom). \n", + "\n", + "we can save the parsed structures and labels to disk, so that they can be easily reloaded during multiple rounds of training.\n", + "\n", + "The json file can be saved by providing the save_path" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "The Pymatgen structures can be saved separately if you're interested to take a look into each structure.\n", + "\n", + "Below are the example codes to save the structures in either json, pickle, cif, or CHGNet graph.\n", "\n", "For super-large training dataset, like MPtrj dataset, we recommend [converting them to CHGNet graphs](https://github.com/CederGroupHub/chgnet/blob/main/examples/make_graphs.py). This will save significant memory and graph computing time.\n", "\n", - "Below are the example codes to save the structures.\n" + "\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "8", "metadata": { "pycharm": { "name": "#%%\n" @@ -128,7 +140,7 @@ }, { "cell_type": "markdown", - "id": "8", + "id": "9", "metadata": {}, "source": [ "For other types of DFT calculations, please refer to their interfaces\n", @@ -143,7 +155,7 @@ }, { "cell_type": "markdown", - "id": "9", + "id": "10", "metadata": {}, "source": [ "## 1. Prepare Training Data\n" @@ -151,17 +163,41 @@ }, { "cell_type": "markdown", - "id": "10", + "id": "11", "metadata": {}, "source": [ - "Below we will create a dummy fine-tuning dataset by using CHGNet prediction with some random noise.\n", - "For your purpose of fine-tuning to a specific chemical system or AIMD data, please modify the block below\n" + "If you have parsed your VASP labels from step 0, you can reload the saved json file." ] }, { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "from chgnet.utils import read_json\n", + "\n", + "dataset_dict = read_json('./my_vasp_calc_dir/chgnet_dataset.json')\n", + "structures = [Structure.from_dict(struct) for struct in dataset_dict['structure']]\n", + "energies = dataset_dict['energy_per_atom']\n", + "forces = dataset_dict[\"force\"]\n", + "stresses = None if result_dict['stress'] in [None, []] else result_dict['stress']\n", + "magmoms = None if result_dict['magmom'] in [None, []] else result_dict['magmom']" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "If you don't have any DFT calculations now, we can create a dummy fine-tuning dataset by using CHGNet prediction with some random noise." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -198,17 +234,18 @@ }, { "cell_type": "markdown", - "id": "12", + "id": "15", "metadata": {}, "source": [ "Note that the stress output from CHGNet is in unit of GPa, here the -10 unit conversion\n", - "modifies it to be kbar in VASP raw unit. We do this since by default, StructureData\n", - "dataset class takes in VASP units.\n" + "modifies it to be kbar in VASP raw unit. \n", + "If you're using stress labels from VASP, you don't need to do any unit conversions\n", + "StructureData dataset class takes in VASP units.\n" ] }, { "cell_type": "markdown", - "id": "13", + "id": "16", "metadata": {}, "source": [ "## 2. Define DataSet\n" @@ -217,7 +254,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -227,7 +264,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "18", "metadata": {}, "outputs": [ { @@ -253,7 +290,29 @@ }, { "cell_type": "markdown", - "id": "16", + "id": "19", + "metadata": {}, + "source": [ + "Alternatively, the dataset can be directly created from VASP calculation dir.\n", + "This function essentially parse the VASP directory first, save the labels to json file, and create the StructureData class" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = StructureData.from_vasp(\n", + " file_root=\"./my_vasp_calc_dir\", \n", + " save_path='./my_vasp_calc_dir/chgnet_dataset.json'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "21", "metadata": {}, "source": [ "The training set is used to optimize the CHGNet through gradient descent, the validation set is used to see validation error at the end of each epoch, and the test set is used to see the final test error at the end of training. The test set can be optional.\n", @@ -265,7 +324,7 @@ }, { "cell_type": "markdown", - "id": "17", + "id": "22", "metadata": {}, "source": [ "## 3. Define model and trainer\n" @@ -274,7 +333,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "23", "metadata": {}, "outputs": [ { @@ -296,7 +355,7 @@ }, { "cell_type": "markdown", - "id": "19", + "id": "24", "metadata": {}, "source": [ "It's optional to freeze the weights inside some layers. This is a common technique to retain the learned knowledge during fine-tuning in large pretrained neural networks. You can choose the layers you want to freeze.\n" @@ -305,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -327,7 +386,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -347,7 +406,7 @@ }, { "cell_type": "markdown", - "id": "22", + "id": "27", "metadata": {}, "source": [ "## 4. Start training\n" @@ -356,7 +415,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "28", "metadata": {}, "outputs": [ { @@ -396,7 +455,7 @@ }, { "cell_type": "markdown", - "id": "24", + "id": "29", "metadata": {}, "source": [ "After training, the trained model can be found in the directory of today's date. Or it can be accessed by:\n" @@ -405,7 +464,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -415,7 +474,7 @@ }, { "cell_type": "markdown", - "id": "26", + "id": "31", "metadata": {}, "source": [ "## Extras 1: GGA / GGA+U compatibility\n" @@ -423,7 +482,7 @@ }, { "cell_type": "markdown", - "id": "27", + "id": "32", "metadata": {}, "source": [ "### Q: Why and when do you care about this?\n", @@ -440,7 +499,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28", + "id": "33", "metadata": {}, "outputs": [ { @@ -460,7 +519,7 @@ }, { "cell_type": "markdown", - "id": "29", + "id": "34", "metadata": {}, "source": [ "You can look for the energy correction applied to each element in :\n", @@ -472,7 +531,7 @@ }, { "cell_type": "markdown", - "id": "30", + "id": "35", "metadata": {}, "source": [ "To demystify `MaterialsProject2020Compatibility`, basically all that's happening is:\n" @@ -481,7 +540,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "36", "metadata": {}, "outputs": [ { @@ -506,7 +565,7 @@ }, { "cell_type": "markdown", - "id": "32", + "id": "37", "metadata": {}, "source": [ "You can also apply the `MaterialsProject2020Compatibility` through pymatgen\n" @@ -515,7 +574,7 @@ { "cell_type": "code", "execution_count": null, - "id": "33", + "id": "38", "metadata": {}, "outputs": [ { @@ -542,7 +601,7 @@ }, { "cell_type": "markdown", - "id": "34", + "id": "39", "metadata": {}, "source": [ "Now use this corrected energy as labels to tune CHGNet, you're good to go!\n" @@ -550,7 +609,7 @@ }, { "cell_type": "markdown", - "id": "35", + "id": "40", "metadata": {}, "source": [ "## Extras 2: AtomRef\n" @@ -558,7 +617,7 @@ }, { "cell_type": "markdown", - "id": "36", + "id": "41", "metadata": {}, "source": [ "### Q: Why and when do you care about this?\n", @@ -572,7 +631,7 @@ }, { "cell_type": "markdown", - "id": "37", + "id": "42", "metadata": {}, "source": [ "### A quick and easy way to turn on training of AtomRef in the trainer (this is by default off):\n" @@ -581,7 +640,7 @@ { "cell_type": "code", "execution_count": null, - "id": "38", + "id": "43", "metadata": {}, "outputs": [ { @@ -621,7 +680,7 @@ }, { "cell_type": "markdown", - "id": "39", + "id": "44", "metadata": {}, "source": [ "### The more regorous way is to solve for the per-atom contribution by linear regression in your fine-tuning dataset\n" @@ -630,7 +689,7 @@ { "cell_type": "code", "execution_count": null, - "id": "40", + "id": "45", "metadata": {}, "outputs": [ { @@ -665,7 +724,7 @@ { "cell_type": "code", "execution_count": null, - "id": "41", + "id": "46", "metadata": {}, "outputs": [], "source": [ @@ -696,7 +755,7 @@ { "cell_type": "code", "execution_count": null, - "id": "42", + "id": "47", "metadata": {}, "outputs": [ { @@ -721,7 +780,7 @@ { "cell_type": "code", "execution_count": null, - "id": "43", + "id": "48", "metadata": {}, "outputs": [ { @@ -765,7 +824,7 @@ ], "metadata": { "kernelspec": { - "display_name": "py311", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -779,7 +838,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.9.7" } }, "nbformat": 4, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 26a2e210..b6e9d275 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -96,7 +96,7 @@ def test_dataset_no_shuffling(): structure_ids = list(range(n_samples)) structure_data = StructureData( - structures=[NaCl.copy().perturb(0.1) for _ in range(n_samples)], + structures=[NaCl.copy() for _ in range(n_samples)], energies=np.random.random(n_samples), forces=np.random.random([n_samples, 2, 3]), stresses=np.random.random([n_samples, 3, 3]), diff --git a/tests/test_vasp_utils.py b/tests/test_vasp_utils.py index e94a1af3..10cc8f16 100644 --- a/tests/test_vasp_utils.py +++ b/tests/test_vasp_utils.py @@ -8,6 +8,7 @@ from pymatgen.core import Structure from chgnet import ROOT +from chgnet.data.dataset import StructureData from chgnet.utils import parse_vasp_dir if TYPE_CHECKING: @@ -18,7 +19,7 @@ def test_parse_vasp_dir_with_magmoms(tmp_path: Path): with ZipFile(f"{ROOT}/tests/files/parse-vasp-with-magmoms.zip") as zip_ref: zip_ref.extractall(tmp_path) vasp_path = os.path.join(tmp_path, "parse-vasp-with-magmoms") - dataset_dict = parse_vasp_dir(vasp_path) + dataset_dict = parse_vasp_dir(vasp_path, save_path=f"{tmp_path}/tmp.json") assert isinstance(dataset_dict, dict) assert len(dataset_dict["structure"]) > 0 @@ -67,3 +68,16 @@ def test_parse_vasp_dir_no_data(): # test existing directory without VASP files with pytest.raises(RuntimeError, match="No data parsed from"): parse_vasp_dir(f"{ROOT}/tests/files") + + +def test_dataset_from_vasp_dir(tmp_path: Path): + with ZipFile(f"{ROOT}/tests/files/parse-vasp-with-magmoms.zip") as zip_ref: + zip_ref.extractall(tmp_path) + vasp_path = os.path.join(tmp_path, "parse-vasp-with-magmoms") + dataset = StructureData.from_vasp(vasp_path, save_path=f"{tmp_path}/tmp.json") + assert len(dataset.structures) > 0 + assert isinstance(dataset.structures[0], Structure) + assert len(dataset.structures) == len(dataset.energies) + assert len(dataset.structures) == len(dataset.forces) + assert len(dataset.structures) == len(dataset.stresses) + assert len(dataset.structures) == len(dataset.magmoms)