Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support the deepmd-kit v3 #207

Merged
merged 56 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
10773e9
specify machine group to run CI test
zjgemi Aug 21, 2023
820acc6
Merge branch 'master' of github.com:zjgemi/dpgen2
zjgemi Sep 22, 2023
0a42f2b
add support for deepmd-pytorch
zjgemi Sep 22, 2023
90f3b93
add wf name to config; log error before raise Error for better format
zjgemi Sep 25, 2023
0c633c6
Merge branch 'wfname-and-exception' into deepmd-pytorch
zjgemi Sep 25, 2023
5ef9a92
add support for fpop ABACUS
zjgemi Oct 11, 2023
20d96cc
Merge branch 'integrate-fpop' into deepmd-pytorch
zjgemi Oct 11, 2023
aa54843
merge test.yml
zjgemi Oct 11, 2023
164c45a
fix merge bug
zjgemi Oct 11, 2023
a7ccc00
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 11, 2023
81ffa0b
add args method to FpOpAbacusInputs and RunFpOpAbacus; try to import …
zjgemi Oct 12, 2023
4cea535
Merge branch 'integrate-fpop' of github.com:zjgemi/dpgen2 into integr…
zjgemi Oct 12, 2023
1dd5cf0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 12, 2023
7f61ac3
Merge branch 'integrate-fpop' into deepmd-pytorch
zjgemi Oct 12, 2023
2093d40
fix configuration_prefix
zjgemi Oct 12, 2023
43e55d0
Merge branch 'integrate-fpop' into deepmd-pytorch
zjgemi Oct 12, 2023
1fbef60
fpop does not support multisystem as confs
zjgemi Oct 13, 2023
bfe4217
Merge branch 'integrate-fpop' into deepmd-pytorch
zjgemi Oct 13, 2023
a07a6d5
remove atom types with 0 atom from type map, for abacus need pp_files…
zjgemi Oct 13, 2023
496007b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2023
05961e4
Merge branch 'integrate-fpop' into deepmd-pytorch
zjgemi Oct 13, 2023
765446c
try to import fpop
zjgemi Oct 13, 2023
781d2a8
Merge branch 'integrate-fpop' of github.com:zjgemi/dpgen2 into integr…
zjgemi Oct 13, 2023
703c1e5
Merge branch 'integrate-fpop' into deepmd-pytorch
zjgemi Oct 13, 2023
cca194f
Merge branch 'master' into deepmd-pytorch
zjgemi Oct 19, 2023
a4e85f4
resolve conflicts
zjgemi Oct 19, 2023
5198337
add restart of dp train for 2 cases:
zjgemi Nov 1, 2023
10645b8
Merge branch 'master' into deepmd-pytorch
zjgemi Dec 1, 2023
3dc44f9
support valid data
zjgemi Dec 6, 2023
d764c22
support multitask
zjgemi Dec 6, 2023
8617fa6
Use finetune for init model
zjgemi Dec 8, 2023
44be244
fix v undefined
zjgemi Dec 8, 2023
356b9e3
fix KeyError validation_data
zjgemi Dec 8, 2023
dac14df
allow executor in debug mode
zjgemi Jan 26, 2024
7a307ef
Merge branch 'master' into deepmd-pytorch
zjgemi Jan 31, 2024
c2ba028
transfer deepmd_pytorch to deepmd-kit v3
zjgemi Mar 15, 2024
77002b1
Merge branch 'master' into deepmd-v3
zjgemi Mar 16, 2024
10ba65a
fix merge
zjgemi Mar 17, 2024
a409d47
model.pt -> model.ckpt.pt
zjgemi Mar 22, 2024
201812d
handle extension of DP models correctly
zjgemi Mar 22, 2024
0da3190
sort reused steps whose startedAt are identical by key
zjgemi Mar 23, 2024
d9f9255
sort reused steps whose startedAt are identical by key in resubmit
zjgemi Mar 24, 2024
8ab3605
fix dpgen and multitask-dpgen for deepmd v3
zjgemi Mar 27, 2024
2243bc7
Merge branch 'sort-keys' into deepmd-v3
zjgemi Mar 27, 2024
16ce26e
fix multitask init model
zjgemi Mar 28, 2024
3221a6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
501890a
fix UT
zjgemi Mar 28, 2024
23bf3b9
fix UT and pyright
zjgemi Mar 28, 2024
cf59e84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
08a712b
fix UT
zjgemi Mar 28, 2024
ad5f24a
fix UT
zjgemi Mar 28, 2024
bc6eba3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
0cbbbcc
fix UT
zjgemi Mar 28, 2024
a6f3462
Merge branch 'deepmd-v3' of github.com:zjgemi/dpgen2 into deepmd-v3
zjgemi Mar 28, 2024
02897f9
fix system prefix and expand_sys_str
zjgemi Mar 29, 2024
1c12011
add deepmd v3 examples
zjgemi Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dpgen2/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
train_script_name = "input.json"
train_log_name = "train.log"
model_name_pattern = "model.%03d.pb"
model_name_match_pattern = r"model\.[0-9]{3,}\.pb"
pytorch_model_name_pattern = "model.%03d.pth"
model_name_match_pattern = r"model\.[0-9]{3,}(\.pb|\.pth)"
lmp_index_pattern = "%06d"
lmp_task_pattern = "task." + lmp_index_pattern
lmp_conf_name = "conf.lmp"
Expand Down
45 changes: 44 additions & 1 deletion dpgen2/entrypoint/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,14 @@ def input_args():
doc_do_finetune = textwrap.dedent(doc_do_finetune)
doc_init_data_prefix = "The prefix of initial data systems"
doc_init_sys = "The inital data systems"
doc_multitask = "Do multitask training"
doc_head = "Head to use in the multitask training"
doc_multi_init_data = (
"The inital data for multitask, it should be a dict, whose keys are task names and each value is a dict"
"containing fields `prefix` and `sys` for initial data of each task"
)
doc_valid_data_prefix = "The prefix of validation data systems"
doc_valid_sys = "The validation data systems"

