Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load files according to filename suffix #1255

Merged
merged 4 commits into from
Jun 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions dpgen/data/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@

import argparse
import glob
import json
import os
import re
import shutil
import subprocess as sp
import sys
import warnings

import dpdata
import numpy as np
Expand All @@ -33,6 +31,7 @@
from dpgen.generator.lib.utils import symlink_user_forward_files
from dpgen.generator.lib.vasp import incar_upper
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import load_file


def create_path(path, back=False):
Expand Down Expand Up @@ -1465,22 +1464,9 @@ def run_abacus_md(jdata, mdata):


def gen_init_bulk(args):
try:
import ruamel
from monty.serialization import loadfn

warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning)
jdata = loadfn(args.PARAM)
if args.MACHINE is not None:
mdata = loadfn(args.MACHINE)
except Exception:
with open(args.PARAM) as fp:
jdata = json.load(fp)
if args.MACHINE is not None:
with open(args.MACHINE) as fp:
mdata = json.load(fp)

jdata = load_file(args.PARAM)
if args.MACHINE is not None:
mdata = load_file(args.MACHINE)
# Selecting a proper machine
mdata = convert_mdata(mdata, ["fp"])
# disp = make_dispatcher(mdata["fp_machine"])
Expand Down
20 changes: 3 additions & 17 deletions dpgen/data/reaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,16 @@
"""

import glob
import json
import os
import random
import warnings

import dpdata

from dpgen import dlog
from dpgen.dispatcher.Dispatcher import make_submission_compat
from dpgen.generator.run import create_path, make_fp_task_name
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import normalize, sepline
from dpgen.util import load_file, normalize, sepline

from .arginfo import init_reaction_jdata_arginfo

Expand Down Expand Up @@ -214,20 +212,8 @@ def convert_data(jdata):


def gen_init_reaction(args):
try:
import ruamel
from monty.serialization import loadfn

warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning)
jdata = loadfn(args.PARAM)
if args.MACHINE is not None:
mdata = loadfn(args.MACHINE)
except Exception:
with open(args.PARAM) as fp:
jdata = json.load(fp)
if args.MACHINE is not None:
with open(args.MACHINE) as fp:
mdata = json.load(fp)
jdata = load_file(args.PARAM)
mdata = load_file(args.MACHINE)

jdata_arginfo = init_reaction_jdata_arginfo()
jdata = normalize(jdata_arginfo, jdata)
Expand Down
19 changes: 3 additions & 16 deletions dpgen/data/surf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@

import argparse
import glob
import json
import os
import re
import shutil
import subprocess as sp
import sys
import warnings

import numpy as np
from ase.build import general_surface
Expand All @@ -29,6 +27,7 @@
from dpgen.dispatcher.Dispatcher import make_submission_compat
from dpgen.generator.lib.utils import symlink_user_forward_files
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import load_file


def create_path(path):
Expand Down Expand Up @@ -602,26 +601,14 @@ def run_vasp_relax(jdata, mdata):


def gen_init_surf(args):
try:
import ruamel
from monty.serialization import loadfn

warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning)
jdata = loadfn(args.PARAM)
if args.MACHINE is not None:
mdata = loadfn(args.MACHINE)
except Exception:
with open(args.PARAM) as fp:
jdata = json.load(fp)
if args.MACHINE is not None:
with open(args.MACHINE) as fp:
mdata = json.load(fp)
jdata = load_file(args.PARAM)

out_dir = out_dir_name(jdata)
jdata["out_dir"] = out_dir
dlog.info("# working dir %s" % out_dir)

if args.MACHINE is not None:
mdata = load_file(args.MACHINE)
# Decide a proper machine
mdata = convert_mdata(mdata, ["fp"])
# disp = make_dispatcher(mdata["fp_machine"])
Expand Down
17 changes: 5 additions & 12 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
from dpgen.util import (
convert_training_data_to_hdf5,
expand_sys_str,
load_file,
normalize,
sepline,
set_directory,
Expand Down Expand Up @@ -4494,25 +4495,17 @@ def set_version(mdata):


def run_iter(param_file, machine_file):
try:
import ruamel
from monty.serialization import dumpfn, loadfn

warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning)
jdata = loadfn(param_file)
mdata = loadfn(machine_file)
except Exception:
with open(param_file) as fp:
jdata = json.load(fp)
with open(machine_file) as fp:
mdata = json.load(fp)
jdata = load_file(param_file)
mdata = load_file(machine_file)

jdata_arginfo = run_jdata_arginfo()
jdata = normalize(jdata_arginfo, jdata, strict_check=False)

update_mass_map(jdata)

if jdata.get("pretty_print", False):
from monty.serialization import dumpfn

# assert(jdata["pretty_format"] in ['json','yaml'])
fparam = (
SHORT_CMD
Expand Down
19 changes: 3 additions & 16 deletions dpgen/simplify/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
02: fp (optional, if the original dataset do not have fp data, same as generator)
"""
import glob
import json
import logging
import os
import queue
import warnings
from collections import defaultdict
from typing import List, Union

