From ad132a26f8ab200e259aafc3df7fe82dabc42a8a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 18 Jun 2023 13:41:58 -0400 Subject: [PATCH 1/4] load files according to filename suffix The current `try...catch` does not report what's wrong in the YAML file if it is invalid. Signed-off-by: Jinzhe Zeng --- dpgen/data/gen.py | 18 +++--------------- dpgen/data/reaction.py | 18 +++--------------- dpgen/data/surf.py | 17 +++-------------- dpgen/generator/run.py | 16 ++++------------ dpgen/simplify/simplify.py | 17 +++-------------- dpgen/util.py | 33 +++++++++++++++++++++++++++++++++ tests/sample.json | 1 + tests/sample.yaml | 1 + tests/test_load_file.py | 17 +++++++++++++++++ 9 files changed, 68 insertions(+), 70 deletions(-) create mode 100644 tests/sample.json create mode 100644 tests/sample.yaml create mode 100644 tests/test_load_file.py diff --git a/dpgen/data/gen.py b/dpgen/data/gen.py index 83bd59745..2a4f3ece0 100644 --- a/dpgen/data/gen.py +++ b/dpgen/data/gen.py @@ -33,6 +33,7 @@ from dpgen.generator.lib.utils import symlink_user_forward_files from dpgen.generator.lib.vasp import incar_upper from dpgen.remote.decide_machine import convert_mdata +from dpgen.util import load_file def create_path(path, back=False): @@ -1465,22 +1466,9 @@ def run_abacus_md(jdata, mdata): def gen_init_bulk(args): - try: - import ruamel - from monty.serialization import loadfn - - warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning) - jdata = loadfn(args.PARAM) - if args.MACHINE is not None: - mdata = loadfn(args.MACHINE) - except Exception: - with open(args.PARAM) as fp: - jdata = json.load(fp) - if args.MACHINE is not None: - with open(args.MACHINE) as fp: - mdata = json.load(fp) - + jdata = load_file(args.PARAM) if args.MACHINE is not None: + mdata = load_file(args.MACHINE) # Selecting a proper machine mdata = convert_mdata(mdata, ["fp"]) # disp = make_dispatcher(mdata["fp_machine"]) diff --git a/dpgen/data/reaction.py b/dpgen/data/reaction.py index adbd453ac..c243f4fbb 100644 --- a/dpgen/data/reaction.py +++ b/dpgen/data/reaction.py @@ -18,7 +18,7 @@ from dpgen.dispatcher.Dispatcher import make_submission_compat from dpgen.generator.run import create_path, make_fp_task_name from dpgen.remote.decide_machine import convert_mdata -from dpgen.util import normalize, sepline +from dpgen.util import load_file, normalize, sepline from .arginfo import init_reaction_jdata_arginfo @@ -214,20 +214,8 @@ def convert_data(jdata): def gen_init_reaction(args): - try: - import ruamel - from monty.serialization import loadfn - - warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning) - jdata = loadfn(args.PARAM) - if args.MACHINE is not None: - mdata = loadfn(args.MACHINE) - except Exception: - with open(args.PARAM) as fp: - jdata = json.load(fp) - if args.MACHINE is not None: - with open(args.MACHINE) as fp: - mdata = json.load(fp) + jdata = load_file(args.PARAM) + mdata = load_file(args.MACHINE) jdata_arginfo = init_reaction_jdata_arginfo() jdata = normalize(jdata_arginfo, jdata) diff --git a/dpgen/data/surf.py b/dpgen/data/surf.py index 4a61ac418..284a0e97f 100644 --- a/dpgen/data/surf.py +++ b/dpgen/data/surf.py @@ -29,6 +29,7 @@ from dpgen.dispatcher.Dispatcher import make_submission_compat from dpgen.generator.lib.utils import symlink_user_forward_files from dpgen.remote.decide_machine import convert_mdata +from dpgen.util import load_file def create_path(path): @@ -602,26 +603,14 @@ def run_vasp_relax(jdata, mdata): def gen_init_surf(args): - try: - import ruamel - from monty.serialization import loadfn - - warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning) - jdata = loadfn(args.PARAM) - if args.MACHINE is not None: - mdata = loadfn(args.MACHINE) - except Exception: - with open(args.PARAM) as fp: - jdata = json.load(fp) - if args.MACHINE is not None: - with open(args.MACHINE) as fp: - mdata = json.load(fp) + jdata = load_file(args.PARAM) out_dir = out_dir_name(jdata) jdata["out_dir"] = out_dir dlog.info("# working dir %s" % out_dir) if args.MACHINE is not None: + mdata = load_file(args.MACHINE) # Decide a proper machine mdata = convert_mdata(mdata, ["fp"]) # disp = make_dispatcher(mdata["fp_machine"]) diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 40dcfbd78..3a8b65147 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -92,6 +92,7 @@ from dpgen.util import ( convert_training_data_to_hdf5, expand_sys_str, + load_file, normalize, sepline, set_directory, @@ -4494,18 +4495,8 @@ def set_version(mdata): def run_iter(param_file, machine_file): - try: - import ruamel - from monty.serialization import dumpfn, loadfn - - warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning) - jdata = loadfn(param_file) - mdata = loadfn(machine_file) - except Exception: - with open(param_file) as fp: - jdata = json.load(fp) - with open(machine_file) as fp: - mdata = json.load(fp) + jdata = load_file(param_file) + mdata = load_file(machine_file) jdata_arginfo = run_jdata_arginfo() jdata = normalize(jdata_arginfo, jdata, strict_check=False) @@ -4513,6 +4504,7 @@ def run_iter(param_file, machine_file): update_mass_map(jdata) if jdata.get("pretty_print", False): + from monty.serialization import dumpfn # assert(jdata["pretty_format"] in ['json','yaml']) fparam = ( SHORT_CMD diff --git a/dpgen/simplify/simplify.py b/dpgen/simplify/simplify.py index 3124fe115..84ebf6457 100644 --- a/dpgen/simplify/simplify.py +++ b/dpgen/simplify/simplify.py @@ -46,7 +46,7 @@ train_task_fmt, ) from dpgen.remote.decide_machine import convert_mdata -from dpgen.util import expand_sys_str, normalize, sepline +from dpgen.util import expand_sys_str, load_file, normalize, sepline from .arginfo import simplify_jdata_arginfo @@ -433,19 +433,8 @@ def run_iter(param_file, machine_file): 07 run_fp (same as generator) 08 post_fp (same as generator) """ - # TODO: function of handling input json should be combined as one function - try: - import ruamel - from monty.serialization import loadfn - - warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning) - jdata = loadfn(param_file) - mdata = loadfn(machine_file) - except Exception: - with open(param_file) as fp: - jdata = json.load(fp) - with open(machine_file) as fp: - mdata = json.load(fp) + jdata = load_file(param_file) + mdata = load_file(machine_file) jdata_arginfo = simplify_jdata_arginfo() jdata = normalize(jdata_arginfo, jdata) diff --git a/dpgen/util.py b/dpgen/util.py index 2c7bd3a26..6f48bffb9 100644 --- a/dpgen/util.py +++ b/dpgen/util.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import json import os +import warnings from contextlib import ( contextmanager, ) @@ -175,3 +176,35 @@ def set_directory(path: Path): yield finally: os.chdir(cwd) + + +def load_file(filename: Union[str, os.PathLike]) -> dict: + """Load data from a JSON or YAML file. + + Parameters + ---------- + filename : str or os.PathLike + The filename to load data from, whose suffix should be .json, .yaml, or .yml + + Returns + ------- + dict + The data loaded from the file + + Raises + ------ + ValueError + If the file format is not supported + """ + filename = str(filename) + if filename.endswith(".json"): + with open(filename, "r") as fp: + data = json.load(fp) + elif filename.endswith(".yaml") or filename.endswith(".yml"): + from ruamel.yaml import YAML + yaml = YAML(typ="safe", pure=True) + with open(filename, "r") as fp: + data = yaml.load(fp) + else: + raise ValueError(f"Unsupported file format: {filename}") + return data diff --git a/tests/sample.json b/tests/sample.json new file mode 100644 index 000000000..c068f91ac --- /dev/null +++ b/tests/sample.json @@ -0,0 +1 @@ +{"aa": "bb"} \ No newline at end of file diff --git a/tests/sample.yaml b/tests/sample.yaml new file mode 100644 index 000000000..9f408e971 --- /dev/null +++ b/tests/sample.yaml @@ -0,0 +1 @@ +aa: bb \ No newline at end of file diff --git a/tests/test_load_file.py b/tests/test_load_file.py new file mode 100644 index 000000000..05b7024c5 --- /dev/null +++ b/tests/test_load_file.py @@ -0,0 +1,17 @@ +import unittest +from pathlib import Path + +from dpgen.util import load_file + +this_directory = Path(__file__).parent + +class TestLoadFile(unittest.TestCase): + def test_load_json_file(self): + ref = {"aa": "bb"} + jdata = load_file(this_directory / "sample.json") + self.assertEqual(jdata, ref) + + def test_load_yaml_file(self): + ref = {"aa": "bb"} + jdata = load_file(this_directory / "sample.yaml") + self.assertEqual(jdata, ref) \ No newline at end of file From dcdcf43fcddafb4615387c5f7137d40aab55c8b9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Jun 2023 00:52:18 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpgen/data/gen.py | 2 -- dpgen/data/reaction.py | 2 -- dpgen/data/surf.py | 2 -- dpgen/generator/run.py | 1 + dpgen/simplify/simplify.py | 2 -- dpgen/util.py | 8 ++++---- tests/test_load_file.py | 5 +++-- 7 files changed, 8 insertions(+), 14 deletions(-) diff --git a/dpgen/data/gen.py b/dpgen/data/gen.py index 2a4f3ece0..6f7856cf0 100644 --- a/dpgen/data/gen.py +++ b/dpgen/data/gen.py @@ -2,13 +2,11 @@ import argparse import glob -import json import os import re import shutil import subprocess as sp import sys -import warnings import dpdata import numpy as np diff --git a/dpgen/data/reaction.py b/dpgen/data/reaction.py index c243f4fbb..a76766846 100644 --- a/dpgen/data/reaction.py +++ b/dpgen/data/reaction.py @@ -7,10 +7,8 @@ """ import glob -import json import os import random -import warnings import dpdata diff --git a/dpgen/data/surf.py b/dpgen/data/surf.py index 284a0e97f..75b2281d3 100644 --- a/dpgen/data/surf.py +++ b/dpgen/data/surf.py @@ -2,13 +2,11 @@ import argparse import glob -import json import os import re import shutil import subprocess as sp import sys -import warnings import numpy as np from ase.build import general_surface diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index 3a8b65147..0963ebdd1 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -4505,6 +4505,7 @@ def run_iter(param_file, machine_file): if jdata.get("pretty_print", False): from monty.serialization import dumpfn + # assert(jdata["pretty_format"] in ['json','yaml']) fparam = ( SHORT_CMD diff --git a/dpgen/simplify/simplify.py b/dpgen/simplify/simplify.py index 84ebf6457..2cf610e91 100644 --- a/dpgen/simplify/simplify.py +++ b/dpgen/simplify/simplify.py @@ -9,11 +9,9 @@ 02: fp (optional, if the original dataset do not have fp data, same as generator) """ import glob -import json import logging import os import queue -import warnings from collections import defaultdict from typing import List, Union diff --git a/dpgen/util.py b/dpgen/util.py index 6f48bffb9..6cdc28982 100644 --- a/dpgen/util.py +++ b/dpgen/util.py @@ -1,7 +1,6 @@ #!/usr/bin/env python import json import os -import warnings from contextlib import ( contextmanager, ) @@ -180,7 +179,7 @@ def set_directory(path: Path): def load_file(filename: Union[str, os.PathLike]) -> dict: """Load data from a JSON or YAML file. - + Parameters ---------- filename : str or os.PathLike @@ -198,12 +197,13 @@ def load_file(filename: Union[str, os.PathLike]) -> dict: """ filename = str(filename) if filename.endswith(".json"): - with open(filename, "r") as fp: + with open(filename) as fp: data = json.load(fp) elif filename.endswith(".yaml") or filename.endswith(".yml"): from ruamel.yaml import YAML + yaml = YAML(typ="safe", pure=True) - with open(filename, "r") as fp: + with open(filename) as fp: data = yaml.load(fp) else: raise ValueError(f"Unsupported file format: {filename}") diff --git a/tests/test_load_file.py b/tests/test_load_file.py index 05b7024c5..583d9c1aa 100644 --- a/tests/test_load_file.py +++ b/tests/test_load_file.py @@ -5,13 +5,14 @@ this_directory = Path(__file__).parent + class TestLoadFile(unittest.TestCase): def test_load_json_file(self): ref = {"aa": "bb"} jdata = load_file(this_directory / "sample.json") self.assertEqual(jdata, ref) - + def test_load_yaml_file(self): ref = {"aa": "bb"} jdata = load_file(this_directory / "sample.yaml") - self.assertEqual(jdata, ref) \ No newline at end of file + self.assertEqual(jdata, ref) From c079947769913b1c247d1fa5581b8ed730de7e3c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 21 Jun 2023 21:52:35 -0400 Subject: [PATCH 3/4] use load_file in a test --- tests/tools/test_convert_mdata.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tools/test_convert_mdata.py b/tests/tools/test_convert_mdata.py index 57e15b52c..90bcecc67 100644 --- a/tests/tools/test_convert_mdata.py +++ b/tests/tools/test_convert_mdata.py @@ -7,6 +7,7 @@ sys.path.insert(0, os.path.join(test_dir, "..")) __package__ = "tools" from dpgen.remote.decide_machine import convert_mdata +from dpgen.util import load_file from .context import setUpModule # noqa: F401 @@ -15,7 +16,7 @@ class TestConvertMdata(unittest.TestCase): machine_file = "machine_fp_single.json" def test_convert_mdata(self): - mdata = json.load(open(self.machine_file)) + mdata = load_file(self.machine_file) mdata = convert_mdata(mdata, ["fp"]) self.assertEqual(mdata["fp_command"], "vasp_std") self.assertEqual(mdata["fp_group_size"], 8) From a330e3cbad25095fd64873634c2b59c4871d6ff9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Jun 2023 01:52:49 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tools/test_convert_mdata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tools/test_convert_mdata.py b/tests/tools/test_convert_mdata.py index 90bcecc67..f5d76c7a0 100644 --- a/tests/tools/test_convert_mdata.py +++ b/tests/tools/test_convert_mdata.py @@ -1,4 +1,3 @@ -import json import os import sys import unittest