diff --git a/dpgen2/exploration/task/lmp/lmp_input.py b/dpgen2/exploration/task/lmp/lmp_input.py index e9626154..c2a22b60 100644 --- a/dpgen2/exploration/task/lmp/lmp_input.py +++ b/dpgen2/exploration/task/lmp/lmp_input.py @@ -101,10 +101,17 @@ def make_lmp_input( graph_list = "" for ii in graphs: graph_list += ii + " " - model_devi_file_name = lmp_pimd_model_devi_name % pimd_bead if pimd_bead is not None else lmp_model_devi_name + model_devi_file_name = ( + lmp_pimd_model_devi_name % pimd_bead + if pimd_bead is not None + else lmp_model_devi_name + ) if Version(deepmd_version) < Version("1"): # 0.x - ret += "pair_style deepmd %s ${THERMO_FREQ} %s\n" % (graph_list, model_devi_file_name) + ret += "pair_style deepmd %s ${THERMO_FREQ} %s\n" % ( + graph_list, + model_devi_file_name, + ) else: # 1.x keywords = "" @@ -118,9 +125,10 @@ def make_lmp_input( keywords += "fparam ${ELE_TEMP}" if ele_temp_a is not None: keywords += "aparam ${ELE_TEMP}" - ret += ( - "pair_style deepmd %s out_freq ${THERMO_FREQ} out_file %s %s\n" - % (graph_list, model_devi_file_name, keywords) + ret += "pair_style deepmd %s out_freq ${THERMO_FREQ} out_file %s %s\n" % ( + graph_list, + model_devi_file_name, + keywords, ) ret += "pair_coeff * *\n" ret += "\n" @@ -129,7 +137,9 @@ def make_lmp_input( if trj_seperate_files: ret += "dump 1 all custom ${DUMP_FREQ} traj/*.lammpstrj id type x y z fx fy fz\n" else: - lmp_traj_file_name = lmp_pimd_traj_name % pimd_bead if pimd_bead is not None else lmp_traj_name + lmp_traj_file_name = ( + lmp_pimd_traj_name % pimd_bead if pimd_bead is not None else lmp_traj_name + ) ret += ( "dump 1 all custom ${DUMP_FREQ} %s id type x y z fx fy fz\n" % lmp_traj_file_name diff --git a/dpgen2/exploration/task/lmp_template_task_group.py b/dpgen2/exploration/task/lmp_template_task_group.py index 26ad3b4a..d0b362f5 100644 --- a/dpgen2/exploration/task/lmp_template_task_group.py +++ b/dpgen2/exploration/task/lmp_template_task_group.py @@ -63,7 +63,9 @@ def set_lmp( self.extra_pair_style_args, self.pimd_bead, ) - self.lmp_template = revise_lmp_input_dump(self.lmp_template, self.traj_freq, self.pimd_bead) + self.lmp_template = revise_lmp_input_dump( + self.lmp_template, self.traj_freq, self.pimd_bead + ) if plm_template_fname is not None: self.plm_template = Path(plm_template_fname).read_text().split("\n") self.plm_set = True @@ -150,28 +152,36 @@ def find_only_one_key(lmp_lines, key): def revise_lmp_input_model( - lmp_lines, task_model_list, trj_freq, extra_pair_style_args="", pimd_bead=None, deepmd_version="1" + lmp_lines, + task_model_list, + trj_freq, + extra_pair_style_args="", + pimd_bead=None, + deepmd_version="1", ): idx = find_only_one_key(lmp_lines, ["pair_style", "deepmd"]) if extra_pair_style_args: extra_pair_style_args = " " + extra_pair_style_args graph_list = " ".join(task_model_list) - model_devi_file_name = lmp_pimd_model_devi_name % pimd_bead if pimd_bead is not None else lmp_model_devi_name - lmp_lines[idx] = ( - "pair_style deepmd %s out_freq %d out_file %s%s" - % ( - graph_list, - trj_freq, - model_devi_file_name, - extra_pair_style_args, - ) + model_devi_file_name = ( + lmp_pimd_model_devi_name % pimd_bead + if pimd_bead is not None + else lmp_model_devi_name + ) + lmp_lines[idx] = "pair_style deepmd %s out_freq %d out_file %s%s" % ( + graph_list, + trj_freq, + model_devi_file_name, + extra_pair_style_args, ) return lmp_lines def revise_lmp_input_dump(lmp_lines, trj_freq, pimd_bead=None): idx = find_only_one_key(lmp_lines, ["dump", "dpgen_dump"]) - lmp_traj_file_name = lmp_pimd_traj_name % pimd_bead if pimd_bead is not None else lmp_traj_name + lmp_traj_file_name = ( + lmp_pimd_traj_name % pimd_bead if pimd_bead is not None else lmp_traj_name + ) lmp_lines[idx] = ( f"dump dpgen_dump all custom %d {lmp_traj_file_name} id type x y z" % trj_freq diff --git a/tests/op/test_run_lmp.py b/tests/op/test_run_lmp.py index 1b1ac6cc..5b7f4542 100644 --- a/tests/op/test_run_lmp.py +++ b/tests/op/test_run_lmp.py @@ -294,7 +294,8 @@ class TestMergePIMDFiles(unittest.TestCase): def test_merge_pimd_files(self): for i in range(1, 3): with open("traj.%s.dump" % i, "w") as f: - f.write("""ITEM: TIMESTEP + f.write( + """ITEM: TIMESTEP 0 ITEM: NUMBER OF ATOMS 3 @@ -318,13 +319,16 @@ def test_merge_pimd_files(self): 1 8 7.23103 0.814939 4.59892 2 1 7.96453 0.61699 5.19158 3 1 6.43661 0.370311 5.09854 -""") +""" + ) for i in range(1, 3): with open("model_devi.%s.out" % i, "w") as f: - f.write("""# step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f + f.write( + """# step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f 0 9.023897e-17 3.548771e-17 5.237314e-17 8.196123e-16 1.225653e-16 3.941002e-16 10 1.081667e-16 4.141596e-17 7.534462e-17 9.070597e-16 1.067947e-16 4.153524e-16 -""") +""" + ) merge_pimd_files() self.assertTrue(os.path.exists(lmp_traj_name)) @@ -335,6 +339,13 @@ def test_merge_pimd_files(self): assert model_devi.shape[0] == 4 def tearDown(self): - for f in [lmp_traj_name, "traj.1.dump", "traj.2.dump", lmp_model_devi_name, "model_devi.1.out", "model_devi.2.out"]: + for f in [ + lmp_traj_name, + "traj.1.dump", + "traj.2.dump", + lmp_model_devi_name, + "model_devi.1.out", + "model_devi.2.out", + ]: if os.path.exists(f): os.remove(f)