From 7c36020d403284ea9ff98e4936ed8d86e401e2b2 Mon Sep 17 00:00:00 2001 From: zjgemi Date: Mon, 14 Oct 2024 15:41:06 +0800 Subject: [PATCH] add split_last_iter_valid_ratio Signed-off-by: zjgemi --- dpgen2/op/run_dp_train.py | 85 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/dpgen2/op/run_dp_train.py b/dpgen2/op/run_dp_train.py index dccbc518..2f066ad3 100644 --- a/dpgen2/op/run_dp_train.py +++ b/dpgen2/op/run_dp_train.py @@ -1,7 +1,9 @@ import glob import json import logging +import math import os +import random import shutil from pathlib import ( Path, @@ -197,6 +199,10 @@ def execute( valid_data = ip["valid_data"] iter_data_old_exp = _expand_all_multi_sys_to_sys(iter_data[:-1]) iter_data_new_exp = _expand_all_multi_sys_to_sys(iter_data[-1:]) + if config["split_last_iter_valid_ratio"] is not None: + train_systems, valid_systems = split_valid(iter_data_new_exp, config["split_last_iter_valid_ratio"]) + iter_data_new_exp = train_systems + valid_data = append_valid_data(config, valid_data, valid_systems) iter_data_exp = iter_data_old_exp + iter_data_new_exp work_dir = Path(task_name) init_model_with_finetune = config["init_model_with_finetune"] @@ -517,6 +523,7 @@ def training_args(): doc_head = "Head to use in the multitask training" doc_init_model_with_finetune = "Use finetune for init model" doc_train_args = "Extra arguments for dp train" + doc_split_last_iter_valid_ratio = "Ratio of valid data if split data of last iter" return [ Argument( "command", @@ -618,6 +625,13 @@ def training_args(): default="", doc=doc_train_args, ), + Argument( + "split_last_iter_valid_ratio", + float, + optional=True, + default=None, + doc=doc_split_last_iter_valid_ratio, + ), ] @staticmethod @@ -672,4 +686,75 @@ def _expand_all_multi_sys_to_sys(list_multi_sys): return all_sys_dirs +def split_valid(systems: List[str], valid_ratio: float): + train_systems = [] + valid_systems = [] + for system in systems: + d = dpdata.MultiSystems() + mixed_type = len(glob.glob("%s/*/real_atom_types.npy" % system)) > 0 + if mixed_type: + d.load_systems_from_file(system, fmt="deepmd/npy/mixed") + else: + k = dpdata.LabeledSystem(system, fmt="deepmd/npy") + d.append(k) + + train_multi_systems = dpdata.MultiSystems() + valid_multi_systems = dpdata.MultiSystems() + for s in d: + nvalid = math.floor(len(s)*valid_ratio) + if random.random() < len(s)*valid_ratio - nvalid: + nvalid += 1 + valid_indices = random.sample(range(len(s)), nvalid) + train_indices = list(set(range(len(s))).difference(valid_indices)) + if len(valid_indices) > 0: + valid_multi_systems.append(s.sub_system(valid_indices)) + if len(train_indices) > 0: + train_multi_systems.append(s.sub_system(train_indices)) + + if len(train_multi_systems) > 0: + target = "train_data/" + system + if mixed_type: + # The multisystem is loaded from one dir, thus we can safely keep one dir + train_multi_systems.to_deepmd_npy_mixed("%s.tmp" % target) + fs = os.listdir("%s.tmp" % target) + assert len(fs) == 1 + os.rename(os.path.join("%s.tmp" % target, fs[0]), target) + os.rmdir("%s.tmp" % target) + else: + train_multi_systems[0].to_deepmd_npy(target) + train_systems.append(target) + + if len(valid_multi_systems) > 0: + target = "valid_data/" + system + if mixed_type: + # The multisystem is loaded from one dir, thus we can safely keep one dir + valid_multi_systems.to_deepmd_npy_mixed("%s.tmp" % target) + fs = os.listdir("%s.tmp" % target) + assert len(fs) == 1 + os.rename(os.path.join("%s.tmp" % target, fs[0]), target) + os.rmdir("%s.tmp" % target) + else: + valid_multi_systems[0].to_deepmd_npy(target) + valid_systems.append(target) + + return train_systems, valid_systems + + +def append_valid_data(config, valid_data, valid_systems): + if not valid_systems: + return valid_data + if config["multitask"]: + head = config["head"] + if not valid_data: + valid_data = {} + if head not in valid_data: + valid_data[head] = [] + valid_data[head] += valid_systems + else: + if not valid_data: + valid_data = [] + valid_data += valid_systems + return valid_data + + config_args = RunDPTrain.training_args