return [
Argument("type_map", List[str], optional=False, doc=doc_type_map),
Expand All @@ -288,10 +296,45 @@ def input_args():
Argument(
"init_data_sys",
[List[str], str],
optional=False,
optional=True,
default=None,
doc=doc_init_sys,
),
Argument(
"multitask",
bool,
optional=True,
default=False,
doc=doc_multitask,
),
Argument(
"head",
str,
optional=True,
default=None,
doc=doc_head,
),
Argument(
"multi_init_data",
dict,
optional=True,
default=None,
doc=doc_multi_init_data,
),
Argument(
"valid_data_prefix",
str,
optional=True,
default=None,
doc=doc_valid_data_prefix,
),
Argument(
"valid_data_sys",
[List[str], str],
optional=True,
default=None,
doc=doc_valid_sys,
),
]


Expand Down
60 changes: 51 additions & 9 deletions dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@
collect_data_config: dict = default_config,
cl_step_config: dict = default_config,
upload_python_packages: Optional[List[os.PathLike]] = None,
valid_data: Optional[S3Artifact] = None,
):
if train_style in ("dp", "dp-dist"):
prep_run_train_op = PrepRunDPTrain(
Expand All @@ -154,6 +155,7 @@
prep_config=prep_train_config,
run_config=run_train_config,
upload_python_packages=upload_python_packages,
valid_data=valid_data,
)
else:
raise RuntimeError(f"unknown train_style {train_style}")
Expand Down Expand Up @@ -387,6 +389,7 @@
init_models,
init_data,
iter_data,
valid_data=None,
):
finetune_optional_parameter = {
"mixed_type": config["inputs"]["mixed_type"],
Expand All @@ -401,6 +404,7 @@
run_config=run_train_config,
upload_python_packages=upload_python_packages,
finetune=True,
valid_data=valid_data,
)
finetune_step = Step(
"finetune-step",
Expand Down Expand Up @@ -466,6 +470,15 @@
]
upload_python_packages = _upload_python_packages