Expand Down Expand Up @@ -46,7 +44,7 @@
train_task_fmt,
)
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import expand_sys_str, normalize, sepline
from dpgen.util import expand_sys_str, load_file, normalize, sepline

from .arginfo import simplify_jdata_arginfo

Expand Down Expand Up @@ -433,19 +431,8 @@ def run_iter(param_file, machine_file):
07 run_fp (same as generator)
08 post_fp (same as generator)
"""
# TODO: function of handling input json should be combined as one function
try:
import ruamel
from monty.serialization import loadfn

warnings.simplefilter("ignore", ruamel.yaml.error.MantissaNoDotYAML1_1Warning)
jdata = loadfn(param_file)
mdata = loadfn(machine_file)
except Exception:
with open(param_file) as fp:
jdata = json.load(fp)
with open(machine_file) as fp:
mdata = json.load(fp)
jdata = load_file(param_file)
mdata = load_file(machine_file)

jdata_arginfo = simplify_jdata_arginfo()
jdata = normalize(jdata_arginfo, jdata)
Expand Down
33 changes: 33 additions & 0 deletions dpgen/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,36 @@ def set_directory(path: Path):
yield
finally:
os.chdir(cwd)


def load_file(filename: Union[str, os.PathLike]) -> dict:
"""Load data from a JSON or YAML file.

Parameters
----------
filename : str or os.PathLike
The filename to load data from, whose suffix should be .json, .yaml, or .yml

Returns
-------
dict
The data loaded from the file

Raises
------
ValueError
If the file format is not supported
"""
filename = str(filename)
if filename.endswith(".json"):
with open(filename) as fp:
data = json.load(fp)
elif filename.endswith(".yaml") or filename.endswith(".yml"):
from ruamel.yaml import YAML

yaml = YAML(typ="safe", pure=True)
with open(filename) as fp:
data = yaml.load(fp)
else:
raise ValueError(f"Unsupported file format: {filename}")
return data
1 change: 1 addition & 0 deletions tests/sample.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"aa": "bb"}
1 change: 1 addition & 0 deletions tests/sample.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
aa: bb
18 changes: 18 additions & 0 deletions tests/test_load_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import unittest
from pathlib import Path

from dpgen.util import load_file

this_directory = Path(__file__).parent


class TestLoadFile(unittest.TestCase):
def test_load_json_file(self):
ref = {"aa": "bb"}
jdata = load_file(this_directory / "sample.json")
self.assertEqual(jdata, ref)

def test_load_yaml_file(self):
ref = {"aa": "bb"}
jdata = load_file(this_directory / "sample.yaml")
self.assertEqual(jdata, ref)
4 changes: 2 additions & 2 deletions tests/tools/test_convert_mdata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
import sys
import unittest
Expand All @@ -7,6 +6,7 @@
sys.path.insert(0, os.path.join(test_dir, ".."))
__package__ = "tools"
from dpgen.remote.decide_machine import convert_mdata
from dpgen.util import load_file

from .context import setUpModule # noqa: F401

Expand All @@ -15,7 +15,7 @@ class TestConvertMdata(unittest.TestCase):
machine_file = "machine_fp_single.json"

def test_convert_mdata(self):
mdata = json.load(open(self.machine_file))
mdata = load_file(self.machine_file)
mdata = convert_mdata(mdata, ["fp"])
self.assertEqual(mdata["fp_command"], "vasp_std")
self.assertEqual(mdata["fp_group_size"], 8)
Expand Down