Skip to content

Commit

Permalink
add split_last_iter_valid_ratio
Browse files Browse the repository at this point in the history
Signed-off-by: zjgemi <[email protected]>
  • Loading branch information
zjgemi committed Oct 14, 2024
1 parent 5f263d3 commit 7c36020
Showing 1 changed file with 85 additions and 0 deletions.
85 changes: 85 additions & 0 deletions dpgen2/op/run_dp_train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import glob
import json
import logging
import math
import os
import random
import shutil
from pathlib import (
Path,
Expand Down Expand Up @@ -197,6 +199,10 @@ def execute(
valid_data = ip["valid_data"]
iter_data_old_exp = _expand_all_multi_sys_to_sys(iter_data[:-1])
iter_data_new_exp = _expand_all_multi_sys_to_sys(iter_data[-1:])
if config["split_last_iter_valid_ratio"] is not None:
train_systems, valid_systems = split_valid(iter_data_new_exp, config["split_last_iter_valid_ratio"])
iter_data_new_exp = train_systems
valid_data = append_valid_data(config, valid_data, valid_systems)
iter_data_exp = iter_data_old_exp + iter_data_new_exp
work_dir = Path(task_name)
init_model_with_finetune = config["init_model_with_finetune"]
Expand Down Expand Up @@ -517,6 +523,7 @@ def training_args():
doc_head = "Head to use in the multitask training"
doc_init_model_with_finetune = "Use finetune for init model"
doc_train_args = "Extra arguments for dp train"
doc_split_last_iter_valid_ratio = "Ratio of valid data if split data of last iter"
return [
Argument(
"command",
Expand Down Expand Up @@ -618,6 +625,13 @@ def training_args():
default="",
doc=doc_train_args,
),
Argument(
"split_last_iter_valid_ratio",
float,
optional=True,
default=None,
doc=doc_split_last_iter_valid_ratio,
),
]

@staticmethod
Expand Down Expand Up @@ -672,4 +686,75 @@ def _expand_all_multi_sys_to_sys(list_multi_sys):
return all_sys_dirs


def split_valid(systems: List[str], valid_ratio: float):
train_systems = []
valid_systems = []
for system in systems:
d = dpdata.MultiSystems()
mixed_type = len(glob.glob("%s/*/real_atom_types.npy" % system)) > 0
if mixed_type:
d.load_systems_from_file(system, fmt="deepmd/npy/mixed")
else:
k = dpdata.LabeledSystem(system, fmt="deepmd/npy")
d.append(k)

train_multi_systems = dpdata.MultiSystems()
valid_multi_systems = dpdata.MultiSystems()
for s in d:
nvalid = math.floor(len(s)*valid_ratio)
if random.random() < len(s)*valid_ratio - nvalid:
nvalid += 1
valid_indices = random.sample(range(len(s)), nvalid)
train_indices = list(set(range(len(s))).difference(valid_indices))
if len(valid_indices) > 0:
valid_multi_systems.append(s.sub_system(valid_indices))
if len(train_indices) > 0:
train_multi_systems.append(s.sub_system(train_indices))

if len(train_multi_systems) > 0:
target = "train_data/" + system
if mixed_type:
# The multisystem is loaded from one dir, thus we can safely keep one dir
train_multi_systems.to_deepmd_npy_mixed("%s.tmp" % target)

Check failure on line 718 in dpgen2/op/run_dp_train.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot access member "to_deepmd_npy_mixed" for type "MultiSystems"   Member "to_deepmd_npy_mixed" is unknown (reportGeneralTypeIssues)
fs = os.listdir("%s.tmp" % target)
assert len(fs) == 1
os.rename(os.path.join("%s.tmp" % target, fs[0]), target)
os.rmdir("%s.tmp" % target)
else:
train_multi_systems[0].to_deepmd_npy(target)

Check failure on line 724 in dpgen2/op/run_dp_train.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot access member "to_deepmd_npy" for type "System"   Member "to_deepmd_npy" is unknown (reportGeneralTypeIssues)
train_systems.append(target)

if len(valid_multi_systems) > 0:
target = "valid_data/" + system
if mixed_type:
# The multisystem is loaded from one dir, thus we can safely keep one dir
valid_multi_systems.to_deepmd_npy_mixed("%s.tmp" % target)

Check failure on line 731 in dpgen2/op/run_dp_train.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot access member "to_deepmd_npy_mixed" for type "MultiSystems"   Member "to_deepmd_npy_mixed" is unknown (reportGeneralTypeIssues)
fs = os.listdir("%s.tmp" % target)
assert len(fs) == 1
os.rename(os.path.join("%s.tmp" % target, fs[0]), target)
os.rmdir("%s.tmp" % target)
else:
valid_multi_systems[0].to_deepmd_npy(target)

Check failure on line 737 in dpgen2/op/run_dp_train.py

View workflow job for this annotation

GitHub Actions / pyright

Cannot access member "to_deepmd_npy" for type "System"   Member "to_deepmd_npy" is unknown (reportGeneralTypeIssues)
valid_systems.append(target)

return train_systems, valid_systems


def append_valid_data(config, valid_data, valid_systems):
if not valid_systems:
return valid_data
if config["multitask"]:
head = config["head"]
if not valid_data:
valid_data = {}
if head not in valid_data:
valid_data[head] = []
valid_data[head] += valid_systems
else:
if not valid_data:
valid_data = []
valid_data += valid_systems
return valid_data


config_args = RunDPTrain.training_args

0 comments on commit 7c36020

Please sign in to comment.