valid_data = config["inputs"]["valid_data_sys"]
if valid_data is not None:
valid_data_prefix = config["inputs"]["valid_data_prefix"]
valid_data = [valid_data] if isinstance(valid_data, str) else valid_data
assert isinstance(valid_data, list)
if valid_data_prefix is not None:
valid_data = [os.path.join(valid_data_prefix, ii) for ii in valid_data]
zjgemi marked this conversation as resolved.
Show resolved Hide resolved
valid_data = [expand_sys_str(ii) for ii in valid_data]
valid_data = upload_artifact(valid_data)

Check warning on line 481 in dpgen2/entrypoint/submit.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/entrypoint/submit.py#L475-L481

Added lines #L475 - L481 were not covered by tests
concurrent_learning_op = make_concurrent_learning_op(
train_style,
explore_style,
Expand All @@ -480,6 +493,7 @@
collect_data_config=collect_data_config,
cl_step_config=cl_step_config,
upload_python_packages=upload_python_packages,
valid_data=valid_data,
)
scheduler = make_naive_exploration_scheduler(config)

Expand All @@ -500,7 +514,7 @@
explore_config["teacher_model_path"]
), f"No such file: {explore_config['teacher_model_path']}"
explore_config["teacher_model_path"] = BinaryFileInput(
explore_config["teacher_model_path"], "pb"
explore_config["teacher_model_path"]
)

fp_config = {}
Expand All @@ -517,15 +531,37 @@
fp_config["run"]["teacher_model_path"]
), f"No such file: {fp_config['run']['teacher_model_path']}"
fp_config["run"]["teacher_model_path"] = BinaryFileInput(
fp_config["run"]["teacher_model_path"], "pb"
fp_config["run"]["teacher_model_path"]
)

init_data_prefix = config["inputs"]["init_data_prefix"]
init_data = config["inputs"]["init_data_sys"]
if init_data_prefix is not None:
init_data = [os.path.join(init_data_prefix, ii) for ii in init_data]
if isinstance(init_data, str):
init_data = expand_sys_str(init_data)
multitask = config["inputs"]["multitask"]
if multitask:
head = config["inputs"]["head"]
multi_init_data = config["inputs"]["multi_init_data"]
init_data = []
multi_init_data_idx = {}
for k, v in multi_init_data.items():
sys = v["sys"]
sys = [sys] if isinstance(sys, str) else sys
assert isinstance(sys, list)
if v["prefix"] is not None:
sys = [os.path.join(v["prefix"], ii) for ii in sys]
sys = [expand_sys_str(ii) for ii in sys]
istart = len(init_data)
init_data += sys
iend = len(init_data)
multi_init_data_idx[k] = list(range(istart, iend))
train_config["multitask"] = True
train_config["head"] = head
train_config["multi_init_data_idx"] = multi_init_data_idx
explore_config["head"] = head

Check warning on line 557 in dpgen2/entrypoint/submit.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/entrypoint/submit.py#L539-L557

Added lines #L539 - L557 were not covered by tests
else:
init_data_prefix = config["inputs"]["init_data_prefix"]
init_data = config["inputs"]["init_data_sys"]
if init_data_prefix is not None:
init_data = [os.path.join(init_data_prefix, ii) for ii in init_data]

Check warning on line 562 in dpgen2/entrypoint/submit.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/entrypoint/submit.py#L562

Added line #L562 was not covered by tests
if isinstance(init_data, str):
init_data = expand_sys_str(init_data)

Check warning on line 564 in dpgen2/entrypoint/submit.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/entrypoint/submit.py#L564

Added line #L564 was not covered by tests
init_data = upload_artifact(init_data)
iter_data = upload_artifact([])
if init_models_paths is not None:
Expand All @@ -550,6 +586,7 @@
init_models,
init_data,
iter_data,
valid_data=valid_data,
)

