Skip to content

Commit

Permalink
argcheck: restrict the type of elements in a list (#179)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Oct 25, 2023
1 parent 16f24d6 commit faf582e
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 15 deletions.
25 changes: 16 additions & 9 deletions dpgen2/entrypoint/args.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import textwrap
from typing import (
List,
)

import dargs
from dargs import (
Expand Down Expand Up @@ -53,7 +56,7 @@ def dp_dist_train_args():
doc=doc_config,
),
Argument(
"template_script", [list, str], optional=False, doc=doc_template_script
"template_script", [List[str], str], optional=False, doc=doc_template_script
),
Argument("student_model_path", str, optional=True, doc=dock_student_model_path),
]
Expand All @@ -76,11 +79,11 @@ def dp_train_args():
),
Argument("numb_models", int, optional=True, default=4, doc=doc_numb_models),
Argument(
"template_script", [list, str], optional=False, doc=doc_template_script
"template_script", [List[str], str], optional=False, doc=doc_template_script
),
Argument(
"init_models_paths",
list,
List[str],
optional=True,
default=None,
doc=doc_init_models_paths,
Expand Down Expand Up @@ -175,7 +178,7 @@ def lmp_args():
),
Argument(
"convergence",
list,
dict,
[],
[variant_conv()],
optional=False,
Expand All @@ -191,7 +194,7 @@ def lmp_args():
doc=doc_configuration,
alias=["configuration"],
),
Argument("stages", list, optional=False, doc=doc_stages),
Argument("stages", List[List[dict]], optional=False, doc=doc_stages),
]


Expand Down Expand Up @@ -266,8 +269,8 @@ def input_args():
doc_init_sys = "The inital data systems"

return [
Argument("type_map", list, optional=False, doc=doc_type_map),
Argument("mass_map", list, optional=False, doc=doc_mass_map),
Argument("type_map", List[str], optional=False, doc=doc_type_map),
Argument("mass_map", List[float], optional=False, doc=doc_mass_map),
Argument(
"init_data_prefix",
str,
Expand All @@ -280,7 +283,11 @@ def input_args():
"do_finetune", bool, optional=True, default=False, doc=doc_do_finetune
),
Argument(
"init_data_sys", [list, str], optional=False, default=None, doc=doc_init_sys
"init_data_sys",
[List[str], str],
optional=False,
default=None,
doc=doc_init_sys,
),
]

Expand Down Expand Up @@ -479,7 +486,7 @@ def submit_args(default_step_config=normalize_step_dict({})):
),
Argument(
"upload_python_packages",
[list, str],
[List[str], str],
optional=True,
default=None,
doc=doc_upload_python_packages,
Expand Down
3 changes: 2 additions & 1 deletion examples/almg/input.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"bohrium_config": {
"username": "__write your bohrium username",
"password": "__write your bohrium password",
"project_id" : "__write your bohrium project ID",
"_project_id" : "__write your bohrium project ID",
"project_id": 123456,
"_comment" : "all"
},

Expand Down
3 changes: 2 additions & 1 deletion examples/chno/input.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"bohrium_config": {
"username": "__write your bohrium username",
"password": "__write your bohrium password",
"project_id" : "__write your bohrium project ID",
"_project_id" : "__write your bohrium project ID",
"project_id" : 123456,
"_comment" : "all"
},

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ docs = [
'myst_parser',
'deepmodeling_sphinx',
'sphinx-argparse',
"dargs>=0.3.1",
"dargs>=0.4.1",
]
test = [
'fakegaussian>=0.0.3',
Expand Down
29 changes: 29 additions & 0 deletions tests/check_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import json
import unittest
from pathlib import (
Path,
)

from dpgen2.entrypoint.args import (
normalize,
)

p_examples = Path(__file__).parent.parent / "examples"

input_files = (
p_examples / "almg" / "input.json",
# p_examples / "almg" / "input-v005.json",
# p_examples / "almg" / "dp_template.json",
p_examples / "ch4" / "input_dist.json",
# p_examples / "chno" / "dpa_manyi.json",
p_examples / "chno" / "input.json",
)


class TestExamples(unittest.TestCase):
def test_arguments(self):
for fn in input_files:
with self.subTest(fn=fn):
with open(fn) as f:
jdata = json.load(f)
normalize(jdata)
6 changes: 3 additions & 3 deletions tests/entrypoint/test_submit_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test(self):
new_data["bohrium_config"],
None,
)
self.assertEqual(old_data["model_devi_jobs"], new_data["explore"]["stages"])
self.assertEqual(old_data["model_devi_jobs"], new_data["explore"]["stages"][0])
new_data["explore"]["configurations"][0].pop("type")
self.assertEqual(old_data["sys_configs"], new_data["explore"]["configurations"])
self.assertEqual(
Expand Down Expand Up @@ -413,10 +413,10 @@ def test_bohrium(self):
"concentration" : [[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]]
}
],
"stages": [
"stages": [[
{ "_idx": 0, "ensemble": "nvt", "nsteps": 20, "press": [1.0,2.0], "sys_idx": [0], "temps": [50,100], "trj_freq": 10, "n_sample" : 3
}
],
]],
"_comment" : "all"
},
"fp" : {
Expand Down

0 comments on commit faf582e

Please sign in to comment.