init_models = finetune_step.outputs.artifacts["models"]
Expand Down Expand Up @@ -734,7 +771,10 @@

def successful_step_keys(wf):
all_step_keys = []
for step in wf.query_step():
steps = wf.query_step()

Check warning on line 774 in dpgen2/entrypoint/submit.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/entrypoint/submit.py#L774

Added line #L774 was not covered by tests
# For reused steps whose startedAt are identical, sort them by key
steps.sort(key=lambda x: "%s-%s" % (x.startedAt, x.key))
for step in steps:

Check warning on line 777 in dpgen2/entrypoint/submit.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/entrypoint/submit.py#L776-L777

Added lines #L776 - L777 were not covered by tests
if step.key is not None and step.phase == "Succeeded":
all_step_keys.append(step.key)
return all_step_keys
Expand Down Expand Up @@ -868,6 +908,8 @@
reused_folded_keys[k] = [k]
reused_keys = sum(reused_folded_keys.values(), [])
reuse_step = old_wf.query_step(key=reused_keys)
# For reused steps whose startedAt are identical, sort them by key
reuse_step.sort(key=lambda x: "%s-%s" % (x.startedAt, x.key))

Check warning on line 912 in dpgen2/entrypoint/submit.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/entrypoint/submit.py#L912

Added line #L912 was not covered by tests

wf = submit_concurrent_learning(
wf_config,
Expand Down
5 changes: 2 additions & 3 deletions dpgen2/fp/deepmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@
# global static variables
deepmd_temp_path = "one_frame_temp"

# global static variables
deepmd_teacher_model = "teacher_model.pb"


class DeepmdInputs:
@staticmethod
Expand Down Expand Up @@ -136,6 +133,8 @@ def run_task(
def _get_dp_model(self, teacher_model_path: BinaryFileInput):
from deepmd.infer import DeepPot # type: ignore

ext = os.path.splitext(teacher_model_path.file_name)[-1]
deepmd_teacher_model = "teacher_model" + ext
teacher_model_path.save_as_file(deepmd_teacher_model)
dp = DeepPot(Path(deepmd_teacher_model))

Expand Down
19 changes: 14 additions & 5 deletions dpgen2/op/prep_dp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,27 @@
)
return op

def _set_desc_seed(self, desc):
if desc["type"] == "hybrid":
for desc in desc["list"]:
self._set_desc_seed(desc)
elif desc["type"] not in ["dpa1", "dpa2"]:
desc["seed"] = random.randrange(sys.maxsize) % (2**32)

def _script_rand_seed(
self,
input_dict,
):
jtmp = input_dict.copy()
if jtmp["model"]["descriptor"]["type"] == "hybrid":
for desc in jtmp["model"]["descriptor"]["list"]:
desc["seed"] = random.randrange(sys.maxsize) % (2**32)
if "model_dict" in jtmp["model"]:
for d in jtmp["model"]["model_dict"].values():
if isinstance(d["descriptor"], str):
self._set_desc_seed(jtmp["model"]["shared_dict"][d["descriptor"]])
d["fitting_net"]["seed"] = random.randrange(sys.maxsize) % (2**32)

Check warning on line 126 in dpgen2/op/prep_dp_train.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/op/prep_dp_train.py#L123-L126

Added lines #L123 - L126 were not covered by tests
else:
jtmp["model"]["descriptor"]["seed"] = random.randrange(sys.maxsize) % (
self._set_desc_seed(jtmp["model"]["descriptor"])
jtmp["model"]["fitting_net"]["seed"] = random.randrange(sys.maxsize) % (
2**32
)
jtmp["model"]["fitting_net"]["seed"] = random.randrange(sys.maxsize) % (2**32)
jtmp["training"]["seed"] = random.randrange(sys.maxsize) % (2**32)
return jtmp
Loading